53
53
from sglang .srt .layers .moe .fused_moe_triton import FusedMoE
54
54
from sglang .srt .layers .moe .topk import select_experts
55
55
from sglang .srt .layers .quantization .base_config import QuantizationConfig
56
- from sglang .srt .layers .quantization .fp8_kernel import per_tensor_quant_mla_fp8
56
+ from sglang .srt .layers .quantization .fp8_kernel import (
57
+ per_tensor_quant_mla_deep_gemm_masked_fp8 ,
58
+ per_tensor_quant_mla_fp8 ,
59
+ )
57
60
from sglang .srt .layers .quantization .fp8_utils import (
58
61
block_quant_to_tensor_quant ,
59
62
channel_quant_to_tensor_quant ,
78
81
_is_cuda = is_cuda ()
79
82
80
83
if _is_cuda :
84
+ from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
81
85
from sgl_kernel import awq_dequantize , bmm_fp8
82
86
83
87
from sglang .srt .layers .moe .ep_moe .token_dispatcher import DeepEPDispatcher
@@ -691,6 +695,10 @@ def __init__(
691
695
self .w_vc = None
692
696
self .w_scale = None
693
697
698
+ self .w_scale_k = None
699
+ self .w_scale_v = None
700
+ self .use_deep_gemm_bmm = False
701
+
694
702
self .flashinfer_mla_disable_ragged = global_server_args_dict [
695
703
"flashinfer_mla_disable_ragged"
696
704
]
@@ -809,7 +817,24 @@ def forward_absorb(
809
817
)
810
818
q_nope , q_pe = q .split ([self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
811
819
812
- if self .w_kc .dtype == torch .float8_e4m3fnuz :
820
+ if self .use_deep_gemm_bmm :
821
+ q_nope_val , q_nope_scale , masked_m , expected_m , aligned_m = (
822
+ per_tensor_quant_mla_deep_gemm_masked_fp8 (
823
+ q_nope .transpose (0 , 1 ), dtype = torch .float8_e4m3fn
824
+ )
825
+ )
826
+ q_nope_out = q_nope .new_empty (
827
+ (self .num_local_heads , aligned_m , self .kv_lora_rank )
828
+ )
829
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked (
830
+ (q_nope_val , q_nope_scale ),
831
+ (self .w_kc , self .w_scale_k ),
832
+ q_nope_out ,
833
+ masked_m ,
834
+ expected_m ,
835
+ )
836
+ q_nope_out = q_nope_out [:, :expected_m , :]
837
+ elif self .w_kc .dtype == torch .float8_e4m3fnuz :
813
838
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
814
839
q_nope_out = torch .bmm (
815
840
q_nope .to (torch .bfloat16 ).transpose (0 , 1 ),
@@ -840,7 +865,24 @@ def forward_absorb(
840
865
attn_output = self .attn_mqa (q_input , k_input , v_input , forward_batch )
841
866
attn_output = attn_output .view (- 1 , self .num_local_heads , self .kv_lora_rank )
842
867
843
- if self .w_vc .dtype == torch .float8_e4m3fnuz :
868
+ if self .use_deep_gemm_bmm :
869
+ attn_output_val , attn_output_scale , masked_m , expected_m , aligned_m = (
870
+ per_tensor_quant_mla_deep_gemm_masked_fp8 (
871
+ attn_output .transpose (0 , 1 ), dtype = torch .float8_e4m3fn
872
+ )
873
+ )
874
+ attn_bmm_output = attn_output .new_empty (
875
+ (self .num_local_heads , aligned_m , self .v_head_dim )
876
+ )
877
+ m_grouped_gemm_fp8_fp8_bf16_nt_masked (
878
+ (attn_output_val , attn_output_scale ),
879
+ (self .w_vc , self .w_scale_v ),
880
+ attn_bmm_output ,
881
+ masked_m ,
882
+ expected_m ,
883
+ )
884
+ attn_bmm_output = attn_bmm_output [:, :expected_m , :]
885
+ elif self .w_vc .dtype == torch .float8_e4m3fnuz :
844
886
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
845
887
attn_bmm_output = torch .bmm (
846
888
attn_output .to (torch .bfloat16 ).transpose (0 , 1 ),
@@ -1412,6 +1454,10 @@ def post_load_weights(self):
1412
1454
w = self_attn .kv_b_proj .weight
1413
1455
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
1414
1456
# This may affect the accuracy of fp8 model.
1457
+ # Fix deepseek v3 blockwise bmm by using deep_gemm
1458
+ use_deep_gemm_bmm = False
1459
+ model_dtype = torch .get_default_dtype ()
1460
+
1415
1461
if w .dtype in (
1416
1462
torch .float8_e4m3fn ,
1417
1463
torch .float8_e4m3fnuz ,
@@ -1430,10 +1476,19 @@ def post_load_weights(self):
1430
1476
weight = w
1431
1477
weight_scale = self_attn .kv_b_proj .weight_scale_inv
1432
1478
1433
- w , scale = block_quant_to_tensor_quant (
1434
- weight , weight_scale , weight_block_size
1435
- )
1436
- self_attn .w_scale = scale
1479
+ if (
1480
+ _is_cuda
1481
+ and weight_block_size [0 ] == 128
1482
+ and weight_block_size [1 ] == 128
1483
+ and model_dtype == torch .bfloat16
1484
+ ):
1485
+ block_scale = weight_scale
1486
+ use_deep_gemm_bmm = True
1487
+ else :
1488
+ w , scale = block_quant_to_tensor_quant (
1489
+ weight , weight_scale , weight_block_size
1490
+ )
1491
+ self_attn .w_scale = scale
1437
1492
else :
1438
1493
weight = w
1439
1494
weight_scale = self_attn .kv_b_proj .weight_scale
@@ -1459,15 +1514,28 @@ def post_load_weights(self):
1459
1514
w_kc , w_vc = w .unflatten (
1460
1515
0 , (- 1 , self_attn .qk_nope_head_dim + self_attn .v_head_dim )
1461
1516
).split ([self_attn .qk_nope_head_dim , self_attn .v_head_dim ], dim = 1 )
1462
- self_attn .w_kc = w_kc .transpose (1 , 2 ).contiguous ().transpose (1 , 2 )
1463
- self_attn .w_vc = w_vc .contiguous ().transpose (1 , 2 )
1464
- if (
1465
- hasattr (self_attn .kv_b_proj , "weight_scale" )
1466
- and self_attn .w_scale is None
1467
- ):
1468
- self_attn .w_scale = self_attn .kv_b_proj .weight_scale
1469
- if _is_hip :
1470
- self_attn .w_scale *= 2.0
1517
+
1518
+ if not use_deep_gemm_bmm :
1519
+ self_attn .w_kc = w_kc .transpose (1 , 2 ).contiguous ().transpose (1 , 2 )
1520
+ self_attn .w_vc = w_vc .contiguous ().transpose (1 , 2 )
1521
+ if (
1522
+ hasattr (self_attn .kv_b_proj , "weight_scale" )
1523
+ and self_attn .w_scale is None
1524
+ ):
1525
+ self_attn .w_scale = self_attn .kv_b_proj .weight_scale
1526
+ if _is_hip :
1527
+ self_attn .w_scale *= 2.0
1528
+ else :
1529
+ num_tile_k = self_attn .qk_nope_head_dim // weight_block_size [1 ]
1530
+ num_tile_n = self_attn .v_head_dim // weight_block_size [0 ]
1531
+ ws_kc , ws_vc = block_scale .unflatten (
1532
+ 0 , (- 1 , (num_tile_k + num_tile_n ))
1533
+ ).split ([num_tile_k , num_tile_n ], dim = 1 )
1534
+ self_attn .w_scale_k = ws_kc .transpose (1 , 2 ).contiguous ()
1535
+ self_attn .w_scale_v = ws_vc .contiguous ()
1536
+ self_attn .w_kc = w_kc .transpose (1 , 2 ).contiguous ()
1537
+ self_attn .w_vc = w_vc .contiguous ()
1538
+ self_attn .use_deep_gemm_bmm = True
1471
1539
1472
1540
def load_weights (self , weights : Iterable [Tuple [str , torch .Tensor ]]):
1473
1541
stacked_params_mapping = [
0 commit comments