diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 8accfc638e3e..bda93a796e90 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -244,7 +244,8 @@ def get_chunked_index( - the second chunk contains values >= 1000 and < 2000, and so on. Parameters: - token_indices (`List[int]`): A monotonically increasing list of token index values. + token_indices (`torch.Tensor` of shape `(seq_len, )`): A monotonically increasing list of + token index values. t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold). remove_index (`int`) An index id to subtract from `token_indices` before chunking @@ -257,12 +258,12 @@ def _iter(): i, start_idx = 0, 0 # skip bos token current_chunk = 1 while i < len(token_indices): # skip eos token - if token_indices[0][i] - remove_index >= current_chunk * tokens_per_chunk: + if token_indices[i] - remove_index >= current_chunk * tokens_per_chunk: yield (start_idx, i) start_idx = i current_chunk += 1 i += 1 - yield (start_idx, token_indices.shape[1]) + yield (start_idx, len(token_indices)) return list(_iter()) @@ -499,8 +500,8 @@ def get_rope_index( ) t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk) - video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids, t_ntoken_per_chunk, st_idx) - audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids, t_ntoken_per_chunk, st_idx) + video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids[0], t_ntoken_per_chunk, st_idx) + audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids[0], t_ntoken_per_chunk, st_idx) sub_len = 0 for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))): video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 2524fd9186a6..61ef59d9a169 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1145,7 +1145,8 @@ def get_chunked_index( - the second chunk contains values >= 1000 and < 2000, and so on. Parameters: - token_indices (`List[int]`): A monotonically increasing list of token index values. + token_indices (`torch.Tensor` of shape `(seq_len, )`): A monotonically increasing list of + token index values. t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold). remove_index (`int`) An index id to subtract from `token_indices` before chunking @@ -1158,12 +1159,12 @@ def _iter(): i, start_idx = 0, 0 # skip bos token current_chunk = 1 while i < len(token_indices): # skip eos token - if token_indices[0][i] - remove_index >= current_chunk * tokens_per_chunk: + if token_indices[i] - remove_index >= current_chunk * tokens_per_chunk: yield (start_idx, i) start_idx = i current_chunk += 1 i += 1 - yield (start_idx, token_indices.shape[1]) + yield (start_idx, len(token_indices)) return list(_iter()) @@ -1400,8 +1401,8 @@ def get_rope_index( ) t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk) - video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids, t_ntoken_per_chunk, st_idx) - audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids, t_ntoken_per_chunk, st_idx) + video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids[0], t_ntoken_per_chunk, st_idx) + audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids[0], t_ntoken_per_chunk, st_idx) sub_len = 0 for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))): video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None diff --git a/src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py index d607b8b95e80..64b444b716f8 100644 --- a/src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py @@ -289,7 +289,7 @@ def get_chunked_index(self, token_indices: np.ndarray, tokens_per_chunk: int) -> - the second chunk contains values >= 1000 and < 2000, and so on. Parameters: - token_indices (`List[int]`): A monotonically increasing list of token index values. + token_indices (`np.ndarray`): A monotonically increasing list of token index values. t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold). Returns: diff --git a/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py b/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py index f43f7bc6b179..116425f349c0 100644 --- a/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py +++ b/tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py @@ -381,6 +381,143 @@ def test_generate_with_static_cache(self): def test_custom_4d_attention_mask(self): pass + def test_get_rope_index_video_with_audio(self): + image_grid_thw = torch.empty((0, 3), dtype=torch.long) + + # 3 * 2 * 2 = 12 video tokens + video_grid_thw = torch.tensor([[3, 2, 2]], dtype=torch.long) + + # num_audio_tokens = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1 + # i.e.: 300 audio_seqlen -> 75 audio tokens + audio_seqlens = torch.tensor([300], dtype=torch.long) + + second_per_grids = torch.tensor([1.0], dtype=torch.float) + + use_audio_in_video = True + + # fmt: off + expected_position_ids = torch.tensor([ + [[ + 0, 1, # text + 2, 2, # vision_bos + audio_bos + + # video chunk + 3, 3, 3, 3, + 28, 28, 28, 28, + + # audio chunk + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 47, 48, 49, 50, 51, 52, + + # video chunk + 53, 53, 53, 53, + + # audio chunk + 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, + 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, + + 78, 78, # audio_eos + vision_eos + 79, 80, # text + ]], + [[ + 0, 1, # text + 2, 2, # vision_bos + audio_bos + + # video chunk + 3, 3, 4, 4, + 3, 3, 4, 4, + + # audio chunk + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 47, 48, 49, 50, 51, 52, + + # video chunk + 3, 3, 4, 4, + + # audio chunk + 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, + 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, + + 78, 78, # audio_eos + vision_eos + 79, 80, # text + ]], + [[ + 0, 1, # text + 2, 2, # vision_bos + audio_bos + + # video chunk + 3, 4, 3, 4, + 3, 4, 3, 4, + + # audio chunk + 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, + 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, + 45, 46, 47, 48, 49, 50, 51, 52, + + # video chunk + 3, 4, 3, 4, + + # audio chunk + 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, + 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, + + 78, 78, # audio_eos + vision_eos + 79, 80, # text + ]], + ], dtype=torch.long) + # fmt: on + + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + input_ids = torch.tensor( + [ + [ + 100, + 101, + ] + + [ + config.vision_start_token_id, + config.audio_start_token_id, + ] + # 1st chunk: 8 video tokens, 50 audio tokens + + [config.video_token_id] * 2 * 2 * 2 + + [config.audio_token_id] * 50 + + + # 2nd chunk: 4 video tokens, 25 audio tokens + [config.video_token_id] * 1 * 2 * 2 + + [config.audio_token_id] * 25 + + [ + config.audio_end_token_id, + config.vision_end_token_id, + ] + + [ + 102, + 103, + ] + ], + dtype=torch.long, + ) + + model = model_class(config) + + position_ids, mrope_position_deltas = model.get_rope_index( + input_ids=input_ids, + image_grid_thw=image_grid_thw, + video_grid_thw=video_grid_thw, + attention_mask=None, + use_audio_in_video=use_audio_in_video, + audio_seqlens=audio_seqlens, + second_per_grids=second_per_grids, + ) + + self.assertTrue(torch.equal(position_ids, expected_position_ids)) + @require_torch class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):