Skip to content

[Bugfix] Fix Qwen2.5-Omni M-RoPE position ids generation #16878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 26, 2025

Conversation

imkero
Copy link
Contributor

@imkero imkero commented Apr 19, 2025

What this PR do

This PR fixes the MRotaryEmbedding::_omni_get_input_positions_tensor result when use_audio_in_video=True and there is an image or a video in the input

Root cause

The main branch code incorrectly handles the position shift among vision_bos, audio_bos, audio_eos and vision_eos, causing a 1 position delta between the hf output and the vLLM output

  1. Main branch vLLM incorrectly applies position id shift (-1) for image when use_audio_in_video=True, the position id shift should only apply to video
    • Line 1229: didn't check whether vision_end belongs to an image or a video
      if src_item[idx] not in [
      audio_token_id, video_token_id, image_token_id
      ]:
      if src_item[idx] == vision_end_token_id and use_audio_in_video:
      start_idx -= 1
      new_src_item.append(src_item[idx])
      llm_pos_ids = torch.tensor([start_idx],
      dtype=torch.long).expand(3, -1)
      llm_pos_ids_list.append(llm_pos_ids)
  2. Main branch vLLM incorrectly do audio_bos & audio_eos handling twice when use_audio_in_video=True
    • first time: Line 1229
      if src_item[idx] not in [
      audio_token_id, video_token_id, image_token_id
      ]:
      if src_item[idx] == vision_end_token_id and use_audio_in_video:
      start_idx -= 1
      new_src_item.append(src_item[idx])
      llm_pos_ids = torch.tensor([start_idx],
      dtype=torch.long).expand(3, -1)
      llm_pos_ids_list.append(llm_pos_ids)
    • second time: Line 1288
      else:
      # read audio from video
      assert audio_seqlens is not None
      audio_seqlen = audio_seqlens[audio_idx]
      vision_seqlen = video_grid_thw[video_idx].prod() // (
      spatial_merge_size**2)
      grid_t = video_grid_thw[video_idx][0]
      grid_h = video_grid_thw[video_idx][1]
      grid_w = video_grid_thw[video_idx][2]
      grid_hs = video_grid_thw[:, 1]
      grid_ws = video_grid_thw[:, 2]
      t_ntoken_per_chunk = int(tokens_per_second * seconds_per_chunk)
      t_index = (torch.arange(grid_t) *
      second_per_grid_ts[video_idx] *
      tokens_per_second).long()
      t_index_split_chunk = cls._split_list_into_ranges(
      t_index, t_ntoken_per_chunk)
      new_src_item.extend([audio_start_token_id])
      start_idx -= 1
      llm_pos_ids_list.extend([
      torch.tensor([start_idx], dtype=torch.long).expand(3, -1)
      ] * 1)
      place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2
      pure_audio_len = place_num - 2
      added_audio_len = 0
      audio_llm_pos_ids_list: List[torch.Tensor] = []

Compare to the groundtruth

Current vLLM impl is producing different result from the huggingface transformers impl, the following snippet do these tests.

https://gist.github.com/imkero/0377d5cc175bf068af237465a861dc3d

Case 1

use_audio_in_video=True and an image

[input_ids]
tensor([
           100,    101,  # text
           151652, # vision_bos
           151655, 151655, 151655, 151655,  # image tokens
           151653, # vision_eos
           102,    103   # text
])

[image_grid_thw]
tensor([[1, 4, 4]])

[hf] mrope_position_ids
tensor([[0, 1, 2, 3, 3, 3, 3, 5, 6, 7],
        [0, 1, 2, 3, 3, 4, 4, 5, 6, 7],
        [0, 1, 2, 3, 4, 3, 4, 5, 6, 7]])

[vllm] mrope_position_ids
tensor([[0, 1, 2, 3, 3, 3, 3, 4, 5, 6],
        [0, 1, 2, 3, 3, 4, 4, 4, 5, 6],
        [0, 1, 2, 3, 4, 3, 4, 4, 5, 6]])

Case 2

use_audio_in_video=True and an video

NOTE: currently the HuggingFace transformers main branch code has another bug in it, and won't produce reliable groundtruth now, see huggingface/transformers#37631

the following result is generated by the transformers code in huggingface/transformers#37631

for detailed explanation about the input and output, please refer to https://github.com/huggingface/transformers/blob/b16b93d6857cc1a3d318d3f9ae9957413a6e976c/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py#L384-L510

t_ntoken_per_chunk 50

