Skip to content

Commit faff51d

Browse files
zhyncsAlcanderian
authored andcommitted
chore: upgrade sgl-kernel v0.1.2.post1 (sgl-project#6196)
Co-authored-by: alcanderian <[email protected]>
1 parent 277acb6 commit faff51d

File tree

5 files changed

+61
-71
lines changed

5 files changed

+61
-71
lines changed

python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ runtime_common = [
4848

4949
srt = [
5050
"sglang[runtime_common]",
51-
"sgl-kernel==0.1.1",
51+
"sgl-kernel==0.1.2.post1",
5252
"flashinfer_python==0.2.5",
5353
"torch==2.6.0",
5454
"torchvision==0.21.0",

python/sglang/srt/entrypoints/engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def _set_envs_and_config(server_args: ServerArgs):
486486
if _is_cuda:
487487
assert_pkg_version(
488488
"sgl-kernel",
489-
"0.1.1",
489+
"0.1.2.post1",
490490
"Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
491491
)
492492

python/sglang/srt/layers/quantization/deep_gemm.py

Lines changed: 57 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,7 @@
1616
import deep_gemm
1717
from deep_gemm import get_num_sms
1818
from deep_gemm.jit_kernels.gemm import get_best_configs
19-
from deep_gemm.jit_kernels.gemm import includes as deep_gemm_includes
20-
from deep_gemm.jit_kernels.gemm import template as deep_gemm_gemm_template
21-
from deep_gemm.jit_kernels.m_grouped_gemm import (
22-
template as deep_gemm_grouped_gemm_template,
23-
)
19+
from deep_gemm.jit_kernels.runtime import FP8GemmRuntime, GemmType
2420
from deep_gemm.jit_kernels.tuner import jit_tuner
2521

2622
sm_version = get_device_sm()
@@ -45,10 +41,15 @@ def get_enable_jit_deepgemm():
4541
_IN_PRECOMPILE_STAGE = get_bool_env_var("SGL_IN_DEEPGEMM_PRECOMPILE_STAGE", "false")
4642

4743
# Force redirect deep_gemm cache_dir
48-
os.environ["DG_CACHE_DIR"] = os.getenv(
49-
"SGL_DG_CACHE_DIR", os.path.expanduser("~") + "/.cache/deep_gemm"
44+
os.environ["DG_JIT_CACHE_DIR"] = os.getenv(
45+
"SGL_DG_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache", "deep_gemm")
5046
)
5147

48+
# Refer to https://github.com/deepseek-ai/DeepGEMM/commit/d75b218b7b8f4a5dd5406ac87905039ead3ae42f
49+
# NVRTC may have performance loss with some cases.
50+
# And NVCC JIT speed is also 9x faster in the ref commit
51+
os.environ["DG_JIT_USE_NVRTC"] = os.getenv("SGL_DG_USE_NVRTC", "0")
52+
5253

5354
def update_deep_gemm_config(gpu_id: int, server_args: ServerArgs):
5455
global _BUILTIN_M_LIST
@@ -130,10 +131,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
130131
num_groups: int,
131132
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
132133
) -> None:
133-
# Auto-tuning with compilation
134-
global deep_gemm_includes, deep_gemm_grouped_gemm_template
135-
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
136-
_ = jit_tuner.compile_and_tune(
134+
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
135+
block_k = 128
136+
num_tma_threads = 128
137+
num_math_threads_per_group = 128
138+
kwargs = {
139+
"NUM_TMA_THREADS": num_tma_threads,
140+
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
141+
"BLOCK_K": block_k,
142+
"NUM_SMS": num_sms,
143+
"SMEM_SIZE": smem_config[0],
144+
}
145+
_, _ = jit_tuner.compile_and_tune(
137146
name="m_grouped_gemm_fp8_fp8_bf16_nt",
138147
keys={
139148
"N": n,
@@ -146,24 +155,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
146155
"NUM_STAGES": num_stages,
147156
"NUM_TMA_MULTICAST": tma_multicast_config[0],
148157
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
149-
"GEMM_TYPE": "GroupedMasked",
158+
"GEMM_TYPE": GemmType.GroupedMasked,
150159
},
151160
space=(),
152-
includes=deep_gemm_includes,
153-
arg_defs=(
154-
("lhs", torch.float8_e4m3fn),
155-
("lhs_scales", torch.float),
156-
("rhs", torch.float8_e4m3fn),
157-
("rhs_scales", torch.float),
158-
("out", torch.bfloat16),
159-
("grouped_layout", torch.int32),
160-
("m", int),
161-
("stream", torch.cuda.Stream),
162-
("num_sms", int),
163-
("smem_size", int),
164-
),
165-
template=deep_gemm_grouped_gemm_template,
166-
args=[],
161+
kwargs=kwargs,
162+
runtime_cls=FP8GemmRuntime,
167163
)
168164

169165

@@ -173,9 +169,18 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
173169
num_groups: int,
174170
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
175171
) -> None:
176-
global deep_gemm_includes, deep_gemm_grouped_gemm_template
177-
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
178-
_ = jit_tuner.compile_and_tune(
172+
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
173+
block_k = 128
174+
num_tma_threads = 128
175+
num_math_threads_per_group = 128
176+
kwargs = {
177+
"NUM_TMA_THREADS": num_tma_threads,
178+
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
179+
"BLOCK_K": block_k,
180+
"NUM_SMS": num_sms,
181+
"SMEM_SIZE": smem_config[0],
182+
}
183+
_, _ = jit_tuner.compile_and_tune(
179184
name="m_grouped_gemm_fp8_fp8_bf16_nt",
180185
keys={
181186
"N": n,
@@ -188,25 +193,11 @@ def _compile_grouped_gemm_nt_f8f8bf16_contig_one(
188193
"NUM_STAGES": num_stages,
189194
"NUM_TMA_MULTICAST": tma_multicast_config[0],
190195
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
191-
"GEMM_TYPE": "GroupedContiguous",
196+
"GEMM_TYPE": GemmType.GroupedContiguous,
192197
},
193198
space=(),
194-
includes=deep_gemm_includes,
195-
arg_defs=(
196-
("lhs", torch.float8_e4m3fn),
197-
("lhs_scales", torch.float),
198-
("rhs", torch.float8_e4m3fn),
199-
("rhs_scales", torch.float),
200-
("out", torch.bfloat16),
201-
("grouped_layout", torch.int32),
202-
("m", int),
203-
("num_groups", int),
204-
("stream", torch.cuda.Stream),
205-
("num_sms", int),
206-
("smem_size", int),
207-
),
208-
template=deep_gemm_grouped_gemm_template,
209-
args=[],
199+
kwargs=kwargs,
200+
runtime_cls=FP8GemmRuntime,
210201
)
211202

212203

@@ -216,9 +207,20 @@ def _compile_gemm_nt_f8f8bf16_one(
216207
_: int, # _ is a dummy parameter to align with other interfaces
217208
config: Tuple[int, int, int, int, Tuple[int, bool], Tuple[int, int, int]],
218209
) -> None:
219-
global deep_gemm_includes, deep_gemm_gemm_template
220-
_, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
221-
_ = jit_tuner.compile_and_tune(
210+
num_sms, block_m, block_n, num_stages, tma_multicast_config, smem_config = config
211+
block_k = 128
212+
num_tma_threads = 128
213+
num_math_threads_per_group = 128
214+
kwargs = {
215+
"GEMM_TYPE": GemmType.Normal,
216+
"NUM_TMA_THREADS": num_tma_threads,
217+
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
218+
"NUM_GROUPS": 1,
219+
"BLOCK_K": block_k,
220+
"NUM_SMS": num_sms,
221+
"SMEM_SIZE": smem_config[0],
222+
}
223+
_, _ = jit_tuner.compile_and_tune(
222224
name="gemm_fp8_fp8_bf16_nt",
223225
keys={
224226
"N": n,
@@ -232,20 +234,8 @@ def _compile_gemm_nt_f8f8bf16_one(
232234
"IS_TMA_MULTICAST_ON_A": tma_multicast_config[1],
233235
},
234236
space=(),
235-
includes=deep_gemm_includes,
236-
arg_defs=(
237-
("lhs", torch.float8_e4m3fn),
238-
("lhs_scales", torch.float),
239-
("rhs", torch.float8_e4m3fn),
240-
("rhs_scales", torch.float),
241-
("out", torch.bfloat16),
242-
("m", int),
243-
("stream", torch.cuda.Stream),
244-
("num_sms", int),
245-
("smem_size", int),
246-
),
247-
template=deep_gemm_gemm_template,
248-
args=[],
237+
kwargs=kwargs,
238+
runtime_cls=FP8GemmRuntime,
249239
)
250240

251241

@@ -373,7 +363,7 @@ def _log_jit_build(M: int, N: int, K: int, kernel_type: DeepGemmKernelType):
373363

374364
from deep_gemm.jit.runtime import RuntimeCache
375365

376-
origin_func = RuntimeCache.__getitem__
366+
origin_func = RuntimeCache.get
377367

378368
def __patched_func(self, *args, **kwargs):
379369
ret = origin_func(self, *args, **kwargs)
@@ -385,6 +375,6 @@ def __patched_func(self, *args, **kwargs):
385375
)
386376
return ret
387377

388-
RuntimeCache.__getitem__ = __patched_func
378+
RuntimeCache.get = __patched_func
389379
yield
390-
RuntimeCache.__getitem__ = origin_func
380+
RuntimeCache.get = origin_func

scripts/ci_install_dependency.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ rm -rf /usr/local/lib/python3.10/dist-packages/sgl_kernel*
1616
pip install --upgrade pip
1717

1818
# Install sgl-kernel
19-
pip install sgl-kernel==0.1.1 --no-cache-dir
19+
pip install sgl-kernel==0.1.2.post1 --no-cache-dir
2020

2121
# Install the main package
2222
pip install -e "python[all]"

scripts/ci_install_dependency_8_gpu.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ rm -rf /usr/local/include/nvshmem*
3434
pip install --upgrade pip
3535

3636
# Install sgl-kernel
37-
pip install sgl-kernel==0.1.1 --no-cache-dir
37+
pip install sgl-kernel==0.1.2.post1 --no-cache-dir
3838

3939
# Install the main package
4040
pip install -e "python[all]"

0 commit comments

Comments
 (0)