Skip to content

Commit c998d04

Browse files
vlm: enable radix cache for qwen-vl models (#5349)
Co-authored-by: Xinyuan Tong <[email protected]>
1 parent 7d0edf3 commit c998d04

26 files changed

+435
-337
lines changed

benchmark/mmmu/eval_utils.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def set_seed(seed_value):
8989

9090

9191
def prepare_samples(eval_args: EvalArgs):
92-
print("preparing samples...")
92+
print("Preparing samples...")
9393
# Build prompts
9494
set_seed(eval_args.seed)
9595

@@ -105,15 +105,40 @@ def prepare_samples(eval_args: EvalArgs):
105105
assert len(value) == 1, "key {} has more than one value".format(key)
106106
eval_args.config[key] = value[0]
107107

108-
# run for each subject
108+
# run for each subject in parallel
109109
sub_dataset_list = []
110+
subjects = list(CAT_SHORT2LONG.values()) # Get a fixed list of subjects
110111

111-
for subject in tqdm(CAT_SHORT2LONG.values()):
112-
sub_dataset = load_dataset(
113-
eval_args.dataset_path, subject, split=eval_args.split
114-
)
115-
sub_dataset_list.append(sub_dataset)
116-
# break
112+
print(f"Loading datasets for {len(subjects)} subjects...")
113+
with ThreadPoolExecutor() as executor:
114+
# Submit all load_dataset tasks
115+
future_to_subject = {
116+
executor.submit(
117+
load_dataset, eval_args.dataset_path, subject, split=eval_args.split
118+
): subject
119+
for subject in subjects
120+
}
121+
122+
# Collect results as they complete
123+
results = {}
124+
for future in tqdm(
125+
as_completed(future_to_subject),
126+
total=len(subjects),
127+
desc="Loading datasets",
128+
):
129+
subject = future_to_subject[future]
130+
try:
131+
results[subject] = future.result()
132+
except Exception as exc:
133+
print(f"{subject} generated an exception: {exc}")
134+
135+
# Ensure datasets are added in the original order for consistency
136+
for subject in subjects:
137+
if subject in results:
138+
sub_dataset_list.append(results[subject])
139+
else:
140+
# Handle cases where a dataset failed to load (optional, depends on desired behavior)
141+
print(f"Warning: Dataset for subject '{subject}' could not be loaded.")
117142

118143
# merge all dataset
119144
dataset = concatenate_datasets(sub_dataset_list)
@@ -133,28 +158,35 @@ def process_sample(i, sample):
133158
width, height = image.size
134159
if width * height >= eval_args.image_pixels_limit:
135160
return None, True
136-
image_path = f"{images_path}/image_{i}.png"
161+
# Use a unique identifier for the image path to avoid potential collisions if indices reset
162+
image_path = f"{images_path}/image_{sample['id']}.png"
137163
if not os.path.exists(image_path):
138164
image.save(image_path)
139165
sample["image_path"] = image_path
140166
return sample, False
141167

168+
print("Processing samples...")
142169
with ThreadPoolExecutor() as executor:
170+
# Pass the sample itself to process_sample, index is less reliable now
143171
futures = [
144-
executor.submit(process_sample, i, sample)
172+
executor.submit(
173+
process_sample, i, sample
174+
) # Keep index i for tqdm maybe? Or remove it. Let's keep it for now.
145175
for i, sample in enumerate(dataset)
146176
]
147-
for future in tqdm(as_completed(futures), total=len(futures)):
177+
for future in tqdm(
178+
as_completed(futures), total=len(dataset), desc="Processing samples"
179+
):
148180
sample, skipped = future.result()
149181
if skipped:
150182
skip_count += 1
151183
elif sample:
152184
samples.append(sample)
153185

154186
print(
155-
f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"
187+
f"Skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"
156188
)
157-
print("samples have been prepared")
189+
print("Samples have been prepared")
158190
return samples
159191

160192

python/sglang/srt/configs/model_config.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -73,15 +73,14 @@ def __init__(
7373
)
7474

7575
if enable_multimodal is None:
76-
if self.hf_config.architectures[0] == "Llama4ForConditionalGeneration":
76+
mm_disabled_models = [
77+
"Gemma3ForConditionalGeneration",
78+
"Llama4ForConditionalGeneration",
79+
]
80+
if self.hf_config.architectures[0] in mm_disabled_models:
7781
enable_multimodal = False
7882
logger.info(
79-
"Multimodal is disabled for Llama4. To enable it, set --enable-llama4-multimodal."
80-
)
81-
elif self.hf_config.architectures[0] == "Gemma3ForConditionalGeneration":
82-
enable_multimodal = False
83-
logger.info(
84-
"Multimodal is disabled for Gemma3. To enable it, set --enable-gemma3-multimodal."
83+
f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
8584
)
8685
else:
8786
enable_multimodal = True

python/sglang/srt/layers/rotary_embedding.py

Lines changed: 150 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -877,127 +877,163 @@ def forward(
877877
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
878878
return query, key
879879

880+
# Copied from https://github.com/huggingface/transformers/blob/c8e0e603de9b3d49161a15fe6e8ea84badfb5d02/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L1439
880881
@staticmethod
881-
def get_input_positions(
882-
input_tokens: List[int],
883-
image_grid_thw: Union[List[List[int]], torch.Tensor],
884-
video_grid_thw: Union[List[List[int]], torch.Tensor],
882+
def get_rope_index(
883+
spatial_merge_size: int,
885884
image_token_id: int,
886885
video_token_id: int,
887886
vision_start_token_id: int,
888-
vision_end_token_id: int,
889-
spatial_merge_size: int,
890-
context_len: int = 0,
891-
seq_len: Optional[int] = None,
892-
second_per_grid_ts: Optional[torch.Tensor] = None,
887+
model_type: str,
893888
tokens_per_second: Optional[int] = None,
894-
) -> Tuple[List[List[int]], int]:
895-
"""
896-
Get mrope input positions and delta value.
897-
898-
:arg
899-
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
900-
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
901-
902-
"""
903-
904-
if isinstance(image_grid_thw, torch.Tensor):
905-
image_grid_thw = image_grid_thw.tolist()
906-
if isinstance(video_grid_thw, torch.Tensor):
907-
video_grid_thw = video_grid_thw.tolist()
908-
909-
input_tokens_tensor = torch.tensor(input_tokens)
910-
vision_start_indices = torch.argwhere(
911-
input_tokens_tensor == vision_start_token_id
912-
).squeeze(1)
913-
vision_tokens = input_tokens_tensor[vision_start_indices + 1]
914-
image_nums = (vision_tokens == image_token_id).sum()
915-
video_nums = (vision_tokens == video_token_id).sum()
916-
llm_pos_ids_list: list = []
917-
918-
st = 0
919-
remain_images, remain_videos = image_nums, video_nums
920-
921-
image_index, video_index = 0, 0
922-
for _ in range(image_nums + video_nums):
923-
if image_token_id in input_tokens and remain_images > 0:
924-
ed_image = input_tokens.index(image_token_id, st)
925-
else:
926-
ed_image = len(input_tokens) + 1
927-
if video_token_id in input_tokens and remain_videos > 0:
928-
ed_video = input_tokens.index(video_token_id, st)
929-
else:
930-
ed_video = len(input_tokens) + 1
931-
if ed_image < ed_video:
932-
t, h, w = (
933-
image_grid_thw[image_index][0],
934-
image_grid_thw[image_index][1],
935-
image_grid_thw[image_index][2],
936-
)
937-
image_index += 1
938-
remain_images -= 1
939-
second_per_grid_t = 0
940-
ed = ed_image
941-
else:
942-
t, h, w = (
943-
video_grid_thw[video_index][0],
944-
video_grid_thw[video_index][1],
945-
video_grid_thw[video_index][2],
946-
)
947-
if second_per_grid_ts is not None:
948-
second_per_grid_t = second_per_grid_ts[video_index]
949-
else:
950-
second_per_grid_t = 1.0
951-
video_index += 1
952-
remain_videos -= 1
953-
ed = ed_video
954-
llm_grid_t, llm_grid_h, llm_grid_w = (
955-
t,
956-
h // spatial_merge_size,
957-
w // spatial_merge_size,
958-
)
959-
text_len = ed - st
960-
961-
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
962-
llm_pos_ids_list.append(
963-
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
964-
)
965-
966-
t_index = (
967-
torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
968-
* second_per_grid_t
969-
* tokens_per_second
970-
).flatten()
971-
972-
h_index = (
973-
torch.arange(llm_grid_h)
974-
.view(1, -1, 1)
975-
.expand(llm_grid_t, -1, llm_grid_w)
976-
.flatten()
977-
)
978-
w_index = (
979-
torch.arange(llm_grid_w)
980-
.view(1, 1, -1)
981-
.expand(llm_grid_t, llm_grid_h, -1)
982-
.flatten()
983-
)
984-
llm_pos_ids_list.append(
985-
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
889+
input_ids: Optional[torch.LongTensor] = None,
890+
image_grid_thw: Optional[torch.LongTensor] = None,
891+
video_grid_thw: Optional[torch.LongTensor] = None,
892+
second_per_grid_ts: Optional[torch.Tensor] = None,
893+
**kwargs,
894+
) -> Tuple[torch.Tensor, torch.Tensor]:
895+
mrope_position_deltas = []
896+
if input_ids is not None and (
897+
image_grid_thw is not None or video_grid_thw is not None
898+
):
899+
total_input_ids = input_ids
900+
position_ids = torch.ones(
901+
3,
902+
input_ids.shape[0],
903+
input_ids.shape[1],
904+
dtype=input_ids.dtype,
905+
device=input_ids.device,
986906
)
987-
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
988-
989-
if st < len(input_tokens):
990-
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
991-
text_len = len(input_tokens) - st
992-
llm_pos_ids_list.append(
993-
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
907+
image_index, video_index = 0, 0
908+
for i, input_ids in enumerate(total_input_ids):
909+
image_nums, video_nums = 0, 0
910+
vision_start_indices = torch.argwhere(
911+
input_ids == vision_start_token_id
912+
).squeeze(1)
913+
vision_tokens = input_ids[vision_start_indices + 1]
914+
image_nums = (vision_tokens == image_token_id).sum()
915+
video_nums = (vision_tokens == video_token_id).sum()
916+
input_tokens = input_ids.tolist()
917+
llm_pos_ids_list: list = []
918+
st = 0
919+
remain_images, remain_videos = image_nums, video_nums
920+
for _ in range(image_nums + video_nums):
921+
if image_token_id in input_tokens and remain_images > 0:
922+
ed_image = input_tokens.index(image_token_id, st)
923+
else:
924+
ed_image = len(input_tokens) + 1
925+
if video_token_id in input_tokens and remain_videos > 0:
926+
ed_video = input_tokens.index(video_token_id, st)
927+
else:
928+
ed_video = len(input_tokens) + 1
929+
if ed_image < ed_video:
930+
t, h, w = (
931+
image_grid_thw[image_index][0],
932+
image_grid_thw[image_index][1],
933+
image_grid_thw[image_index][2],
934+
)
935+
second_per_grid_t = 0
936+
image_index += 1
937+
remain_images -= 1
938+
ed = ed_image
939+
else:
940+
t, h, w = (
941+
video_grid_thw[video_index][0],
942+
video_grid_thw[video_index][1],
943+
video_grid_thw[video_index][2],
944+
)
945+
if second_per_grid_ts is not None:
946+
second_per_grid_t = second_per_grid_ts[video_index]
947+
else:
948+
second_per_grid_t = 1.0
949+
video_index += 1
950+
remain_videos -= 1
951+
ed = ed_video
952+
llm_grid_t, llm_grid_h, llm_grid_w = (
953+
t.item(),
954+
h.item() // spatial_merge_size,
955+
w.item() // spatial_merge_size,
956+
)
957+
text_len = ed - st
958+
959+
st_idx = (
960+
llm_pos_ids_list[-1].max() + 1
961+
if len(llm_pos_ids_list) > 0
962+
else 0
963+
)
964+
llm_pos_ids_list.append(
965+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
966+
)
967+
968+
if model_type == "qwen2_5_vl":
969+
range_tensor = torch.arange(llm_grid_t).view(-1, 1)
970+
expanded_range = range_tensor.expand(
971+
-1, llm_grid_h * llm_grid_w
972+
)
973+
974+
time_tensor = (
975+
expanded_range * second_per_grid_t * tokens_per_second
976+
)
977+
978+
time_tensor_long = time_tensor.long()
979+
t_index = time_tensor_long.flatten()
980+
elif model_type == "qwen2_vl":
981+
t_index = (
982+
torch.arange(llm_grid_t)
983+
.view(-1, 1)
984+
.expand(-1, llm_grid_h * llm_grid_w)
985+
.flatten()
986+
)
987+
else:
988+
raise RuntimeError("Unimplemented")
989+
h_index = (
990+
torch.arange(llm_grid_h)
991+
.view(1, -1, 1)
992+
.expand(llm_grid_t, -1, llm_grid_w)
993+
.flatten()
994+
)
995+
w_index = (
996+
torch.arange(llm_grid_w)
997+
.view(1, 1, -1)
998+
.expand(llm_grid_t, llm_grid_h, -1)
999+
.flatten()
1000+
)
1001+
llm_pos_ids_list.append(
1002+
torch.stack([t_index, h_index, w_index]) + text_len + st_idx
1003+
)
1004+
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
1005+
1006+
if st < len(input_tokens):
1007+
st_idx = (
1008+
llm_pos_ids_list[-1].max() + 1
1009+
if len(llm_pos_ids_list) > 0
1010+
else 0
1011+
)
1012+
text_len = len(input_tokens) - st
1013+
llm_pos_ids_list.append(
1014+
torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
1015+
)
1016+
1017+
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1018+
position_ids[..., i, :] = llm_positions.to(position_ids.device)
1019+
mrope_position_deltas.append(
1020+
llm_positions.max() + 1 - len(total_input_ids[i])
1021+
)
1022+
mrope_position_deltas = torch.tensor(
1023+
mrope_position_deltas, device=input_ids.device
1024+
).unsqueeze(1)
1025+
return position_ids, mrope_position_deltas
1026+
else:
1027+
s = input_ids.shape[1]
1028+
position_ids = torch.arange(s)
1029+
position_ids = (
1030+
position_ids.unsqueeze(0).expand(3, -1, -1).to(input_ids.device)
9941031
)
995-
996-
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
997-
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
998-
llm_positions = llm_positions[:, context_len:seq_len]
999-
1000-
return llm_positions.tolist(), mrope_position_delta
1032+
max_position_ids = position_ids.max(0, keepdim=False)[0].max(
1033+
-1, keepdim=True
1034+
)[0]
1035+
mrope_position_deltas = max_position_ids + 1 - s
1036+
return position_ids, mrope_position_deltas
10011037

10021038
@staticmethod
10031039
def get_next_input_positions(

0 commit comments

Comments
 (0)