video_chunk_indexes [(0, 8), (8, 16), (16, 20)]
audio_chunk_indexes [(0, 50), (50, 100), (100, 125)]

[input_ids]
# for readability, i replace the token_id with the token_name
tensor([
    100,    101,

vision_bos, audio_bos,

video, video, video, video, 
video, video, video, video, 

audio, audio, audio, audio, audio, audio, audio, audio, audio, 
audio, audio, audio, audio, audio, audio, audio, audio, audio, 
audio, audio, audio, audio, audio, audio, audio, audio, audio, 
audio, audio, audio, audio, audio, audio, audio, audio, audio, 
audio, audio, audio, audio, audio, audio, audio, audio, audio, 
audio, audio, audio, audio, audio, 

video, video, video, video, 

audio, audio, audio, audio, audio, audio, audio, audio, audio,
audio, audio, audio, audio, audio, audio, audio, audio, audio,
audio, audio, audio, audio, audio, audio, audio,

audio_eos, vision_eos,

    102,    103
])

[video_grid_thw]
tensor([[5, 4, 4]])

[https://github.com/huggingface/transformers/pull/37631] mrope_position_ids
tensor([[ 0,  1,  2,  2,  3,  3,  3,  3, 28, 28, 28, 28,  3,  4,  5,  6,  7,  8,
          9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
         27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
         45, 46, 47, 48, 49, 50, 51, 52, 53, 53, 53, 53, 53, 54, 55, 56, 57, 58,
         59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
         77, 78, 78, 79, 80],
        [ 0,  1,  2,  2,  3,  3,  4,  4,  3,  3,  4,  4,  3,  4,  5,  6,  7,  8,
          9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
         27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
         45, 46, 47, 48, 49, 50, 51, 52,  3,  3,  4,  4, 53, 54, 55, 56, 57, 58,
         59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
         77, 78, 78, 79, 80],
        [ 0,  1,  2,  2,  3,  4,  3,  4,  3,  4,  3,  4,  3,  4,  5,  6,  7,  8,
          9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
         27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
         45, 46, 47, 48, 49, 50, 51, 52,  3,  4,  3,  4, 53, 54, 55, 56, 57, 58,
         59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
         77, 78, 78, 79, 80]])
[vllm] mrope_position_ids
tensor([[ 0,  1,  2,  3,  3,  4,  4,  4,  4, 29, 29, 29, 29,  4,  5,  6,  7,  8,
          9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
         27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
         45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 54, 54, 54, 54, 55, 56, 57, 58,
         59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
         77, 78, 79, 80, 81],
        [ 0,  1,  2,  3,  3,  4,  4,  5,  5,  4,  4,  5,  5,  4,  5,  6,  7,  8,
          9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
         27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
         45, 46, 47, 48, 49, 50, 51, 52, 53,  4,  4,  5,  5, 54, 55, 56, 57, 58,
         59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
         77, 78, 79, 80, 81],
        [ 0,  1,  2,  3,  3,  4,  5,  4,  5,  4,  5,  4,  5,  4,  5,  6,  7,  8,
          9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26,
         27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
         45, 46, 47, 48, 49, 50, 51, 52, 53,  4,  5,  4,  5, 54, 55, 56, 57, 58,
         59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76,
         77, 78, 79, 80, 81]])

cc @ywang96 @fyabc

imkero added 2 commits April 19, 2025 23:48
Signed-off-by: imkero <[email protected]>
@imkero imkero marked this pull request as draft April 19, 2025 16:02
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@imkero imkero marked this pull request as ready for review April 19, 2025 16:30
@fyabc
Copy link
Contributor

fyabc commented Apr 19, 2025

Thank you for your bug fix! I will take a look at it.

@imkero
Copy link
Contributor Author

imkero commented Apr 25, 2025

I will update the test snippet to use the original impl of transformers and vLLM (instead of currently copying the code manually).

The PR huggingface/transformers#37631 has been merged and we can consider it as a groundtruth for video-with-audio now.

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merging this to unblock followup PRs.

@fyabc Please take a look when you have time and let us know if you find any problem!

@WoosukKwon WoosukKwon merged commit de7eb10 into vllm-project:main Apr 26, 2025
30 of 33 checks passed
@fyabc
Copy link
Contributor

fyabc commented Apr 27, 2025

@WoosukKwon @imkero Sorry for the late reply, this bug fix is ​​correct and consistent with transformers; thank you very much for the fix!

jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants