Skip to content

Commit 34ef6c8

Browse files
authored
[VLM] Adopt fast image processor by default (#5065)
1 parent 6117209 commit 34ef6c8

File tree

12 files changed

+165
-100
lines changed

12 files changed

+165
-100
lines changed

benchmark/mmmu/bench_sglang.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,5 +89,4 @@ def eval_mmmu(args):
8989
EvalArgs.add_cli_args(parser)
9090
args = add_common_sglang_args_and_parse(parser)
9191
args = parser.parse_args()
92-
9392
eval_mmmu(args)

benchmark/mmmu/eval_utils.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pprint
88
import random
99
import re
10+
from concurrent.futures import ThreadPoolExecutor, as_completed
1011
from typing import Dict, Optional
1112

1213
import numpy as np
@@ -117,29 +118,38 @@ def prepare_samples(eval_args: EvalArgs):
117118
# merge all dataset
118119
dataset = concatenate_datasets(sub_dataset_list)
119120

120-
## prepare images
121-
samples = []
122-
skip_count = 0
123-
124-
# use image file as input to ensure the consistency between sglang and hf
121+
# Prepare images in parallel
125122
images_path = os.path.expanduser("~/.cache/mmmu/images")
126123
os.makedirs(images_path, exist_ok=True)
127124
print(f"Saving images to: {images_path}")
128125

129-
for i, sample in enumerate(tqdm(dataset)):
126+
samples = []
127+
skip_count = 0
128+
129+
def process_sample(i, sample):
130130
sample = process_single_sample(sample)
131131
sample = construct_prompt(sample, eval_args.config)
132132
image = sample["image"]
133-
134133
width, height = image.size
135134
if width * height >= eval_args.image_pixels_limit:
136-
skip_count += 1
137-
continue
135+
return None, True
138136
image_path = f"{images_path}/image_{i}.png"
139137
if not os.path.exists(image_path):
140138
image.save(image_path)
141139
sample["image_path"] = image_path
142-
samples.append(sample)
140+
return sample, False
141+
142+
with ThreadPoolExecutor() as executor:
143+
futures = [
144+
executor.submit(process_sample, i, sample)
145+
for i, sample in enumerate(dataset)
146+
]
147+
for future in tqdm(as_completed(futures), total=len(futures)):
148+
sample, skipped = future.result()
149+
if skipped:
150+
skip_count += 1
151+
elif sample:
152+
samples.append(sample)
143153

144154
print(
145155
f"skipping {skip_count} samples with large images, {round((float(skip_count) / len(dataset)) * 100, 2)}% of dataset"

docs/backend/server_arguments.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
4545
Please consult the documentation below to learn more about the parameters you may provide when launching a server.
4646

4747

48-
## Model and tokenizer
48+
## Model, processor and tokenizer
4949

5050
* `model_path`: Path to the model that will be served.
5151
* `tokenizer_path`: Defaults to the `model_path`.
@@ -62,6 +62,7 @@ Please consult the documentation below to learn more about the parameters you ma
6262
* `skip_tokenizer_init`: Set to true to provide the tokens to the engine and get the output tokens directly, typically used in RLHF. Please see this [example for reference](https://github.com/sgl-project/sglang/blob/main/examples/runtime/token_in_token_out/).
6363
* `json_model_override_args`: Override model config with the provided JSON.
6464
* `delete_ckpt_after_loading`: Delete the model checkpoint after loading the model.
65+
* `disable_fast_image_processor`: Adopt base image processor instead of fast image processor(which is by default). For more detail, see: https://huggingface.co/docs/transformers/main/en/main_classes/image_processor#image-processor
6566

6667

6768
## Serving: HTTP & API

python/sglang/srt/hf_transformers_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def get_processor(
215215
tokenizer_mode: str = "auto",
216216
trust_remote_code: bool = False,
217217
tokenizer_revision: Optional[str] = None,
218+
use_fast: Optional[bool] = True,
218219
**kwargs,
219220
):
220221
# pop 'revision' from kwargs if present.
@@ -232,6 +233,9 @@ def get_processor(
232233
if "size" not in kwargs:
233234
kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
234235

236+
if config.model_type not in {"llava", "clip"}:
237+
kwargs["use_fast"] = use_fast
238+
235239
processor = AutoProcessor.from_pretrained(
236240
tokenizer_name,
237241
*args,

python/sglang/srt/managers/multimodal_processors/base_processor.py

Lines changed: 114 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
import multiprocessing as mp
55
import os
66
from abc import ABC, abstractmethod
7-
from typing import Optional
7+
from typing import List, Optional
88

99
import numpy as np
1010
import PIL
1111
from decord import VideoReader, cpu
1212
from PIL import Image
13+
from transformers import BaseImageProcessorFast
1314

14-
from sglang.srt.utils import encode_video, load_audio, load_image, logger
15+
from sglang.srt.managers.schedule_batch import Modality
16+
from sglang.srt.utils import encode_video, load_audio, load_image
1517

1618

1719
@dataclasses.dataclass
@@ -78,6 +80,10 @@ def process_mm_data(
7880
kwargs["audios"] = audios
7981

8082
processor = self._processor
83+
if hasattr(processor, "image_processor") and isinstance(
84+
processor.image_processor, BaseImageProcessorFast
85+
):
86+
kwargs["device"] = "cuda"
8187
result = processor.__call__(
8288
text=[input_text],
8389
padding=True,
@@ -111,6 +117,84 @@ def get_estimated_frames_list(self, image_data):
111117

112118
return estimated_frames_list
113119

120+
@staticmethod
121+
def _load_single_item(
122+
data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True
123+
):
124+
"""Static method that can be pickled for multiprocessing"""
125+
try:
126+
if is_audio:
127+
return load_audio(data)
128+
elif is_video:
129+
path = data[len("video:") :]
130+
return encode_video(path, frame_count_limit)
131+
else:
132+
img, _ = load_image(data)
133+
return img.convert("RGB") if discard_alpha_channel else img
134+
except Exception as e:
135+
raise RuntimeError(f"Error while loading data {data}: {e}")
136+
137+
def submit_data_loading_tasks(
138+
self,
139+
text_parts: List[str],
140+
multimodal_tokens: MultimodalSpecialTokens,
141+
image_data: Optional[list] = None,
142+
audio_data: Optional[list] = None,
143+
discard_alpha_channel: bool = True,
144+
):
145+
"""
146+
load multimodal data parallelly
147+
"""
148+
149+
# TODO(mick): load from server_args, env, or sampling_params
150+
MAX_NUM_FRAMES = 30
151+
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
152+
total_frame_count = sum(estimated_frames_list)
153+
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
154+
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
155+
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
156+
157+
assert len(image_data) == len(estimated_frames_list)
158+
# Submit all tasks
159+
futures = []
160+
task_info = []
161+
image_index, audio_index = 0, 0
162+
163+
for text_part in text_parts:
164+
if text_part == multimodal_tokens.image_token:
165+
data = image_data[image_index]
166+
is_video = isinstance(data, str) and data.startswith("video:")
167+
estimated_frames = estimated_frames_list[image_index]
168+
frame_count_limit = max(1, int(estimated_frames * scaling_factor))
169+
futures.append(
170+
self.io_executor.submit(
171+
BaseMultimodalProcessor._load_single_item,
172+
data,
173+
is_video,
174+
False,
175+
frame_count_limit,
176+
discard_alpha_channel,
177+
)
178+
)
179+
task_info.append((Modality.IMAGE, data, frame_count_limit))
180+
image_index += 1
181+
elif text_part == multimodal_tokens.audio_token:
182+
data = audio_data[audio_index]
183+
futures.append(
184+
self.io_executor.submit(
185+
BaseMultimodalProcessor._load_single_item,
186+
data,
187+
False,
188+
True,
189+
None,
190+
discard_alpha_channel,
191+
)
192+
)
193+
task_info.append((Modality.AUDIO, data, None))
194+
audio_index += 1
195+
196+
return futures, task_info
197+
114198
def load_mm_data(
115199
self,
116200
prompt: str,
@@ -155,84 +239,37 @@ def load_mm_data(
155239
# split text into list of normal text and special tokens
156240
text_parts = re.split(pattern, prompt)
157241

158-
# TODO(mick): load from server_args, env, or sampling_params
159-
MAX_NUM_FRAMES = 30
160-
estimated_frames_list = self.get_estimated_frames_list(image_data=image_data)
161-
total_frame_count = sum(estimated_frames_list)
162-
# a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs.
163-
# e.g., 0.1 suggests that 1 frame out of 10 input frames should be used
164-
scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count))
165-
166-
assert len(image_data) == len(estimated_frames_list)
167-
168-
image_index, audio_index = 0, 0
169-
hashes, image_sizes, images, audios = [], [], [], []
242+
futures, task_info = self.submit_data_loading_tasks(
243+
text_parts=text_parts,
244+
multimodal_tokens=multimodal_tokens,
245+
image_data=image_data,
246+
audio_data=audio_data,
247+
discard_alpha_channel=discard_alpha_channel,
248+
)
249+
# Process results
250+
image_sizes, images, audios = [], [], []
170251
new_text = ""
171-
for index, text_part in enumerate(text_parts):
172-
try:
173-
if text_part == multimodal_tokens.image_token:
174-
# load as image
175-
if len(images) >= MAX_NUM_FRAMES:
176-
frames_to_process = 0
177-
else:
178-
estimated_frames = estimated_frames_list[image_index]
179-
frames_to_process = max(
180-
1, int(estimated_frames * scaling_factor)
181-
)
182-
183-
if frames_to_process == 0:
184-
frames = []
185-
else:
186-
image_file = image_data[image_index]
187-
if isinstance(image_file, str) and image_file.startswith(
188-
"video:"
189-
):
190-
# video
191-
path = image_file[len("video:") :]
192-
frames = encode_video(
193-
path, frame_count_limit=frames_to_process
194-
)
195-
else:
196-
# image
197-
raw_image, _size = load_image(image_file)
198-
if discard_alpha_channel:
199-
raw_image = raw_image.convert("RGB")
200-
frames = [raw_image]
201-
if len(frames) == 0:
202-
continue
203-
204-
image_sizes += frames[0].size * len(frames)
205-
206-
# Generate a hashable value for the image file
207-
if isinstance(image_file, Image.Image):
208-
# For PIL.Image objects, use the ID as a hashable value
209-
hash_value = hash(id(image_file))
210-
else:
211-
# For other types (strings, etc.), use the regular hash
212-
hash_value = hash(image_file)
213-
214-
hashes += [hash_value] * len(frames)
215-
images += frames
216-
image_index += 1
217-
if frames_to_process != 0:
252+
task_ptr = 0
253+
254+
for text_part in text_parts:
255+
if text_part in multimodal_tokens.collect():
256+
task_type, data, frame_limit = task_info[task_ptr]
257+
result = futures[task_ptr].result()
258+
task_ptr += 1
259+
260+
if task_type == Modality.IMAGE:
261+
frames = [result] if not isinstance(result, list) else result
262+
if frames:
263+
image_sizes += frames[0].size * len(frames)
264+
images += frames
218265
new_text += multimodal_tokens.image_token * len(frames)
219-
assert frames_to_process == len(frames)
220-
elif text_part == multimodal_tokens.audio_token:
221-
# load as audio
222-
audio_file = audio_data[audio_index]
223-
audio = load_audio(audio_file)
224-
hashes += [hash(audio_file)]
225-
audios += [audio]
226-
audio_index += 1
266+
elif task_type == Modality.AUDIO:
267+
# audio
268+
audios.append(result)
227269
new_text += multimodal_tokens.audio_token
228-
else:
229-
# TODO(mick): handle video
230-
# normal text
231-
new_text += text_part
232-
233-
except Exception as e:
234-
logger.error(f"An exception occurred while loading images: {e}")
235-
raise RuntimeError(f"An exception occurred while loading images: {e}")
270+
# TODO: handle video
271+
else:
272+
new_text += text_part
236273

237274
out = BaseMultiModalProcessorOutput(
238275
images=images,

python/sglang/srt/managers/multimodal_processors/janus_pro.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ async def process_mm_data_async(
3333
base_out = self.load_mm_data(
3434
prompt=input_ids,
3535
image_data=image_data,
36-
multimodal_tokens=MultimodalSpecialTokens(image_token=processor.image_tag),
36+
multimodal_tokens=MultimodalSpecialTokens(
37+
image_token=processor.image_token
38+
),
3739
max_req_input_len=max_req_input_len,
3840
)
3941

python/sglang/srt/managers/schedule_batch.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,10 @@ def tensor_hash(tensor_list) -> int:
222222
# memoryview() doesn't support PyTorch's BFloat16 dtype
223223
tensor = tensor.float()
224224

225+
assert isinstance(tensor, torch.Tensor)
225226
if tensor.is_cuda:
226-
tensor_cpu = torch.frombuffer(
227-
tensor.storage().untyped(), dtype=tensor.dtype, count=tensor.numel()
228-
).clone()
227+
# TODO: improve this
228+
tensor_cpu = tensor.cpu()
229229
else:
230230
tensor_cpu = tensor
231231

@@ -321,7 +321,6 @@ def from_dict(obj: dict):
321321
item.set_pad_value()
322322

323323
optional_args = [
324-
"modalities",
325324
"im_token_id",
326325
"im_start_id",
327326
"im_end_id",

python/sglang/srt/managers/scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,7 @@ def init_tokenizer(self):
452452
tokenizer_mode=server_args.tokenizer_mode,
453453
trust_remote_code=server_args.trust_remote_code,
454454
revision=server_args.revision,
455+
use_fast=not server_args.disable_fast_image_processor,
455456
)
456457
self.tokenizer = self.processor.tokenizer
457458
else:

python/sglang/srt/managers/tokenizer_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ def __init__(
180180
tokenizer_mode=server_args.tokenizer_mode,
181181
trust_remote_code=server_args.trust_remote_code,
182182
revision=server_args.revision,
183+
use_fast=not server_args.disable_fast_image_processor,
183184
)
184185

185186
# We want to parallelize the image pre-processing so we create an executor for it

0 commit comments

Comments
 (0)