|
4 | 4 | import multiprocessing as mp
|
5 | 5 | import os
|
6 | 6 | from abc import ABC, abstractmethod
|
7 |
| -from typing import Optional |
| 7 | +from typing import List, Optional |
8 | 8 |
|
9 | 9 | import numpy as np
|
10 | 10 | import PIL
|
11 | 11 | from decord import VideoReader, cpu
|
12 | 12 | from PIL import Image
|
| 13 | +from transformers import BaseImageProcessorFast |
13 | 14 |
|
14 |
| -from sglang.srt.utils import encode_video, load_audio, load_image, logger |
| 15 | +from sglang.srt.managers.schedule_batch import Modality |
| 16 | +from sglang.srt.utils import encode_video, load_audio, load_image |
15 | 17 |
|
16 | 18 |
|
17 | 19 | @dataclasses.dataclass
|
@@ -78,6 +80,10 @@ def process_mm_data(
|
78 | 80 | kwargs["audios"] = audios
|
79 | 81 |
|
80 | 82 | processor = self._processor
|
| 83 | + if hasattr(processor, "image_processor") and isinstance( |
| 84 | + processor.image_processor, BaseImageProcessorFast |
| 85 | + ): |
| 86 | + kwargs["device"] = "cuda" |
81 | 87 | result = processor.__call__(
|
82 | 88 | text=[input_text],
|
83 | 89 | padding=True,
|
@@ -111,6 +117,84 @@ def get_estimated_frames_list(self, image_data):
|
111 | 117 |
|
112 | 118 | return estimated_frames_list
|
113 | 119 |
|
| 120 | + @staticmethod |
| 121 | + def _load_single_item( |
| 122 | + data, is_video, is_audio, frame_count_limit=None, discard_alpha_channel=True |
| 123 | + ): |
| 124 | + """Static method that can be pickled for multiprocessing""" |
| 125 | + try: |
| 126 | + if is_audio: |
| 127 | + return load_audio(data) |
| 128 | + elif is_video: |
| 129 | + path = data[len("video:") :] |
| 130 | + return encode_video(path, frame_count_limit) |
| 131 | + else: |
| 132 | + img, _ = load_image(data) |
| 133 | + return img.convert("RGB") if discard_alpha_channel else img |
| 134 | + except Exception as e: |
| 135 | + raise RuntimeError(f"Error while loading data {data}: {e}") |
| 136 | + |
| 137 | + def submit_data_loading_tasks( |
| 138 | + self, |
| 139 | + text_parts: List[str], |
| 140 | + multimodal_tokens: MultimodalSpecialTokens, |
| 141 | + image_data: Optional[list] = None, |
| 142 | + audio_data: Optional[list] = None, |
| 143 | + discard_alpha_channel: bool = True, |
| 144 | + ): |
| 145 | + """ |
| 146 | + load multimodal data parallelly |
| 147 | + """ |
| 148 | + |
| 149 | + # TODO(mick): load from server_args, env, or sampling_params |
| 150 | + MAX_NUM_FRAMES = 30 |
| 151 | + estimated_frames_list = self.get_estimated_frames_list(image_data=image_data) |
| 152 | + total_frame_count = sum(estimated_frames_list) |
| 153 | + # a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs. |
| 154 | + # e.g., 0.1 suggests that 1 frame out of 10 input frames should be used |
| 155 | + scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count)) |
| 156 | + |
| 157 | + assert len(image_data) == len(estimated_frames_list) |
| 158 | + # Submit all tasks |
| 159 | + futures = [] |
| 160 | + task_info = [] |
| 161 | + image_index, audio_index = 0, 0 |
| 162 | + |
| 163 | + for text_part in text_parts: |
| 164 | + if text_part == multimodal_tokens.image_token: |
| 165 | + data = image_data[image_index] |
| 166 | + is_video = isinstance(data, str) and data.startswith("video:") |
| 167 | + estimated_frames = estimated_frames_list[image_index] |
| 168 | + frame_count_limit = max(1, int(estimated_frames * scaling_factor)) |
| 169 | + futures.append( |
| 170 | + self.io_executor.submit( |
| 171 | + BaseMultimodalProcessor._load_single_item, |
| 172 | + data, |
| 173 | + is_video, |
| 174 | + False, |
| 175 | + frame_count_limit, |
| 176 | + discard_alpha_channel, |
| 177 | + ) |
| 178 | + ) |
| 179 | + task_info.append((Modality.IMAGE, data, frame_count_limit)) |
| 180 | + image_index += 1 |
| 181 | + elif text_part == multimodal_tokens.audio_token: |
| 182 | + data = audio_data[audio_index] |
| 183 | + futures.append( |
| 184 | + self.io_executor.submit( |
| 185 | + BaseMultimodalProcessor._load_single_item, |
| 186 | + data, |
| 187 | + False, |
| 188 | + True, |
| 189 | + None, |
| 190 | + discard_alpha_channel, |
| 191 | + ) |
| 192 | + ) |
| 193 | + task_info.append((Modality.AUDIO, data, None)) |
| 194 | + audio_index += 1 |
| 195 | + |
| 196 | + return futures, task_info |
| 197 | + |
114 | 198 | def load_mm_data(
|
115 | 199 | self,
|
116 | 200 | prompt: str,
|
@@ -155,84 +239,37 @@ def load_mm_data(
|
155 | 239 | # split text into list of normal text and special tokens
|
156 | 240 | text_parts = re.split(pattern, prompt)
|
157 | 241 |
|
158 |
| - # TODO(mick): load from server_args, env, or sampling_params |
159 |
| - MAX_NUM_FRAMES = 30 |
160 |
| - estimated_frames_list = self.get_estimated_frames_list(image_data=image_data) |
161 |
| - total_frame_count = sum(estimated_frames_list) |
162 |
| - # a heuristic value, suggesting the maximum fraction of frames to embed from all visual inputs. |
163 |
| - # e.g., 0.1 suggests that 1 frame out of 10 input frames should be used |
164 |
| - scaling_factor = min(1.0, MAX_NUM_FRAMES / max(1, total_frame_count)) |
165 |
| - |
166 |
| - assert len(image_data) == len(estimated_frames_list) |
167 |
| - |
168 |
| - image_index, audio_index = 0, 0 |
169 |
| - hashes, image_sizes, images, audios = [], [], [], [] |
| 242 | + futures, task_info = self.submit_data_loading_tasks( |
| 243 | + text_parts=text_parts, |
| 244 | + multimodal_tokens=multimodal_tokens, |
| 245 | + image_data=image_data, |
| 246 | + audio_data=audio_data, |
| 247 | + discard_alpha_channel=discard_alpha_channel, |
| 248 | + ) |
| 249 | + # Process results |
| 250 | + image_sizes, images, audios = [], [], [] |
170 | 251 | new_text = ""
|
171 |
| - for index, text_part in enumerate(text_parts): |
172 |
| - try: |
173 |
| - if text_part == multimodal_tokens.image_token: |
174 |
| - # load as image |
175 |
| - if len(images) >= MAX_NUM_FRAMES: |
176 |
| - frames_to_process = 0 |
177 |
| - else: |
178 |
| - estimated_frames = estimated_frames_list[image_index] |
179 |
| - frames_to_process = max( |
180 |
| - 1, int(estimated_frames * scaling_factor) |
181 |
| - ) |
182 |
| - |
183 |
| - if frames_to_process == 0: |
184 |
| - frames = [] |
185 |
| - else: |
186 |
| - image_file = image_data[image_index] |
187 |
| - if isinstance(image_file, str) and image_file.startswith( |
188 |
| - "video:" |
189 |
| - ): |
190 |
| - # video |
191 |
| - path = image_file[len("video:") :] |
192 |
| - frames = encode_video( |
193 |
| - path, frame_count_limit=frames_to_process |
194 |
| - ) |
195 |
| - else: |
196 |
| - # image |
197 |
| - raw_image, _size = load_image(image_file) |
198 |
| - if discard_alpha_channel: |
199 |
| - raw_image = raw_image.convert("RGB") |
200 |
| - frames = [raw_image] |
201 |
| - if len(frames) == 0: |
202 |
| - continue |
203 |
| - |
204 |
| - image_sizes += frames[0].size * len(frames) |
205 |
| - |
206 |
| - # Generate a hashable value for the image file |
207 |
| - if isinstance(image_file, Image.Image): |
208 |
| - # For PIL.Image objects, use the ID as a hashable value |
209 |
| - hash_value = hash(id(image_file)) |
210 |
| - else: |
211 |
| - # For other types (strings, etc.), use the regular hash |
212 |
| - hash_value = hash(image_file) |
213 |
| - |
214 |
| - hashes += [hash_value] * len(frames) |
215 |
| - images += frames |
216 |
| - image_index += 1 |
217 |
| - if frames_to_process != 0: |
| 252 | + task_ptr = 0 |
| 253 | + |
| 254 | + for text_part in text_parts: |
| 255 | + if text_part in multimodal_tokens.collect(): |
| 256 | + task_type, data, frame_limit = task_info[task_ptr] |
| 257 | + result = futures[task_ptr].result() |
| 258 | + task_ptr += 1 |
| 259 | + |
| 260 | + if task_type == Modality.IMAGE: |
| 261 | + frames = [result] if not isinstance(result, list) else result |
| 262 | + if frames: |
| 263 | + image_sizes += frames[0].size * len(frames) |
| 264 | + images += frames |
218 | 265 | new_text += multimodal_tokens.image_token * len(frames)
|
219 |
| - assert frames_to_process == len(frames) |
220 |
| - elif text_part == multimodal_tokens.audio_token: |
221 |
| - # load as audio |
222 |
| - audio_file = audio_data[audio_index] |
223 |
| - audio = load_audio(audio_file) |
224 |
| - hashes += [hash(audio_file)] |
225 |
| - audios += [audio] |
226 |
| - audio_index += 1 |
| 266 | + elif task_type == Modality.AUDIO: |
| 267 | + # audio |
| 268 | + audios.append(result) |
227 | 269 | new_text += multimodal_tokens.audio_token
|
228 |
| - else: |
229 |
| - # TODO(mick): handle video |
230 |
| - # normal text |
231 |
| - new_text += text_part |
232 |
| - |
233 |
| - except Exception as e: |
234 |
| - logger.error(f"An exception occurred while loading images: {e}") |
235 |
| - raise RuntimeError(f"An exception occurred while loading images: {e}") |
| 270 | + # TODO: handle video |
| 271 | + else: |
| 272 | + new_text += text_part |
236 | 273 |
|
237 | 274 | out = BaseMultiModalProcessorOutput(
|
238 | 275 | images=images,
|
|
0 commit comments