From c6a84a93178af9b00b652bfc7ba0963fcf5ad107 Mon Sep 17 00:00:00 2001 From: imkero Date: Sat, 19 Apr 2025 23:48:12 +0800 Subject: [PATCH 1/2] fix: omni get input positions Signed-off-by: imkero --- .../model_executor/layers/rotary_embedding.py | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index e6f2461eb67..e61a96f2695 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -1197,6 +1197,7 @@ def _omni_get_input_positions_tensor( video_token_id = thinker_config.video_token_index audio_start_token_id = thinker_config.audio_start_token_id audio_end_token_id = thinker_config.audio_end_token_id + vision_start_token_id = thinker_config.vision_start_token_id vision_end_token_id = thinker_config.vision_end_token_id seconds_per_chunk = thinker_config.seconds_per_chunk spatial_merge_size = thinker_config.vision_config.spatial_merge_size @@ -1226,8 +1227,15 @@ def _omni_get_input_positions_tensor( if src_item[idx] not in [ audio_token_id, video_token_id, image_token_id ]: - if src_item[idx] == vision_end_token_id and use_audio_in_video: - start_idx -= 1 + if use_audio_in_video and idx > 0: + if src_item[idx] == vision_end_token_id and \ + # processing the <|audio_eos|> before <|vision_eos|> + src_item[idx - 1] == audio_end_token_id: + start_idx -= 1 + elif src_item[idx] == audio_start_token_id and \ + # processing the <|audio_bos|> after <|vision_eos|> + src_item[idx - 1] == vision_start_token_id: + start_idx -= 1 new_src_item.append(src_item[idx]) llm_pos_ids = torch.tensor([start_idx], dtype=torch.long).expand(3, -1) @@ -1285,11 +1293,6 @@ def _omni_get_input_positions_tensor( tokens_per_second).long() t_index_split_chunk = cls._split_list_into_ranges( t_index, t_ntoken_per_chunk) - new_src_item.extend([audio_start_token_id]) - start_idx -= 1 - llm_pos_ids_list.extend([ - torch.tensor([start_idx], dtype=torch.long).expand(3, -1) - ] * 1) place_num = (((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1) + 2 pure_audio_len = place_num - 2 added_audio_len = 0 @@ -1300,7 +1303,7 @@ def _omni_get_input_positions_tensor( new_src_item.extend([video_token_id] * vision_ntoken_per_chunk) vision_llm_pos_ids_list = cls._get_llm_pos_ids_for_vision( - start_idx + 1, video_idx, spatial_merge_size, t_chunk, + start_idx, video_idx, spatial_merge_size, t_chunk, grid_hs, grid_ws).split(1, dim=1) llm_pos_ids_list.extend(vision_llm_pos_ids_list) new_src_item.extend( @@ -1308,13 +1311,13 @@ def _omni_get_input_positions_tensor( added_audio_len) * [audio_token_id]) audio_start_idx = start_idx if len( audio_llm_pos_ids_list - ) == 0 else audio_llm_pos_ids_list[-1][0].item() + ) == 0 else audio_llm_pos_ids_list[-1][0].item() + 1 if min(t_ntoken_per_chunk, pure_audio_len - added_audio_len) > 0: audio_llm_pos_ids_list = (torch.arange( min(t_ntoken_per_chunk, pure_audio_len - added_audio_len)).expand(3, -1) + - audio_start_idx + 1).split( + audio_start_idx).split( 1, dim=1) else: audio_llm_pos_ids_list = [] @@ -1329,11 +1332,6 @@ def _omni_get_input_positions_tensor( 3, -1) + llm_pos_ids_list[-1].max() + 1).split( 1, dim=1) llm_pos_ids_list.extend(audio_llm_pos_ids_list) - llm_pos_ids_list.extend([ - torch.tensor( - [llm_pos_ids_list[-1].max() + 1] * 3).unsqueeze(1) - ] * 1) - new_src_item.extend([audio_end_token_id]) audio_idx += 1 video_idx += 1 # move to the next token From ff5be1231cdd0a83d53efd02fecb1c4413655ec8 Mon Sep 17 00:00:00 2001 From: imkero Date: Sat, 19 Apr 2025 23:50:19 +0800 Subject: [PATCH 2/2] fix Signed-off-by: imkero --- vllm/model_executor/layers/rotary_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index e61a96f2695..09f94ac4d50 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -1229,12 +1229,12 @@ def _omni_get_input_positions_tensor( ]: if use_audio_in_video and idx > 0: if src_item[idx] == vision_end_token_id and \ - # processing the <|audio_eos|> before <|vision_eos|> src_item[idx - 1] == audio_end_token_id: + # processing the <|audio_eos|> before <|vision_eos|> start_idx -= 1 elif src_item[idx] == audio_start_token_id and \ - # processing the <|audio_bos|> after <|vision_eos|> src_item[idx - 1] == vision_start_token_id: + # processing the <|audio_bos|> after <|vision_eos|> start_idx -= 1 new_src_item.append(src_item[idx]) llm_pos_ids = torch.tensor([start_idx],