Skip to content

Commit c98e0b4

Browse files
committed
[perf] use deepgemm for block quant bmm
1 parent d06a83f commit c98e0b4

File tree

3 files changed

+353
-16
lines changed

3 files changed

+353
-16
lines changed

python/sglang/srt/layers/quantization/fp8_kernel.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,3 +941,105 @@ def per_tensor_quant_mla_fp8(
941941
)
942942

943943
return x_q, x_s
944+
945+
946+
@triton.jit
947+
def _per_token_group_quant_mla_deep_gemm_masked_fp8(
948+
y_ptr,
949+
y_q_ptr,
950+
y_s_ptr,
951+
masked_m_ptr,
952+
group_size,
953+
y_stride_b,
954+
y_stride_t,
955+
y_q_stride_b,
956+
y_q_stride_t,
957+
y_s_stride_b,
958+
y_s_stride_g,
959+
eps,
960+
fp8_min,
961+
fp8_max,
962+
NUM_GROUP: tl.constexpr,
963+
BLOCK: tl.constexpr,
964+
):
965+
"""A Triton-accelerated function to perform per-token-group
966+
quantization on a tensor for deep_gemm grouped_gemm_masked.
967+
This function converts the tensor values into float8 values.
968+
y and y_q: (b, t, k)
969+
y_s: (b, k//group_size, t)
970+
"""
971+
t_id = tl.program_id(0)
972+
b_id = tl.program_id(1)
973+
974+
y_ptr += b_id * y_stride_b + t_id * y_stride_t
975+
y_q_ptr += b_id * y_q_stride_b + t_id * y_q_stride_t
976+
y_s_ptr += b_id * y_s_stride_b + t_id
977+
978+
if t_id == 0:
979+
tl.store(masked_m_ptr + b_id, tl.num_programs(0))
980+
981+
cols = tl.arange(0, BLOCK) # group_size <= BLOCK
982+
mask = cols < group_size
983+
984+
for gid in range(NUM_GROUP):
985+
y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to(
986+
tl.float32
987+
)
988+
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
989+
y_s = _absmax / fp8_max
990+
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
991+
992+
tl.store(y_q_ptr + gid * group_size + cols, y_q, mask=mask)
993+
tl.store(y_s_ptr + gid * y_s_stride_g, y_s)
994+
995+
996+
def per_tensor_quant_mla_deep_gemm_masked_fp8(
997+
x: torch.Tensor,
998+
group_size: int = 128,
999+
eps: float = 1e-12,
1000+
dtype: torch.dtype = torch.float8_e4m3fn,
1001+
) -> Tuple[torch.Tensor, torch.Tensor]:
1002+
"""
1003+
This function quantizes input values to float8 values with per-token-group-quantization
1004+
for deep_gemm grouped_gemm_masked and specialized for mla absorbed case.
1005+
"""
1006+
assert x.dim() == 3, "`x` is not a 3d-tensor"
1007+
1008+
finfo = torch.finfo(dtype)
1009+
fp8_max = finfo.max
1010+
if _is_hip:
1011+
dtype = torch.float8_e4m3fnuz
1012+
fp8_max = 224.0
1013+
1014+
b, m, k = x.shape
1015+
aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel
1016+
num_tiles_k = k // group_size
1017+
assert num_tiles_k * group_size == k, f"k % {group_size} must be zero"
1018+
1019+
x_q = x.new_empty((b, aligned_m, k), dtype=dtype)
1020+
x_s = x.new_empty((b, num_tiles_k, aligned_m), dtype=torch.float32)
1021+
masked_m = x.new_empty((b,), dtype=torch.int32)
1022+
1023+
BLOCK_SIZE = triton.next_power_of_2(group_size)
1024+
grid = (m, b)
1025+
1026+
_per_token_group_quant_mla_deep_gemm_masked_fp8[grid](
1027+
x,
1028+
x_q,
1029+
x_s,
1030+
masked_m,
1031+
group_size,
1032+
x.stride(0),
1033+
x.stride(1),
1034+
x_q.stride(0),
1035+
x_q.stride(1),
1036+
x_s.stride(0),
1037+
x_s.stride(1),
1038+
eps,
1039+
-fp8_max,
1040+
fp8_max,
1041+
num_tiles_k,
1042+
BLOCK_SIZE,
1043+
)
1044+
1045+
return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m

python/sglang/srt/models/deepseek_v2.py

