22
22
from sglang .srt .layers .linear import (
23
23
ColumnParallelLinear ,
24
24
QKVParallelLinear ,
25
+ ReplicatedLinear ,
25
26
RowParallelLinear ,
26
27
)
27
28
from sglang .srt .layers .logits_processor import LogitsProcessor
@@ -184,6 +185,7 @@ class MllamaVisionEncoderLayer(nn.Module):
184
185
def __init__ (
185
186
self ,
186
187
config : config_mllama .MllamaVisionConfig ,
188
+ quant_config : Optional [QuantizationConfig ] = None ,
187
189
is_gated : bool = False ,
188
190
prefix : str = "" ,
189
191
):
@@ -199,14 +201,16 @@ def __init__(
199
201
self .num_attention_heads ,
200
202
self .hidden_size ,
201
203
use_qkv_parallel = True ,
202
- quant_config = None ,
204
+ quant_config = quant_config ,
203
205
dropout = 0.0 ,
204
206
use_context_forward = False ,
205
207
softmax_in_single_precision = False ,
206
208
flatten_batch = False ,
207
209
prefix = add_prefix ("self_attn" , prefix ),
208
210
)
209
- self .mlp = MllamaVisionMLP (config , prefix = add_prefix ("mlp" , prefix ))
211
+ self .mlp = MllamaVisionMLP (
212
+ config , quant_config , prefix = add_prefix ("mlp" , prefix )
213
+ )
210
214
211
215
self .input_layernorm = nn .LayerNorm (self .hidden_size , eps = config .norm_eps )
212
216
self .post_attention_layernorm = nn .LayerNorm (
@@ -244,6 +248,7 @@ class MllamaVisionEncoder(nn.Module):
244
248
def __init__ (
245
249
self ,
246
250
config : config_mllama .MllamaVisionConfig ,
251
+ quant_config : Optional [QuantizationConfig ] = None ,
247
252
num_layers = 32 ,
248
253
is_gated = False ,
249
254
output_hidden_states = None ,
@@ -254,7 +259,10 @@ def __init__(
254
259
self .layers = nn .ModuleList (
255
260
[
256
261
MllamaVisionEncoderLayer (
257
- config , is_gated , prefix = add_prefix (f"layers.{ i } " , prefix )
262
+ config ,
263
+ quant_config ,
264
+ is_gated ,
265
+ prefix = add_prefix (f"layers.{ i } " , prefix ),
258
266
)
259
267
for i in range (num_layers )
260
268
]
@@ -283,7 +291,12 @@ def forward(
283
291
284
292
285
293
class MllamaVisionModel (nn .Module ):
286
- def __init__ (self , config : config_mllama .MllamaVisionConfig , prefix : str = "" ):
294
+ def __init__ (
295
+ self ,
296
+ config : config_mllama .MllamaVisionConfig ,
297
+ quant_config : Optional [QuantizationConfig ] = None ,
298
+ prefix : str = "" ,
299
+ ):
287
300
super ().__init__ ()
288
301
self .image_size = config .image_size
289
302
self .patch_size = config .patch_size
@@ -320,13 +333,15 @@ def __init__(self, config: config_mllama.MllamaVisionConfig, prefix: str = ""):
320
333
# encoders
321
334
self .transformer = MllamaVisionEncoder (
322
335
config ,
336
+ quant_config ,
323
337
config .num_hidden_layers ,
324
338
is_gated = False ,
325
339
output_hidden_states = config .intermediate_layers_indices ,
326
340
prefix = add_prefix ("transformer" , prefix ),
327
341
)
328
342
self .global_transformer = MllamaVisionEncoder (
329
343
config ,
344
+ quant_config ,
330
345
config .num_global_layers ,
331
346
is_gated = True ,
332
347
prefix = add_prefix ("global_transformer" , prefix ),
@@ -765,13 +780,35 @@ def forward(
765
780
766
781
767
782
class MllamaForConditionalGeneration (nn .Module ):
783
+ # BitandBytes specific attributes
784
+ default_bitsandbytes_target_modules = [
785
+ ".gate_proj." ,
786
+ ".down_proj." ,
787
+ ".up_proj." ,
788
+ ".q_proj." ,
789
+ ".k_proj." ,
790
+ ".v_proj." ,
791
+ ".o_proj." ,
792
+ ]
793
+ # in TP, these weights are partitioned along the column dimension (dim=-1)
794
+ column_parallel_weights_modules = [".down_proj." , ".o_proj." ]
795
+ bitsandbytes_stacked_params_mapping = {
796
+ # shard_name, weight_name, index
797
+ "q_proj" : ("qkv_proj" , 0 ),
798
+ "k_proj" : ("qkv_proj" , 1 ),
799
+ "v_proj" : ("qkv_proj" , 2 ),
800
+ "gate_proj" : ("gate_up_proj" , 0 ),
801
+ "up_proj" : ("gate_up_proj" , 1 ),
802
+ }
803
+
768
804
def __init__ (
769
805
self ,
770
806
config : config_mllama .MllamaConfig ,
771
807
quant_config : Optional [QuantizationConfig ] = None ,
772
808
prefix : str = "" ,
773
809
):
774
810
super ().__init__ ()
811
+ self .quant_config = quant_config
775
812
self .vocab_size = config .text_config .vocab_size
776
813
self .hidden_size = config .text_config .hidden_size
777
814
self .max_num_tiles = config .vision_config .max_num_tiles
@@ -782,17 +819,21 @@ def __init__(
782
819
self .image_size = config .vision_config .image_size
783
820
784
821
self .vision_model = MllamaVisionModel (
785
- config .vision_config , prefix = add_prefix ("vision_model" , prefix )
822
+ config .vision_config ,
823
+ quant_config = quant_config ,
824
+ prefix = add_prefix ("vision_model" , prefix ),
786
825
)
787
826
self .language_model = MllamaForCausalLM (
788
827
config .text_config ,
789
828
quant_config = quant_config ,
790
829
prefix = add_prefix ("language_model" , prefix ),
791
830
)
792
- self .multi_modal_projector = nn . Linear (
831
+ self .multi_modal_projector = ReplicatedLinear (
793
832
config .vision_config .vision_output_dim ,
794
833
config .text_config .hidden_size ,
795
834
bias = True ,
835
+ quant_config = quant_config ,
836
+ prefix = "multi_modal_projector" ,
796
837
)
797
838
self .logits_processor = LogitsProcessor (config .text_config )
798
839
self .capture_mode = False
@@ -959,7 +1000,9 @@ def forward(
959
1000
cross_attention_states = self .vision_model (
960
1001
batched_images , batched_ar_ids , batched_ar_mask
961
1002
)
962
- cross_attention_states = self .multi_modal_projector (cross_attention_states )
1003
+ cross_attention_states , _ = self .multi_modal_projector (
1004
+ cross_attention_states
1005
+ )
963
1006
964
1007
bs , _ , _ , _ , image_token_dim = cross_attention_states .shape
965
1008
cross_attention_states = cross_attention_states .view (
@@ -1013,7 +1056,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
1013
1056
if "vision_model" in name :
1014
1057
# adapt to VisionAttention
1015
1058
name = name .replace ("self_attn.o_proj" , "self_attn.proj" )
1016
-
1017
1059
param = params_dict .pop (name )
1018
1060
weight_loader = getattr (param , "weight_loader" , default_weight_loader )
1019
1061
weight_loader (param , loaded_weight )
0 commit comments