Skip to content

Commit 77e9549

Browse files
yizhang2077thyecust
authored andcommitted
sgl-kernel use cutlass latest version for fp8 blockwise gemm (sgl-project#5207)
1 parent aae6996 commit 77e9549

File tree

6 files changed

+86
-923
lines changed

6 files changed

+86
-923
lines changed

sgl-kernel/benchmark/bench_fp8_blockwise_gemm.py

Lines changed: 49 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,22 @@
22
import copy
33
import itertools
44

5+
import deep_gemm
56
import torch
67
import triton
8+
from deep_gemm import get_col_major_tma_aligned_tensor
79
from sgl_kernel import fp8_blockwise_scaled_mm
810
from vllm._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
911

12+
from sglang.srt.layers.quantization.fp8_kernel import w8a8_block_fp8_matmul
13+
1014

1115
def get_weight_shapes(args):
1216
models_tps = list(itertools.product(args.models, args.tp_sizes))
1317
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
1418
# cannot TP
1519
total = [
16-
# (512 + 64, 7168), # this weight is not supported by current kernel
20+
(512 + 64, 7168),
1721
((128 + 64) * 128, 7168),
1822
(128 * (128 + 128), 512),
1923
(7168, 16384),
@@ -52,6 +56,23 @@ def cdiv(a: int, b: int) -> int:
5256
return -(a // -b)
5357

5458

59+
def fp8_gemm_deepgemm(
60+
x_fp8: torch.Tensor,
61+
x_scale: torch.Tensor,
62+
y_fp8: torch.Tensor,
63+
y_scale: torch.Tensor,
64+
m: int,
65+
n: int,
66+
k: int,
67+
):
68+
"""DeepGEMM implementation of FP8 GEMM"""
69+
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
70+
71+
# Run DeepGEMM kernel
72+
deep_gemm.gemm_fp8_fp8_bf16_nt((x_fp8, x_scale), (y_fp8, y_scale), out)
73+
return out
74+
75+
5576
def scale_shape(shape, group_shape):
5677
assert len(shape) == len(group_shape)
5778
return tuple(cdiv(shape[i], group_shape[i]) for i in range(len(group_shape)))
@@ -60,12 +81,12 @@ def scale_shape(shape, group_shape):
6081
@triton.testing.perf_report(
6182
triton.testing.Benchmark(
6283
x_names=["batch_size"],
63-
x_vals=[1, 16, 32, 64, 128, 256, 512, 1024, 2048],
84+
x_vals=[1, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096],
6485
x_log=False,
6586
line_arg="provider",
66-
line_vals=["vllm", "sgl-kernel"],
67-
line_names=["vllm fp8 blockwise gemm", "sgl-kernel fp8 blockwise gemm"],
68-
styles=[("blue", "-"), ("orange", "-")],
87+
line_vals=["vllm", "sgl-kernel", "triton", "deepgemm"],
88+
line_names=["vllm", "sgl-kernel", "sglang triton", "deepgemm"],
89+
styles=[("blue", "-"), ("orange", "-"), ("red", "-"), ("yellow", "-")],
6990
ylabel="GB/s",
7091
plot_name="fp8 blockwise scaled matmul",
7192
args={},
@@ -80,7 +101,7 @@ def benchmark(batch_size, provider, N, K):
80101
a_fp8 = a_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
81102

82103
b_fp32 = (torch.rand(N, K, dtype=torch.float32, device="cuda") - 0.5) * 2 * fp8_max
83-
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn).t()
104+
b_fp8 = b_fp32.clamp(min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
84105

85106
scale_a_group_shape = (1, 128)
86107
scale_b_group_shape = (128, 128)
@@ -89,31 +110,40 @@ def benchmark(batch_size, provider, N, K):
89110

90111
scale_a = torch.randn(scale_a_shape, device="cuda", dtype=torch.float32)
91112
scale_b = torch.randn(scale_b_shape, device="cuda", dtype=torch.float32)
92-
scale_a = scale_a.t().contiguous().t()
93-
scale_b = scale_b.t().contiguous().t()
94113

95114
quantiles = [0.5, 0.2, 0.8]
96115
if provider == "sgl-kernel":
116+
scale_a = scale_a.t().contiguous().t()
117+
b_fp8, scale_b = b_fp8.t(), scale_b.t()
97118
ms, min_ms, max_ms = triton.testing.do_bench(
98119
lambda: fp8_blockwise_scaled_mm(
99120
a_fp8, b_fp8, scale_a, scale_b, torch.float16
100121
),
101122
quantiles=quantiles,
102123
)
103124
if provider == "vllm":
125+
scale_a = scale_a.t().contiguous().t()
126+
b_fp8, scale_b = b_fp8.t(), scale_b.t()
104127
ms, min_ms, max_ms = triton.testing.do_bench(
105128
lambda: vllm_scaled_mm(a_fp8, b_fp8, scale_a, scale_b, torch.float16),
106129
quantiles=quantiles,
107130
)
108-
gbps = (
109-
lambda ms: (
110-
(2 * M * N * K - M * N) * a_fp8.element_size()
111-
+ (3 * M * N) * scale_a.element_size()
131+
if provider == "triton":
132+
ms, min_ms, max_ms = triton.testing.do_bench(
133+
lambda: w8a8_block_fp8_matmul(
134+
a_fp8, b_fp8, scale_a, scale_b, [128, 128], torch.float16
135+
),
136+
quantiles=quantiles,
137+
)
138+
if provider == "deepgemm":
139+
scale_a_col_major = get_col_major_tma_aligned_tensor(scale_a.clone())
140+
ms, min_ms, max_ms = triton.testing.do_bench(
141+
lambda: fp8_gemm_deepgemm(
142+
a_fp8, scale_a_col_major, b_fp8, scale_b, M, N, K
143+
),
144+
quantiles=quantiles,
112145
)
113-
* 1e-9
114-
/ (ms * 1e-3)
115-
)
116-
return gbps(ms), gbps(max_ms), gbps(min_ms)
146+
return ms * 1000, max_ms * 1000, min_ms * 1000 # convert to ms
117147

118148

119149
if __name__ == "__main__":
@@ -136,6 +166,9 @@ def benchmark(batch_size, provider, N, K):
136166

137167
NK_model_names = get_weight_shapes(args)
138168
for N, K, model_name in NK_model_names:
169+
if N % 128 != 0 or K % 128 != 0:
170+
print(f"Skip {N=}, {K=} now")
171+
continue
139172
print(f"{model_name} N={N} K={K}: ")
140173
benchmark.run(
141174
print_data=True,

sgl-kernel/csrc/cutlass_extensions/gemm/collective/collective_builder.hpp

Lines changed: 0 additions & 125 deletions
This file was deleted.

0 commit comments

Comments
 (0)