Lines changed: 84 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,10 @@
5353
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
5454
from sglang.srt.layers.moe.topk import select_experts
5555
from sglang.srt.layers.quantization.base_config import QuantizationConfig
56-
from sglang.srt.layers.quantization.fp8_kernel import per_tensor_quant_mla_fp8
56+
from sglang.srt.layers.quantization.fp8_kernel import (
57+
per_tensor_quant_mla_deep_gemm_masked_fp8,
58+
per_tensor_quant_mla_fp8,
59+
)
5760
from sglang.srt.layers.quantization.fp8_utils import (
5861
block_quant_to_tensor_quant,
5962
channel_quant_to_tensor_quant,
@@ -78,6 +81,7 @@
7881
_is_cuda = is_cuda()
7982

8083
if _is_cuda:
84+
from deep_gemm import m_grouped_gemm_fp8_fp8_bf16_nt_masked
8185
from sgl_kernel import awq_dequantize, bmm_fp8
8286

8387
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
@@ -691,6 +695,10 @@ def __init__(
691695
self.w_vc = None
692696
self.w_scale = None
693697

698+
self.w_scale_k = None
699+
self.w_scale_v = None
700+
self.use_deep_gemm_bmm = False
701+
694702
self.flashinfer_mla_disable_ragged = global_server_args_dict[
695703
"flashinfer_mla_disable_ragged"
696704
]
@@ -809,7 +817,24 @@ def forward_absorb(
809817
)
810818
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
811819

812-
if self.w_kc.dtype == torch.float8_e4m3fnuz:
820+
if self.use_deep_gemm_bmm:
821+
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
822+
per_tensor_quant_mla_deep_gemm_masked_fp8(
823+
q_nope.transpose(0, 1), dtype=torch.float8_e4m3fn
824+
)
825+
)
826+
q_nope_out = q_nope.new_empty(
827+
(self.num_local_heads, aligned_m, self.kv_lora_rank)
828+
)
829+
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
830+
(q_nope_val, q_nope_scale),
831+
(self.w_kc, self.w_scale_k),
832+
q_nope_out,
833+
masked_m,
834+
expected_m,
835+
)
836+
q_nope_out = q_nope_out[:, :expected_m, :]
837+
elif self.w_kc.dtype == torch.float8_e4m3fnuz:
813838
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
814839
q_nope_out = torch.bmm(
815840
q_nope.to(torch.bfloat16).transpose(0, 1),
@@ -840,7 +865,24 @@ def forward_absorb(
840865
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
841866
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
842867

843-
if self.w_vc.dtype == torch.float8_e4m3fnuz:
868+
if self.use_deep_gemm_bmm:
869+
attn_output_val, attn_output_scale, masked_m, expected_m, aligned_m = (
870+
per_tensor_quant_mla_deep_gemm_masked_fp8(
871+
attn_output.transpose(0, 1), dtype=torch.float8_e4m3fn
872+
)
873+
)
874+
attn_bmm_output = attn_output.new_empty(
875+
(self.num_local_heads, aligned_m, self.v_head_dim)
876+
)
877+
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
878+
(attn_output_val, attn_output_scale),
879+
(self.w_vc, self.w_scale_v),
880+
attn_bmm_output,
881+
masked_m,
882+
expected_m,
883+
)
884+
attn_bmm_output = attn_bmm_output[:, :expected_m, :]
885+
elif self.w_vc.dtype == torch.float8_e4m3fnuz:
844886
# TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
845887
attn_bmm_output = torch.bmm(
846888
attn_output.to(torch.bfloat16).transpose(0, 1),
@@ -1412,6 +1454,10 @@ def post_load_weights(self):
14121454
w = self_attn.kv_b_proj.weight
14131455
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
14141456
# This may affect the accuracy of fp8 model.
1457+
# Fix deepseek v3 blockwise bmm by using deep_gemm
1458+
use_deep_gemm_bmm = False
1459+
model_dtype = torch.get_default_dtype()
1460+
14151461
if w.dtype in (
14161462
torch.float8_e4m3fn,
14171463
torch.float8_e4m3fnuz,
@@ -1430,10 +1476,19 @@ def post_load_weights(self):
14301476
weight = w
14311477
weight_scale = self_attn.kv_b_proj.weight_scale_inv
14321478

1433-
w, scale = block_quant_to_tensor_quant(
1434-
weight, weight_scale, weight_block_size
1435-
)
1436-
self_attn.w_scale = scale
1479+
if (
1480+
_is_cuda
1481+
and weight_block_size[0] == 128
1482+
and weight_block_size[1] == 128
1483+
and model_dtype == torch.bfloat16
1484+
):
1485+
block_scale = weight_scale
1486+
use_deep_gemm_bmm = True
1487+
else:
1488+
w, scale = block_quant_to_tensor_quant(
1489+
weight, weight_scale, weight_block_size
1490+
)
1491+
self_attn.w_scale = scale
14371492
else:
14381493
weight = w
14391494
weight_scale = self_attn.kv_b_proj.weight_scale
@@ -1459,15 +1514,28 @@ def post_load_weights(self):
14591514
w_kc, w_vc = w.unflatten(
14601515
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
14611516
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1462-
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
1463-
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
1464-
if (
1465-
hasattr(self_attn.kv_b_proj, "weight_scale")
1466-
and self_attn.w_scale is None
1467-
):
1468-
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
1469-
if _is_hip:
1470-
self_attn.w_scale *= 2.0
1517+
1518+
if not use_deep_gemm_bmm:
1519+
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
1520+
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
1521+
if (
1522+
hasattr(self_attn.kv_b_proj, "weight_scale")
1523+
and self_attn.w_scale is None
1524+
):
1525+
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
1526+
if _is_hip:
1527+
self_attn.w_scale *= 2.0
1528+
else:
1529+
num_tile_k = self_attn.qk_nope_head_dim // weight_block_size[1]
1530+
num_tile_n = self_attn.v_head_dim // weight_block_size[0]
1531+
ws_kc, ws_vc = block_scale.unflatten(
1532+
0, (-1, (num_tile_k + num_tile_n))
1533+
).split([num_tile_k, num_tile_n], dim=1)
1534+
self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
1535+
self_attn.w_scale_v = ws_vc.contiguous()
1536+
self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
1537+
self_attn.w_vc = w_vc.contiguous()
1538+
self_attn.use_deep_gemm_bmm = True
14711539

14721540
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
14731541
stacked_params_mapping = [

0 commit comments

Comments
 (0)