16
16
import deep_gemm
17
17
from deep_gemm import get_num_sms
18
18
from deep_gemm .jit_kernels .gemm import get_best_configs
19
- from deep_gemm .jit_kernels .gemm import includes as deep_gemm_includes
20
- from deep_gemm .jit_kernels .gemm import template as deep_gemm_gemm_template
21
- from deep_gemm .jit_kernels .m_grouped_gemm import (
22
- template as deep_gemm_grouped_gemm_template ,
23
- )
19
+ from deep_gemm .jit_kernels .runtime import FP8GemmRuntime , GemmType
24
20
from deep_gemm .jit_kernels .tuner import jit_tuner
25
21
26
22
sm_version = get_device_sm ()
@@ -45,10 +41,15 @@ def get_enable_jit_deepgemm():
45
41
_IN_PRECOMPILE_STAGE = get_bool_env_var ("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE" , "false" )
46
42
47
43
# Force redirect deep_gemm cache_dir
48
- os .environ ["DG_CACHE_DIR " ] = os .getenv (
49
- "SGL_DG_CACHE_DIR" , os .path .expanduser ("~" ) + "/ .cache/ deep_gemm"
44
+ os .environ ["DG_JIT_CACHE_DIR " ] = os .getenv (
45
+ "SGL_DG_CACHE_DIR" , os .path .join ( os . path . expanduser ("~" ), " .cache" , " deep_gemm")
50
46
)
51
47
48
+ # Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
49
+ # NVRTC may have performance loss with some cases.
50
+ # And NVCC JIT speed is also 9x faster in the ref commit
51
+ os .environ ["DG_JIT_USE_NVRTC" ] = os .getenv ("SGL_DG_USE_NVRTC" , "0" )
52
+
52
53
53
54
def update_deep_gemm_config (gpu_id : int , server_args : ServerArgs ):
54
55
global _BUILTIN_M_LIST
@@ -130,10 +131,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
130
131
num_groups : int ,
131
132
config : Tuple [int , int , int , int , Tuple [int , bool ], Tuple [int , int , int ]],
132
133
) -> None :
133
- # Auto-tuning with compilation
134
- global deep_gemm_includes , deep_gemm_grouped_gemm_template
135
- _ , block_m , block_n , num_stages , tma_multicast_config , smem_config = config
136
- _ = jit_tuner .compile_and_tune (
134
+ num_sms , block_m , block_n , num_stages , tma_multicast_config , smem_config = config
135
+ block_k = 128
136
+ num_tma_threads = 128
137
+ num_math_threads_per_group = 128
138
+ kwargs = {
139
+ "NUM_TMA_THREADS" : num_tma_threads ,
140
+ "NUM_MATH_THREADS_PER_GROUP" : num_math_threads_per_group ,
141
+ "BLOCK_K" : block_k ,
142
+ "NUM_SMS" : num_sms ,
143
+ "SMEM_SIZE" : smem_config [0 ],
144
+ }
145
+ _ , _ = jit_tuner .compile_and_tune (
137
146
name = "m_grouped_gemm_fp8_fp8_bf16_nt" ,
138
147
keys = {
139
148
"N" : n ,
@@ -146,24 +155,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
146
155
"NUM_STAGES" : num_stages ,
147
156
"NUM_TMA_MULTICAST" : tma_multicast_config [0 ],
148
157
"IS_TMA_MULTICAST_ON_A" : tma_multicast_config [1 ],
149
- "GEMM_TYPE" : " GroupedMasked" ,
158
+ "GEMM_TYPE" : GemmType . GroupedMasked ,
150
159
},
151
160
space = (),
152
- includes = deep_gemm_includes ,
153
- arg_defs = (
154
- ("lhs" , torch .float8_e4m3fn ),
155
- ("lhs_scales" , torch .float ),
156
- ("rhs" , torch .float8_e4m3fn ),
157
- ("rhs_scales" , torch .float ),
158
- ("out" , torch .bfloat16 ),
159
- ("grouped_layout" , torch .int32 ),
160
- ("m" , int ),
161
- ("stream" , torch .cuda .Stream ),
162
- ("num_sms" , int ),
163
- ("smem_size" , int ),
164
- ),
165
- template = deep_gemm_grouped_gemm_template ,
166
- args = [],
161
+ kwargs = kwargs ,
162
+ runtime_cls = FP8GemmRuntime ,
167
163
)
168
164
169
165
@@ -173,9 +169,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
173
169
num_groups : int ,
174
170
config : Tuple [int , int , int , int , Tuple [int , bool ], Tuple [int , int , int ]],
175
171
) -> None :
176
- global deep_gemm_includes , deep_gemm_grouped_gemm_template
177
- _ , block_m , block_n , num_stages , tma_multicast_config , smem_config = config
178
- _ = jit_tuner .compile_and_tune (
172
+ num_sms , block_m , block_n , num_stages , tma_multicast_config , smem_config = config
173
+ block_k = 128
174
+ num_tma_threads = 128
175
+ num_math_threads_per_group = 128
176
+ kwargs = {
177
+ "NUM_TMA_THREADS" : num_tma_threads ,
178
+ "NUM_MATH_THREADS_PER_GROUP" : num_math_threads_per_group ,
179
+ "BLOCK_K" : block_k ,
180
+ "NUM_SMS" : num_sms ,
181
+ "SMEM_SIZE" : smem_config [0 ],
182
+ }
183
+ _ , _ = jit_tuner .compile_and_tune (
179
184
name = "m_grouped_gemm_fp8_fp8_bf16_nt" ,
180
185
keys = {
181
186
"N" : n ,
@@ -188,25 +193,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
188
193
"NUM_STAGES" : num_stages ,
189
194
"NUM_TMA_MULTICAST" : tma_multicast_config [0 ],
190
195
"IS_TMA_MULTICAST_ON_A" : tma_multicast_config [1 ],
191
- "GEMM_TYPE" : " GroupedContiguous" ,
196
+ "GEMM_TYPE" : GemmType . GroupedContiguous ,
192
197
},
193
198
space = (),
194
- includes = deep_gemm_includes ,
195
- arg_defs = (
196
- ("lhs" , torch .float8_e4m3fn ),
197
- ("lhs_scales" , torch .float ),
198
- ("rhs" , torch .float8_e4m3fn ),
199
- ("rhs_scales" , torch .float ),
200
- ("out" , torch .bfloat16 ),
201
- ("grouped_layout" , torch .int32 ),
202
- ("m" , int ),
203
- ("num_groups" , int ),
204
- ("stream" , torch .cuda .Stream ),
205
- ("num_sms" , int ),
206
- ("smem_size" , int ),
207
- ),
208
- template = deep_gemm_grouped_gemm_template ,
209
- args = [],
199
+ kwargs = kwargs ,
200
+ runtime_cls = FP8GemmRuntime ,
210
201
)
211
202
212
203
@@ -216,9 +207,20 @@ def _compile_gemm_nt_f8f8bf16_one(
216
207
_ : int , # _ is a dummy parameter to align with other interfaces
217
208
config : Tuple [int , int , int , int , Tuple [int , bool ], Tuple [int , int , int ]],
218
209
) -> None :
219
- global deep_gemm_includes , deep_gemm_gemm_template
220
- _ , block_m , block_n , num_stages , tma_multicast_config , smem_config = config
221
- _ = jit_tuner .compile_and_tune (
210
+ num_sms , block_m , block_n , num_stages , tma_multicast_config , smem_config = config
211
+ block_k = 128
212
+ num_tma_threads = 128
213
+ num_math_threads_per_group = 128
214
+ kwargs = {
215
+ "GEMM_TYPE" : GemmType .Normal ,
216
+ "NUM_TMA_THREADS" : num_tma_threads ,
217
+ "NUM_MATH_THREADS_PER_GROUP" : num_math_threads_per_group ,
218
+ "NUM_GROUPS" : 1 ,
219
+ "BLOCK_K" : block_k ,
220
+ "NUM_SMS" : num_sms ,
221
+ "SMEM_SIZE" : smem_config [0 ],
222
+ }
223
+ _ , _ = jit_tuner .compile_and_tune (
222
224
name = "gemm_fp8_fp8_bf16_nt" ,
223
225
keys = {
224
226
"N" : n ,
@@ -232,20 +234,8 @@ def _compile_gemm_nt_f8f8bf16_one(
232
234
"IS_TMA_MULTICAST_ON_A" : tma_multicast_config [1 ],
233
235
},
234
236
space = (),
235
- includes = deep_gemm_includes ,
236
- arg_defs = (
237
- ("lhs" , torch .float8_e4m3fn ),
238
- ("lhs_scales" , torch .float ),
239
- ("rhs" , torch .float8_e4m3fn ),
240
- ("rhs_scales" , torch .float ),
241
- ("out" , torch .bfloat16 ),
242
- ("m" , int ),
243
- ("stream" , torch .cuda .Stream ),
244
- ("num_sms" , int ),
245
- ("smem_size" , int ),
246
- ),
247
- template = deep_gemm_gemm_template ,
248
- args = [],
237
+ kwargs = kwargs ,
238
+ runtime_cls = FP8GemmRuntime ,
249
239
)
250
240
251
241
@@ -373,7 +363,7 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
373
363
374
364
from deep_gemm .jit .runtime import RuntimeCache
375
365
376
- origin_func = RuntimeCache .__getitem__
366
+ origin_func = RuntimeCache .get
377
367
378
368
def __patched_func (self , * args , ** kwargs ):
379
369
ret = origin_func (self , * args , ** kwargs )
@@ -385,6 +375,6 @@ def __patched_func(self, *args, **kwargs):
385
375
)
386
376
return ret
387
377
388
- RuntimeCache .__getitem__ = __patched_func
378
+ RuntimeCache .get = __patched_func
389
379
yield
390
- RuntimeCache .__getitem__ = origin_func
380
+ RuntimeCache .get = origin_func
0 commit comments