@@ -877,127 +877,163 @@ def forward(
877
877
key = torch .cat ((key_rot , key_pass ), dim = - 1 ).reshape (key_shape )
878
878
return query , key
879
879
880
+ # Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
880
881
@staticmethod
881
- def get_input_positions (
882
- input_tokens : List [int ],
883
- image_grid_thw : Union [List [List [int ]], torch .Tensor ],
884
- video_grid_thw : Union [List [List [int ]], torch .Tensor ],
882
+ def get_rope_index (
883
+ spatial_merge_size : int ,
885
884
image_token_id : int ,
886
885
video_token_id : int ,
887
886
vision_start_token_id : int ,
888
- vision_end_token_id : int ,
889
- spatial_merge_size : int ,
890
- context_len : int = 0 ,
891
- seq_len : Optional [int ] = None ,
892
- second_per_grid_ts : Optional [torch .Tensor ] = None ,
887
+ model_type : str ,
893
888
tokens_per_second : Optional [int ] = None ,
894
- ) -> Tuple [List [List [int ]], int ]:
895
- """
896
- Get mrope input positions and delta value.
897
-
898
- :arg
899
- second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
900
- The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
901
-
902
- """
903
-
904
- if isinstance (image_grid_thw , torch .Tensor ):
905
- image_grid_thw = image_grid_thw .tolist ()
906
- if isinstance (video_grid_thw , torch .Tensor ):
907
- video_grid_thw = video_grid_thw .tolist ()
908
-
909
- input_tokens_tensor = torch .tensor (input_tokens )
910
- vision_start_indices = torch .argwhere (
911
- input_tokens_tensor == vision_start_token_id
912
- ).squeeze (1 )
913
- vision_tokens = input_tokens_tensor [vision_start_indices + 1 ]
914
- image_nums = (vision_tokens == image_token_id ).sum ()
915
- video_nums = (vision_tokens == video_token_id ).sum ()
916
- llm_pos_ids_list : list = []
917
-
918
- st = 0
919
- remain_images , remain_videos = image_nums , video_nums
920
-
921
- image_index , video_index = 0 , 0
922
- for _ in range (image_nums + video_nums ):
923
- if image_token_id in input_tokens and remain_images > 0 :
924
- ed_image = input_tokens .index (image_token_id , st )
925
- else :
926
- ed_image = len (input_tokens ) + 1
927
- if video_token_id in input_tokens and remain_videos > 0 :
928
- ed_video = input_tokens .index (video_token_id , st )
929
- else :
930
- ed_video = len (input_tokens ) + 1
931
- if ed_image < ed_video :
932
- t , h , w = (
933
- image_grid_thw [image_index ][0 ],
934
- image_grid_thw [image_index ][1 ],
935
- image_grid_thw [image_index ][2 ],
936
- )
937
- image_index += 1
938
- remain_images -= 1
939
- second_per_grid_t = 0
940
- ed = ed_image
941
- else :
942
- t , h , w = (
943
- video_grid_thw [video_index ][0 ],
944
- video_grid_thw [video_index ][1 ],
945
- video_grid_thw [video_index ][2 ],
946
- )
947
- if second_per_grid_ts is not None :
948
- second_per_grid_t = second_per_grid_ts [video_index ]
949
- else :
950
- second_per_grid_t = 1.0
951
- video_index += 1
952
- remain_videos -= 1
953
- ed = ed_video
954
- llm_grid_t , llm_grid_h , llm_grid_w = (
955
- t ,
956
- h // spatial_merge_size ,
957
- w // spatial_merge_size ,
958
- )
959
- text_len = ed - st
960
-
961
- st_idx = llm_pos_ids_list [- 1 ].max () + 1 if len (llm_pos_ids_list ) > 0 else 0
962
- llm_pos_ids_list .append (
963
- torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ) + st_idx
964
- )
965
-
966
- t_index = (
967
- torch .arange (llm_grid_t ).view (- 1 , 1 ).expand (- 1 , llm_grid_h * llm_grid_w )
968
- * second_per_grid_t
969
- * tokens_per_second
970
- ).flatten ()
971
-
972
- h_index = (
973
- torch .arange (llm_grid_h )
974
- .view (1 , - 1 , 1 )
975
- .expand (llm_grid_t , - 1 , llm_grid_w )
976
- .flatten ()
977
- )
978
- w_index = (
979
- torch .arange (llm_grid_w )
980
- .view (1 , 1 , - 1 )
981
- .expand (llm_grid_t , llm_grid_h , - 1 )
982
- .flatten ()
983
- )
984
- llm_pos_ids_list .append (
985
- torch .stack ([t_index , h_index , w_index ]) + text_len + st_idx
889
+ input_ids : Optional [torch .LongTensor ] = None ,
890
+ image_grid_thw : Optional [torch .LongTensor ] = None ,
891
+ video_grid_thw : Optional [torch .LongTensor ] = None ,
892
+ second_per_grid_ts : Optional [torch .Tensor ] = None ,
893
+ ** kwargs ,
894
+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
895
+ mrope_position_deltas = []
896
+ if input_ids is not None and (
897
+ image_grid_thw is not None or video_grid_thw is not None
898
+ ):
899
+ total_input_ids = input_ids
900
+ position_ids = torch .ones (
901
+ 3 ,
902
+ input_ids .shape [0 ],
903
+ input_ids .shape [1 ],
904
+ dtype = input_ids .dtype ,
905
+ device = input_ids .device ,
986
906
)
987
- st = ed + llm_grid_t * llm_grid_h * llm_grid_w
988
-
989
- if st < len (input_tokens ):
990
- st_idx = llm_pos_ids_list [- 1 ].max () + 1 if len (llm_pos_ids_list ) > 0 else 0
991
- text_len = len (input_tokens ) - st
992
- llm_pos_ids_list .append (
993
- torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ) + st_idx
907
+ image_index , video_index = 0 , 0
908
+ for i , input_ids in enumerate (total_input_ids ):
909
+ image_nums , video_nums = 0 , 0
910
+ vision_start_indices = torch .argwhere (
911
+ input_ids == vision_start_token_id
912
+ ).squeeze (1 )
913
+ vision_tokens = input_ids [vision_start_indices + 1 ]
914
+ image_nums = (vision_tokens == image_token_id ).sum ()
915
+ video_nums = (vision_tokens == video_token_id ).sum ()
916
+ input_tokens = input_ids .tolist ()
917
+ llm_pos_ids_list : list = []
918
+ st = 0
919
+ remain_images , remain_videos = image_nums , video_nums
920
+ for _ in range (image_nums + video_nums ):
921
+ if image_token_id in input_tokens and remain_images > 0 :
922
+ ed_image = input_tokens .index (image_token_id , st )
923
+ else :
924
+ ed_image = len (input_tokens ) + 1
925
+ if video_token_id in input_tokens and remain_videos > 0 :
926
+ ed_video = input_tokens .index (video_token_id , st )
927
+ else :
928
+ ed_video = len (input_tokens ) + 1
929
+ if ed_image < ed_video :
930
+ t , h , w = (
931
+ image_grid_thw [image_index ][0 ],
932
+ image_grid_thw [image_index ][1 ],
933
+ image_grid_thw [image_index ][2 ],
934
+ )
935
+ second_per_grid_t = 0
936
+ image_index += 1
937
+ remain_images -= 1
938
+ ed = ed_image
939
+ else :
940
+ t , h , w = (
941
+ video_grid_thw [video_index ][0 ],
942
+ video_grid_thw [video_index ][1 ],
943
+ video_grid_thw [video_index ][2 ],
944
+ )
945
+ if second_per_grid_ts is not None :
946
+ second_per_grid_t = second_per_grid_ts [video_index ]
947
+ else :
948
+ second_per_grid_t = 1.0
949
+ video_index += 1
950
+ remain_videos -= 1
951
+ ed = ed_video
952
+ llm_grid_t , llm_grid_h , llm_grid_w = (
953
+ t .item (),
954
+ h .item () // spatial_merge_size ,
955
+ w .item () // spatial_merge_size ,
956
+ )
957
+ text_len = ed - st
958
+
959
+ st_idx = (
960
+ llm_pos_ids_list [- 1 ].max () + 1
961
+ if len (llm_pos_ids_list ) > 0
962
+ else 0
963
+ )
964
+ llm_pos_ids_list .append (
965
+ torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ) + st_idx
966
+ )
967
+
968
+ if model_type == "qwen2_5_vl" :
969
+ range_tensor = torch .arange (llm_grid_t ).view (- 1 , 1 )
970
+ expanded_range = range_tensor .expand (
971
+ - 1 , llm_grid_h * llm_grid_w
972
+ )
973
+
974
+ time_tensor = (
975
+ expanded_range * second_per_grid_t * tokens_per_second
976
+ )
977
+
978
+ time_tensor_long = time_tensor .long ()
979
+ t_index = time_tensor_long .flatten ()
980
+ elif model_type == "qwen2_vl" :
981
+ t_index = (
982
+ torch .arange (llm_grid_t )
983
+ .view (- 1 , 1 )
984
+ .expand (- 1 , llm_grid_h * llm_grid_w )
985
+ .flatten ()
986
+ )
987
+ else :
988
+ raise RuntimeError ("Unimplemented" )
989
+ h_index = (
990
+ torch .arange (llm_grid_h )
991
+ .view (1 , - 1 , 1 )
992
+ .expand (llm_grid_t , - 1 , llm_grid_w )
993
+ .flatten ()
994
+ )
995
+ w_index = (
996
+ torch .arange (llm_grid_w )
997
+ .view (1 , 1 , - 1 )
998
+ .expand (llm_grid_t , llm_grid_h , - 1 )
999
+ .flatten ()
1000
+ )
1001
+ llm_pos_ids_list .append (
1002
+ torch .stack ([t_index , h_index , w_index ]) + text_len + st_idx
1003
+ )
1004
+ st = ed + llm_grid_t * llm_grid_h * llm_grid_w
1005
+
1006
+ if st < len (input_tokens ):
1007
+ st_idx = (
1008
+ llm_pos_ids_list [- 1 ].max () + 1
1009
+ if len (llm_pos_ids_list ) > 0
1010
+ else 0
1011
+ )
1012
+ text_len = len (input_tokens ) - st
1013
+ llm_pos_ids_list .append (
1014
+ torch .arange (text_len ).view (1 , - 1 ).expand (3 , - 1 ) + st_idx
1015
+ )
1016
+
1017
+ llm_positions = torch .cat (llm_pos_ids_list , dim = 1 ).reshape (3 , - 1 )
1018
+ position_ids [..., i , :] = llm_positions .to (position_ids .device )
1019
+ mrope_position_deltas .append (
1020
+ llm_positions .max () + 1 - len (total_input_ids [i ])
1021
+ )
1022
+ mrope_position_deltas = torch .tensor (
1023
+ mrope_position_deltas , device = input_ids .device
1024
+ ).unsqueeze (1 )
1025
+ return position_ids , mrope_position_deltas
1026
+ else :
1027
+ s = input_ids .shape [1 ]
1028
+ position_ids = torch .arange (s )
1029
+ position_ids = (
1030
+ position_ids .unsqueeze (0 ).expand (3 , - 1 , - 1 ).to (input_ids .device )
994
1031
)
995
-
996
- llm_positions = torch .cat (llm_pos_ids_list , dim = 1 ).reshape (3 , - 1 )
997
- mrope_position_delta = (llm_positions .max () + 1 - len (input_tokens )).item ()
998
- llm_positions = llm_positions [:, context_len :seq_len ]
999
-
1000
- return llm_positions .tolist (), mrope_position_delta
1032
+ max_position_ids = position_ids .max (0 , keepdim = False )[0 ].max (
1033
+ - 1 , keepdim = True
1034
+ )[0 ]
1035
+ mrope_position_deltas = max_position_ids + 1 - s
1036
+ return position_ids , mrope_position_deltas
1001
1037
1002
1038
@staticmethod
1003
1039
def get_next_input_positions (
0 commit comments