Skip to content

Commit fbebcb7

Browse files
authored
model: support mllama4 (#5144)
1 parent 87edded commit fbebcb7

File tree

7 files changed

+145
-65
lines changed

7 files changed

+145
-65
lines changed

python/sglang/srt/configs/model_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,8 +486,8 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
486486
"Gemma3ForConditionalGeneration",
487487
"Grok1VForCausalLM",
488488
"Grok1AForCausalLM",
489-
# TODO: add multimodal support for "Llama4ForConditionalGeneration",
490489
"LlavaLlamaForCausalLM",
490+
"Llama4ForConditionalGeneration",
491491
"LlavaMistralForCausalLM",
492492
"LlavaQwenForCausalLM",
493493
"LlavaVidForCausalLM",

python/sglang/srt/managers/mm_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,8 @@ def get_embedding_and_mask(
148148
placeholder_tensor,
149149
).unsqueeze(-1)
150150

151-
num_mm_tokens_in_input_ids = special_multimodal_mask.sum()
151+
num_mm_tokens_in_input_ids = special_multimodal_mask.sum().item()
152+
152153
if num_mm_tokens_in_input_ids != num_mm_tokens_in_embedding:
153154
logger.warning(
154155
f"Number of tokens in multimodal embedding does not match those in the input text."
@@ -172,7 +173,7 @@ def get_embedding_and_mask(
172173
embedding = embedding[-num_multimodal:, :]
173174
else:
174175
raise RuntimeError(
175-
"Insufficient multimodal embedding length. This is an internal error"
176+
f"Insufficient multimodal embedding length: {num_mm_tokens_in_input_ids=} vs {num_mm_tokens_in_embedding=}. This is an internal error"
176177
)
177178

178179
return embedding, special_multimodal_mask

python/sglang/srt/managers/multimodal_processors/mllama4.py

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
1-
from typing import List, Mapping, Optional, Tuple, Union
1+
from typing import List, Union
22

33
import torch
4-
from PIL import Image
5-
from transformers import Llama4Processor
64
from transformers.image_utils import SizeDict
7-
from transformers.models.llama4.image_processing_llama4 import (
5+
from transformers.models.llama4.image_processing_llama4_fast import (
86
find_supported_resolutions,
97
get_best_fit,
108
)
@@ -15,7 +13,6 @@
1513
)
1614
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
1715
from sglang.srt.models.mllama4 import Llama4ForConditionalGeneration
18-
from sglang.srt.utils import load_image
1916

2017

2118
class Mllama4ImageProcessor(BaseMultimodalProcessor):
@@ -25,6 +22,9 @@ def __init__(self, hf_config, server_args, _processor):
2522
super().__init__(hf_config, server_args, _processor)
2623
self.vision_config = hf_config.vision_config
2724
self.text_config = hf_config.text_config
25+
self.boi_token_index = hf_config.boi_token_index
26+
self.eoi_token_index = hf_config.eoi_token_index
27+
self.image_token_index = hf_config.image_token_index
2828
self.multimodal_tokens = MultimodalSpecialTokens(
2929
image_token=_processor.image_token
3030
)
@@ -54,19 +54,16 @@ async def process_mm_data_async(
5454
)
5555

5656
# Process the images using the processor
57-
processor = Llama4Processor.from_pretrained(
58-
self.server_args.model_path, **kwargs
59-
)
57+
processor = self._processor
6058

6159
# Process the prompt and images
62-
image_inputs = processor(
63-
text=processed_data.input_text,
60+
processor_output = self.process_mm_data(
61+
input_text=processed_data.input_text,
6462
images=processed_data.images,
65-
return_tensors="pt",
6663
)
6764

6865
# Handle image resolutions and aspect ratios
69-
if "pixel_values" in image_inputs:
66+
if "pixel_values" in processor_output:
7067
image_processor = processor.image_processor
7168
tokenizer = self._processor.tokenizer
7269

@@ -100,16 +97,16 @@ async def process_mm_data_async(
10097
]
10198

10299
# Add to image_inputs
103-
image_inputs["aspect_ratios"] = aspect_ratios
104-
image_inputs["patches_per_image"] = torch.tensor(patches_per_image)
100+
processor_output["aspect_ratios"] = aspect_ratios
101+
processor_output["patches_per_image"] = torch.tensor(patches_per_image)
105102

106103
# Process embed_is_patch
107104
vocab = tokenizer.get_vocab()
108105
patch_id = vocab.get(processor.img_patch_token, -1)
109106
image_end_id = vocab.get(processor.end_of_img_token, -1)
110107

111108
if patch_id != -1 and image_end_id != -1:
112-
input_ids = image_inputs["input_ids"].view(-1)
109+
input_ids = processor_output["input_ids"].view(-1)
113110

114111
# Remove BOS token if present
115112
if input_ids.size(0) > 0 and input_ids[0] == tokenizer.bos_token_id:
@@ -129,33 +126,21 @@ async def process_mm_data_async(
129126
for per_image_input_ids in split_input_ids:
130127
embed_is_patch.append(per_image_input_ids == patch_id)
131128

132-
image_inputs["embed_is_patch"] = embed_is_patch
129+
processor_output["embed_is_patch"] = embed_is_patch
133130

134131
# Convert to the format expected by SGLang
135-
image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
132+
processor_output["input_ids"] = processor_output["input_ids"].tolist()[0]
133+
134+
processor_output["im_start_id"] = self.boi_token_index
135+
processor_output["im_end_id"] = self.eoi_token_index
136+
processor_output["im_token_id"] = self.image_token_index
136137

137138
# Add metadata for image processing
138-
image_inputs["mm_items"] = [
139+
processor_output["mm_items"] = [
139140
MultimodalDataItem(
140-
pixel_values=image_inputs["pixel_values"],
141+
pixel_values=processor_output["pixel_values"],
141142
modality=Modality.IMAGE,
142-
# Add additional metadata needed for Llama4 vision processing
143-
embed_is_patch=image_inputs.get("embed_is_patch", None),
144-
aspect_ratios=image_inputs.get("aspect_ratios", None),
145-
patches_per_image=image_inputs.get("patches_per_image", None),
146143
)
147144
]
148145

149-
return image_inputs
150-
151-
def get_patch_per_chunk(self):
152-
"""Calculate patches per chunk based on vision config"""
153-
image_size = self.vision_config.image_size
154-
patch_size = self.vision_config.patch_size
155-
156-
assert (
157-
image_size % patch_size == 0
158-
), f"chunk size {image_size} should be multiple of patch_size {patch_size}"
159-
160-
ds_ratio = int(round(1.0 / (self.vision_config.pixel_shuffle_ratio**2)))
161-
return (image_size // patch_size) ** 2 // ds_ratio
146+
return processor_output

python/sglang/srt/managers/schedule_batch.py

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import hashlib
34
from enum import Enum, auto
45

56
# Copyright 2023-2024 SGLang Team
@@ -157,7 +158,7 @@ class Modality(Enum):
157158
@dataclasses.dataclass
158159
class MultimodalDataItem:
159160
"""
160-
A single multimodal data, from a single image/video/audio or other
161+
A single multimodal data, from a single image/video/audio or others
161162
"""
162163

163164
modality: Modality
@@ -195,25 +196,54 @@ def is_empty_list(l):
195196

196197
def set_pad_value(self):
197198
"""
198-
Set the pad value after first hashign the data
199+
Set the pad value after first hashing the data
199200
"""
200201

201-
def tensor_hash(f):
202-
f_list = flatten_nested_list(f)
203-
f_list = [x.flatten() if isinstance(x, torch.Tensor) else x for x in f_list]
204-
f_cat = torch.concat(f_list).contiguous().numpy().tobytes()
205-
return hash(f_cat)
202+
def data_hash(data) -> int:
203+
hash_bytes = hashlib.sha256(data).digest()[:8]
204+
return int.from_bytes(hash_bytes, byteorder="big", signed=False)
205+
206+
def tensor_hash(tensor_list) -> int:
207+
"""
208+
hash a tensor or a tensor list
209+
"""
210+
tensor = tensor_list
211+
if isinstance(tensor_list, list):
212+
tensor_list = flatten_nested_list(tensor_list)
213+
tensor_list = [
214+
x.flatten() if isinstance(x, torch.Tensor) else x
215+
for x in tensor_list
216+
]
217+
tensor = torch.concat(tensor_list)
218+
219+
tensor = tensor.detach().contiguous()
220+
221+
if tensor.dtype == torch.bfloat16:
222+
# memoryview() doesn't support PyTorch's BFloat16 dtype
223+
tensor = tensor.float()
224+
225+
if tensor.is_cuda:
226+
tensor_cpu = torch.frombuffer(
227+
tensor.storage().untyped(), dtype=tensor.dtype, count=tensor.numel()
228+
).clone()
229+
else:
230+
tensor_cpu = tensor
231+
232+
mv = memoryview(tensor_cpu.numpy())
233+
return data_hash(mv.tobytes())
206234

207235
def hash_feature(f):
208236
if isinstance(f, list):
209237
if isinstance(f[0], torch.Tensor):
210238
return tensor_hash(f)
211-
return hash(tuple(flatten_nested_list(f)))
239+
return data_hash(tuple(flatten_nested_list(f)))
212240
elif isinstance(f, np.ndarray):
213241
arr = np.ascontiguousarray(f)
214242
arr_bytes = arr.tobytes()
215-
return hash(arr_bytes)
216-
return hash(f)
243+
return data_hash(arr_bytes)
244+
elif isinstance(f, torch.Tensor):
245+
return tensor_hash([f])
246+
return data_hash(f)
217247

218248
if self.is_audio():
219249
self.hash = hash_feature(self.audio_features)
@@ -256,7 +286,7 @@ class MultimodalInputs:
256286
mrope_position_delta: Optional[torch.Tensor] = None
257287

258288
# image
259-
im_token_id: Optional[torch.Tensor] = None
289+
im_token_id: Optional[int] = None
260290
im_start_id: Optional[int] = None
261291
im_end_id: Optional[int] = None
262292
slice_start_id: Optional[int] = None
@@ -330,10 +360,8 @@ def merge(self, other: MultimodalInputs):
330360

331361
# args needed to be merged
332362
optional_args = [
333-
"items",
334-
"image_offsets",
363+
"mm_items",
335364
"image_pad_len",
336-
# "modalities", # modalities should be ["multi-images"] (one entry) even for multiple images
337365
]
338366
for arg in optional_args:
339367
self_arg = getattr(self, arg, None)

python/sglang/srt/models/llama4.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,9 @@ def __init__(
466466
):
467467
super().__init__(config, quant_config, prefix)
468468

469+
def get_input_embeddings(self):
470+
return self.model.embed_tokens
471+
469472
def _init_model(
470473
self,
471474
config: Llama4TextConfig,

python/sglang/srt/models/mllama4.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
1-
# TODO: add Aapted from vllm/mllama4.py
21
from collections.abc import Iterable
3-
from typing import Optional, Set, Tuple
2+
from typing import List, Optional, Set, Tuple
43

54
import torch
65
from torch import nn
7-
from transformers import Llama4Config
6+
from transformers import Llama4Config, Llama4VisionModel
7+
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
88

99
from sglang.srt.layers.logits_processor import LogitsProcessor
1010
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
1111
from sglang.srt.layers.quantization import QuantizationConfig
12+
from sglang.srt.managers.mm_utils import (
13+
MultiModalityDataPaddingPatternImageTokens,
14+
general_mm_embed_routine,
15+
)
16+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
1217
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
1318
from sglang.srt.model_loader.weight_utils import default_weight_loader
1419
from sglang.srt.utils import add_prefix
@@ -30,6 +35,9 @@ def __init__(
3035
self.config = config
3136
self.quant_config = quant_config
3237

38+
self.vision_model = Llama4VisionModel(config.vision_config)
39+
self.multi_modal_projector = Llama4MultiModalProjector(config)
40+
3341
# Initialize the language model
3442
from sglang.srt.models.llama4 import Llama4ForCausalLM
3543

@@ -41,6 +49,29 @@ def __init__(
4149

4250
self.logits_processor = LogitsProcessor(config.text_config)
4351

52+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
53+
# Get all special token IDs
54+
im_token_id: int = mm_inputs.im_token_id
55+
56+
pattern = MultiModalityDataPaddingPatternImageTokens(torch.tensor(im_token_id))
57+
return pattern.pad_input_tokens(input_ids, mm_inputs)
58+
59+
def get_image_feature(
60+
self,
61+
items: List[MultimodalDataItem],
62+
) -> torch.Tensor:
63+
pixel_values = (
64+
torch.concat([item.pixel_values for item in items])
65+
.to(next(self.vision_model.parameters()).device)
66+
.type(next(self.vision_model.parameters()).dtype)
67+
)
68+
69+
image_outputs = self.vision_model(pixel_values, output_hidden_states=False)
70+
image_features = image_outputs.last_hidden_state
71+
vision_flat = image_features.view(-1, image_features.size(-1))
72+
projected_vision_flat = self.multi_modal_projector(vision_flat)
73+
return projected_vision_flat
74+
4475
def forward(
4576
self,
4677
input_ids: torch.Tensor,
@@ -49,7 +80,15 @@ def forward(
4980
**kwargs: object,
5081
) -> torch.Tensor:
5182

52-
return self.language_model(input_ids, positions, forward_batch)
83+
hs = general_mm_embed_routine(
84+
input_ids=input_ids,
85+
forward_batch=forward_batch,
86+
language_model=self.language_model,
87+
image_data_embedding_func=self.get_image_feature,
88+
positions=positions,
89+
)
90+
91+
return hs
5392

5493
def permute_qk_weight_for_rotary(
5594
self,
@@ -108,17 +147,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]:
108147
)
109148

110149
for name, loaded_weight in weights:
111-
112-
if name.startswith("vision_model") or name.startswith(
113-
"multi_modal_projector"
114-
):
115-
continue
116-
117-
name, loaded_weight = self.permute_qk_weight_for_rotary(name, loaded_weight)
150+
if not "vision" in name:
151+
name, loaded_weight = self.permute_qk_weight_for_rotary(
152+
name, loaded_weight
153+
)
118154

119155
for param_name, weight_name, shard_id in stacked_params_mapping:
120156
if weight_name not in name:
121157
continue
158+
159+
if "vision" in name:
160+
continue
122161
name = name.replace(weight_name, param_name)
123162
param = params_dict[name]
124163
weight_loader = param.weight_loader

0 commit comments

Comments
 (0)