Skip to content

Commit 5f79128

Browse files
authored
Fix Qwen2.5-Omni get_chunked_index chunking functionality (#37631)
* fix: qwen2.5 omni modular get_rope_index * test: add test for qwen2.5 omni rope index (video with audio input) * style * expected_position_ids readability * fix: use spatial_merge_size = 1 in unit test
1 parent fee1190 commit 5f79128

File tree

4 files changed

+150
-11
lines changed

4 files changed

+150
-11
lines changed

src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,8 @@ def get_chunked_index(
244244
- the second chunk contains values >= 1000 and < 2000, and so on.
245245
246246
Parameters:
247-
token_indices (`List[int]`): A monotonically increasing list of token index values.
247+
token_indices (`torch.Tensor` of shape `(seq_len, )`): A monotonically increasing list of
248+
token index values.
248249
t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
249250
remove_index (`int`) An index id to subtract from `token_indices` before chunking
250251
@@ -257,12 +258,12 @@ def _iter():
257258
i, start_idx = 0, 0 # skip bos token
258259
current_chunk = 1
259260
while i < len(token_indices): # skip eos token
260-
if token_indices[0][i] - remove_index >= current_chunk * tokens_per_chunk:
261+
if token_indices[i] - remove_index >= current_chunk * tokens_per_chunk:
261262
yield (start_idx, i)
262263
start_idx = i
263264
current_chunk += 1
264265
i += 1
265-
yield (start_idx, token_indices.shape[1])
266+
yield (start_idx, len(token_indices))
266267

267268
return list(_iter())
268269

@@ -499,8 +500,8 @@ def get_rope_index(
499500
)
500501

501502
t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk)
502-
video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids, t_ntoken_per_chunk, st_idx)
503-
audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids, t_ntoken_per_chunk, st_idx)
503+
video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids[0], t_ntoken_per_chunk, st_idx)
504+
audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids[0], t_ntoken_per_chunk, st_idx)
504505
sub_len = 0
505506
for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))):
506507
video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None

src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1145,7 +1145,8 @@ def get_chunked_index(
11451145
- the second chunk contains values >= 1000 and < 2000, and so on.
11461146
11471147
Parameters:
1148-
token_indices (`List[int]`): A monotonically increasing list of token index values.
1148+
token_indices (`torch.Tensor` of shape `(seq_len, )`): A monotonically increasing list of
1149+
token index values.
11491150
t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
11501151
remove_index (`int`) An index id to subtract from `token_indices` before chunking
11511152
@@ -1158,12 +1159,12 @@ def _iter():
11581159
i, start_idx = 0, 0 # skip bos token
11591160
current_chunk = 1
11601161
while i < len(token_indices): # skip eos token
1161-
if token_indices[0][i] - remove_index >= current_chunk * tokens_per_chunk:
1162+
if token_indices[i] - remove_index >= current_chunk * tokens_per_chunk:
11621163
yield (start_idx, i)
11631164
start_idx = i
11641165
current_chunk += 1
11651166
i += 1
1166-
yield (start_idx, token_indices.shape[1])
1167+
yield (start_idx, len(token_indices))
11671168

11681169
return list(_iter())
11691170

@@ -1400,8 +1401,8 @@ def get_rope_index(
14001401
)
14011402

14021403
t_ntoken_per_chunk = int(position_id_per_seconds * seconds_per_chunk)
1403-
video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids, t_ntoken_per_chunk, st_idx)
1404-
audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids, t_ntoken_per_chunk, st_idx)
1404+
video_chunk_indexes = self.get_chunked_index(video_llm_pos_ids[0], t_ntoken_per_chunk, st_idx)
1405+
audio_chunk_indexes = self.get_chunked_index(audio_llm_pos_ids[0], t_ntoken_per_chunk, st_idx)
14051406
sub_len = 0
14061407
for j in range(max(len(video_chunk_indexes), len(audio_chunk_indexes))):
14071408
video_chunk_index = video_chunk_indexes[j] if j < len(video_chunk_indexes) else None

src/transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def get_chunked_index(self, token_indices: np.ndarray, tokens_per_chunk: int) ->
289289
- the second chunk contains values >= 1000 and < 2000, and so on.
290290
291291
Parameters:
292-
token_indices (`List[int]`): A monotonically increasing list of token index values.
292+
token_indices (`np.ndarray`): A monotonically increasing list of token index values.
293293
t_ntoken_per_chunk (`int`): Number of tokens per chunk (used as the chunk size threshold).
294294
295295
Returns:

tests/models/qwen2_5_omni/test_modeling_qwen2_5_omni.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,143 @@ def test_generate_with_static_cache(self):
381381
def test_custom_4d_attention_mask(self):
382382
pass
383383

384+
def test_get_rope_index_video_with_audio(self):
385+
image_grid_thw = torch.empty((0, 3), dtype=torch.long)
386+
387+
# 3 * 2 * 2 = 12 video tokens
388+
video_grid_thw = torch.tensor([[3, 2, 2]], dtype=torch.long)
389+
390+
# num_audio_tokens = ((audio_seqlen - 1) // 2 + 1 - 2) // 2 + 1
391+
# i.e.: 300 audio_seqlen -> 75 audio tokens
392+
audio_seqlens = torch.tensor([300], dtype=torch.long)
393+
394+
second_per_grids = torch.tensor([1.0], dtype=torch.float)
395+
396+
use_audio_in_video = True
397+
398+
# fmt: off
399+
expected_position_ids = torch.tensor([
400+
[[
401+
0, 1, # text
402+
2, 2, # vision_bos + audio_bos
403+
404+
# video chunk
405+
3, 3, 3, 3,
406+
28, 28, 28, 28,
407+
408+
# audio chunk
409+
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
410+
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
411+
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
412+
45, 46, 47, 48, 49, 50, 51, 52,
413+
414+
# video chunk
415+
53, 53, 53, 53,
416+
417+
# audio chunk
418+
53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66,
419+
67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
420+
421+
78, 78, # audio_eos + vision_eos
422+
79, 80, # text
423+
]],
424+
[[
425+
0, 1, # text
426+
2, 2, # vision_bos + audio_bos
427+
428+
# video chunk
429+
3, 3, 4, 4,
430+
3, 3, 4, 4,
431+
432+
# audio chunk
433+
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
434+
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
435+
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
436+
45, 46, 47, 48, 49, 50, 51, 52,
437+
438+
# video chunk
439+
3, 3, 4, 4,
440+
441+
# audio chunk
442+
53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66,
443+
67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
444+
445+
78, 78, # audio_eos + vision_eos
446+
79, 80, # text
447+
]],
448+
[[
449+
0, 1, # text
450+
2, 2, # vision_bos + audio_bos
451+
452+
# video chunk
453+
3, 4, 3, 4,
454+
3, 4, 3, 4,
455+
456+
# audio chunk
457+
3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
458+
17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30,
459+
31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44,
460+
45, 46, 47, 48, 49, 50, 51, 52,
461+
462+
# video chunk
463+
3, 4, 3, 4,
464+
465+
# audio chunk
466+
53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66,
467+
67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77,
468+
469+
78, 78, # audio_eos + vision_eos
470+
79, 80, # text
471+
]],
472+
], dtype=torch.long)
473+
# fmt: on
474+
475+
for model_class in self.all_model_classes:
476+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
477+
478+
input_ids = torch.tensor(
479+
[
480+
[
481+
100,
482+
101,
483+
]
484+
+ [
485+
config.vision_start_token_id,
486+
config.audio_start_token_id,
487+
]
488+
# 1st chunk: 8 video tokens, 50 audio tokens
489+
+ [config.video_token_id] * 2 * 2 * 2
490+
+ [config.audio_token_id] * 50
491+
+
492+
# 2nd chunk: 4 video tokens, 25 audio tokens
493+
[config.video_token_id] * 1 * 2 * 2
494+
+ [config.audio_token_id] * 25
495+
+ [
496+
config.audio_end_token_id,
497+
config.vision_end_token_id,
498+
]
499+
+ [
500+
102,
501+
103,
502+
]
503+
],
504+
dtype=torch.long,
505+
)
506+
507+
model = model_class(config)
508+
509+
position_ids, mrope_position_deltas = model.get_rope_index(
510+
input_ids=input_ids,
511+
image_grid_thw=image_grid_thw,
512+
video_grid_thw=video_grid_thw,
513+
attention_mask=None,
514+
use_audio_in_video=use_audio_in_video,
515+
audio_seqlens=audio_seqlens,
516+
second_per_grids=second_per_grids,
517+
)
518+
519+
self.assertTrue(torch.equal(position_ids, expected_position_ids))
520+
384521

385522
@require_torch
386523
class Qwen2_5OmniModelIntegrationTest(unittest.TestCase):

0 commit comments

Comments
 (0)