2
2
import copy
3
3
import itertools
4
4
5
+ import deep_gemm
5
6
import torch
6
7
import triton
8
+ from deep_gemm import get_col_major_tma_aligned_tensor
7
9
from sgl_kernel import fp8_blockwise_scaled_mm
8
10
from vllm ._custom_ops import cutlass_scaled_mm as vllm_scaled_mm
9
11
12
+ from sglang .srt .layers .quantization .fp8_kernel import w8a8_block_fp8_matmul
13
+
10
14
11
15
def get_weight_shapes (args ):
12
16
models_tps = list (itertools .product (args .models , args .tp_sizes ))
13
17
# NOTE(HandH1998): The weight shapes only works for DeepSeek-V3. Modify them, if you tune for another different model.
14
18
# cannot TP
15
19
total = [
16
- # (512 + 64, 7168), # this weight is not supported by current kernel
20
+ (512 + 64 , 7168 ),
17
21
((128 + 64 ) * 128 , 7168 ),
18
22
(128 * (128 + 128 ), 512 ),
19
23
(7168 , 16384 ),
@@ -52,6 +56,23 @@ def cdiv(a: int, b: int) -> int:
52
56
return - (a // - b )
53
57
54
58
59
+ def fp8_gemm_deepgemm (
60
+ x_fp8 : torch .Tensor ,
61
+ x_scale : torch .Tensor ,
62
+ y_fp8 : torch .Tensor ,
63
+ y_scale : torch .Tensor ,
64
+ m : int ,
65
+ n : int ,
66
+ k : int ,
67
+ ):
68
+ """DeepGEMM implementation of FP8 GEMM"""
69
+ out = torch .empty ((m , n ), device = "cuda" , dtype = torch .bfloat16 )
70
+
71
+ # Run DeepGEMM kernel
72
+ deep_gemm .gemm_fp8_fp8_bf16_nt ((x_fp8 , x_scale ), (y_fp8 , y_scale ), out )
73
+ return out
74
+
75
+
55
76
def scale_shape (shape , group_shape ):
56
77
assert len (shape ) == len (group_shape )
57
78
return tuple (cdiv (shape [i ], group_shape [i ]) for i in range (len (group_shape )))
@@ -60,12 +81,12 @@ def scale_shape(shape, group_shape):
60
81
@triton .testing .perf_report (
61
82
triton .testing .Benchmark (
62
83
x_names = ["batch_size" ],
63
- x_vals = [1 , 16 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 ],
84
+ x_vals = [1 , 8 , 16 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 , 4096 ],
64
85
x_log = False ,
65
86
line_arg = "provider" ,
66
- line_vals = ["vllm" , "sgl-kernel" ],
67
- line_names = ["vllm fp8 blockwise gemm " , "sgl-kernel fp8 blockwise gemm " ],
68
- styles = [("blue" , "-" ), ("orange" , "-" )],
87
+ line_vals = ["vllm" , "sgl-kernel" , "triton" , "deepgemm" ],
88
+ line_names = ["vllm" , "sgl-kernel" , "sglang triton" , "deepgemm " ],
89
+ styles = [("blue" , "-" ), ("orange" , "-" ), ( "red" , "-" ), ( "yellow" , "-" ) ],
69
90
ylabel = "GB/s" ,
70
91
plot_name = "fp8 blockwise scaled matmul" ,
71
92
args = {},
@@ -80,7 +101,7 @@ def benchmark(batch_size, provider, N, K):
80
101
a_fp8 = a_fp32 .clamp (min = fp8_min , max = fp8_max ).to (torch .float8_e4m3fn )
81
102
82
103
b_fp32 = (torch .rand (N , K , dtype = torch .float32 , device = "cuda" ) - 0.5 ) * 2 * fp8_max
83
- b_fp8 = b_fp32 .clamp (min = fp8_min , max = fp8_max ).to (torch .float8_e4m3fn ). t ()
104
+ b_fp8 = b_fp32 .clamp (min = fp8_min , max = fp8_max ).to (torch .float8_e4m3fn )
84
105
85
106
scale_a_group_shape = (1 , 128 )
86
107
scale_b_group_shape = (128 , 128 )
@@ -89,31 +110,40 @@ def benchmark(batch_size, provider, N, K):
89
110
90
111
scale_a = torch .randn (scale_a_shape , device = "cuda" , dtype = torch .float32 )
91
112
scale_b = torch .randn (scale_b_shape , device = "cuda" , dtype = torch .float32 )
92
- scale_a = scale_a .t ().contiguous ().t ()
93
- scale_b = scale_b .t ().contiguous ().t ()
94
113
95
114
quantiles = [0.5 , 0.2 , 0.8 ]
96
115
if provider == "sgl-kernel" :
116
+ scale_a = scale_a .t ().contiguous ().t ()
117
+ b_fp8 , scale_b = b_fp8 .t (), scale_b .t ()
97
118
ms , min_ms , max_ms = triton .testing .do_bench (
98
119
lambda : fp8_blockwise_scaled_mm (
99
120
a_fp8 , b_fp8 , scale_a , scale_b , torch .float16
100
121
),
101
122
quantiles = quantiles ,
102
123
)
103
124
if provider == "vllm" :
125
+ scale_a = scale_a .t ().contiguous ().t ()
126
+ b_fp8 , scale_b = b_fp8 .t (), scale_b .t ()
104
127
ms , min_ms , max_ms = triton .testing .do_bench (
105
128
lambda : vllm_scaled_mm (a_fp8 , b_fp8 , scale_a , scale_b , torch .float16 ),
106
129
quantiles = quantiles ,
107
130
)
108
- gbps = (
109
- lambda ms : (
110
- (2 * M * N * K - M * N ) * a_fp8 .element_size ()
111
- + (3 * M * N ) * scale_a .element_size ()
131
+ if provider == "triton" :
132
+ ms , min_ms , max_ms = triton .testing .do_bench (
133
+ lambda : w8a8_block_fp8_matmul (
134
+ a_fp8 , b_fp8 , scale_a , scale_b , [128 , 128 ], torch .float16
135
+ ),
136
+ quantiles = quantiles ,
137
+ )
138
+ if provider == "deepgemm" :
139
+ scale_a_col_major = get_col_major_tma_aligned_tensor (scale_a .clone ())
140
+ ms , min_ms , max_ms = triton .testing .do_bench (
141
+ lambda : fp8_gemm_deepgemm (
142
+ a_fp8 , scale_a_col_major , b_fp8 , scale_b , M , N , K
143
+ ),
144
+ quantiles = quantiles ,
112
145
)
113
- * 1e-9
114
- / (ms * 1e-3 )
115
- )
116
- return gbps (ms ), gbps (max_ms ), gbps (min_ms )
146
+ return ms * 1000 , max_ms * 1000 , min_ms * 1000 # convert to ms
117
147
118
148
119
149
if __name__ == "__main__" :
@@ -136,6 +166,9 @@ def benchmark(batch_size, provider, N, K):
136
166
137
167
NK_model_names = get_weight_shapes (args )
138
168
for N , K , model_name in NK_model_names :
169
+ if N % 128 != 0 or K % 128 != 0 :
170
+ print (f"Skip { N = } , { K = } now" )
171
+ continue
139
172
print (f"{ model_name } N={ N } K={ K } : " )
140
173
benchmark .run (
141
174
print_data = True ,
0 commit comments