Skip to content

Commit 4b7e3cc

Browse files
ryang-maxyhyang201
authored andcommitted
Support BNB quantization for llama/mllama (sgl-project#5038)
Co-authored-by: Yuhao Yang <[email protected]>
1 parent 4195ebe commit 4b7e3cc

File tree

3 files changed

+60
-11
lines changed

3 files changed

+60
-11
lines changed

python/sglang/srt/model_loader/loader.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1074,7 +1074,11 @@ def _load_weights(self, model_config: ModelConfig, model: nn.Module) -> None:
10741074
model_type = model_config.hf_config.model_type
10751075
for quant_param_name in quant_state_dict:
10761076
non_stacked_param_name = quant_param_name
1077-
1077+
if model_type == "mllama" and "vision_model" in quant_param_name:
1078+
# adapt to VisionAttention
1079+
quant_param_name = quant_param_name.replace(
1080+
"self_attn.o_proj", "self_attn.proj"
1081+
)
10781082
shard_index = 0
10791083
for shard_name, (
10801084
weight_name,

python/sglang/srt/models/mllama.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from sglang.srt.layers.linear import (
2323
ColumnParallelLinear,
2424
QKVParallelLinear,
25+
ReplicatedLinear,
2526
RowParallelLinear,
2627
)
2728
from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -184,6 +185,7 @@ class MllamaVisionEncoderLayer(nn.Module):
184185
def __init__(
185186
self,
186187
config: config_mllama.MllamaVisionConfig,
188+
quant_config: Optional[QuantizationConfig] = None,
187189
is_gated: bool = False,
188190
prefix: str = "",
189191
):
@@ -199,14 +201,16 @@ def __init__(
199201
self.num_attention_heads,
200202
self.hidden_size,
201203
use_qkv_parallel=True,
202-
quant_config=None,
204+
quant_config=quant_config,
203205
dropout=0.0,
204206
use_context_forward=False,
205207
softmax_in_single_precision=False,
206208
flatten_batch=False,
207209
prefix=add_prefix("self_attn", prefix),
208210
)
209-
self.mlp = MllamaVisionMLP(config, prefix=add_prefix("mlp", prefix))
211+
self.mlp = MllamaVisionMLP(
212+
config, quant_config, prefix=add_prefix("mlp", prefix)
213+
)
210214

211215
self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps)
212216
self.post_attention_layernorm = nn.LayerNorm(
@@ -244,6 +248,7 @@ class MllamaVisionEncoder(nn.Module):
244248
def __init__(
245249
self,
246250
config: config_mllama.MllamaVisionConfig,
251+
quant_config: Optional[QuantizationConfig] = None,
247252
num_layers=32,
248253
is_gated=False,
249254
output_hidden_states=None,
@@ -254,7 +259,10 @@ def __init__(
254259
self.layers = nn.ModuleList(
255260
[
256261
MllamaVisionEncoderLayer(
257-
config, is_gated, prefix=add_prefix(f"layers.{i}", prefix)
262+
config,
263+
quant_config,
264+
is_gated,
265+
prefix=add_prefix(f"layers.{i}", prefix),
258266
)
259267
for i in range(num_layers)
260268
]
@@ -283,7 +291,12 @@ def forward(
283291

284292

285293
class MllamaVisionModel(nn.Module):
286-
def __init__(self, config: config_mllama.MllamaVisionConfig, prefix: str = ""):
294+
def __init__(
295+
self,
296+
config: config_mllama.MllamaVisionConfig,
297+
quant_config: Optional[QuantizationConfig] = None,
298+
prefix: str = "",
299+
):
287300
super().__init__()
288301
self.image_size = config.image_size
289302
self.patch_size = config.patch_size
@@ -320,13 +333,15 @@ def __init__(self, config: config_mllama.MllamaVisionConfig, prefix: str = ""):
320333
# encoders
321334
self.transformer = MllamaVisionEncoder(
322335
config,
336+
quant_config,
323337
config.num_hidden_layers,
324338
is_gated=False,
325339
output_hidden_states=config.intermediate_layers_indices,
326340
prefix=add_prefix("transformer", prefix),
327341
)
328342
self.global_transformer = MllamaVisionEncoder(
329343
config,
344+
quant_config,
330345
config.num_global_layers,
331346
is_gated=True,
332347
prefix=add_prefix("global_transformer", prefix),
@@ -765,13 +780,35 @@ def forward(
765780

766781

767782
class MllamaForConditionalGeneration(nn.Module):
783+
# BitandBytes specific attributes
784+
default_bitsandbytes_target_modules = [
785+
".gate_proj.",
786+
".down_proj.",
787+
".up_proj.",
788+
".q_proj.",
789+
".k_proj.",
790+
".v_proj.",
791+
".o_proj.",
792+
]
793+
# in TP, these weights are partitioned along the column dimension (dim=-1)
794+
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
795+
bitsandbytes_stacked_params_mapping = {
796+
# shard_name, weight_name, index
797+
"q_proj": ("qkv_proj", 0),
798+
"k_proj": ("qkv_proj", 1),
799+
"v_proj": ("qkv_proj", 2),
800+
"gate_proj": ("gate_up_proj", 0),
801+
"up_proj": ("gate_up_proj", 1),
802+
}
803+
768804
def __init__(
769805
self,
770806
config: config_mllama.MllamaConfig,
771807
quant_config: Optional[QuantizationConfig] = None,
772808
prefix: str = "",
773809
):
774810
super().__init__()
811+
self.quant_config = quant_config
775812
self.vocab_size = config.text_config.vocab_size
776813
self.hidden_size = config.text_config.hidden_size
777814
self.max_num_tiles = config.vision_config.max_num_tiles
@@ -782,17 +819,21 @@ def __init__(
782819
self.image_size = config.vision_config.image_size
783820

784821
self.vision_model = MllamaVisionModel(
785-
config.vision_config, prefix=add_prefix("vision_model", prefix)
822+
config.vision_config,
823+
quant_config=quant_config,
824+
prefix=add_prefix("vision_model", prefix),
786825
)
787826
self.language_model = MllamaForCausalLM(
788827
config.text_config,
789828
quant_config=quant_config,
790829
prefix=add_prefix("language_model", prefix),
791830
)
792-
self.multi_modal_projector = nn.Linear(
831+
self.multi_modal_projector = ReplicatedLinear(
793832
config.vision_config.vision_output_dim,
794833
config.text_config.hidden_size,
795834
bias=True,
835+
quant_config=quant_config,
836+
prefix="multi_modal_projector",
796837
)
797838
self.logits_processor = LogitsProcessor(config.text_config)
798839
self.capture_mode = False
@@ -959,7 +1000,9 @@ def forward(
9591000
cross_attention_states = self.vision_model(
9601001
batched_images, batched_ar_ids, batched_ar_mask
9611002
)
962-
cross_attention_states = self.multi_modal_projector(cross_attention_states)
1003+
cross_attention_states, _ = self.multi_modal_projector(
1004+
cross_attention_states
1005+
)
9631006

9641007
bs, _, _, _, image_token_dim = cross_attention_states.shape
9651008
cross_attention_states = cross_attention_states.view(
@@ -1013,7 +1056,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
10131056
if "vision_model" in name:
10141057
# adapt to VisionAttention
10151058
name = name.replace("self_attn.o_proj", "self_attn.proj")
1016-
10171059
param = params_dict.pop(name)
10181060
weight_loader = getattr(param, "weight_loader", default_weight_loader)
10191061
weight_loader(param, loaded_weight)

test/srt/test_bnb.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""
22
Usage:
3-
python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_mixed_batch
4-
python3 -m unittest test_vision_openai_server.TestOpenAIVisionServer.test_multi_images_chat_completion
3+
python3 -m unittest test_bnb.TestVisionModel.test_vlm
4+
python3 -m unittest test_bnb.TestLanguageModel.test_mmlu
55
"""
66

77
import base64
@@ -31,10 +31,13 @@
3131
VISION_MODELS = [
3232
("unsloth/Qwen2.5-VL-7B-Instruct-bnb-4bit", "qwen2-vl"),
3333
("unsloth/Qwen2-VL-7B-Instruct-bnb-4bit", "qwen2-vl"),
34+
("unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit", "llama_3_vision"),
35+
("unsloth/Llama-3.2-11B-Vision-bnb-4bit", "llama_3_vision"),
3436
]
3537
LANGUAGE_MODELS = [
3638
"unsloth/Qwen2.5-7B-Instruct-bnb-4bit",
3739
"unsloth/Qwen2-7B-Instruct-bnb-4bit",
40+
"unsloth/Llama-3.2-3B-Instruct-bnb-4bit",
3841
]
3942

4043
# image

0 commit comments

Comments
 (0)