Skip to content

Commit 154b862

Browse files
zhyncsgrimoire
authored andcommitted
feat: add DeepGEMM build warning (sgl-project#5176)
Co-authored-by: grimoire <[email protected]>
1 parent b38ba6e commit 154b862

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

python/sglang/srt/layers/quantization/fp8_kernel.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import logging
1818
import os
19+
from contextlib import contextmanager
1920
from typing import Any, Dict, List, Optional, Tuple
2021

2122
import torch
@@ -59,7 +60,10 @@ def deep_gemm_fp8_fp8_bf16_nt(
5960
Bs: torch.Tensor,
6061
C: torch.Tensor,
6162
) -> None:
62-
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
63+
M, K = A.shape
64+
N, _ = B.shape
65+
with _log_jit_build(M, N, K):
66+
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
6367

6468
def deep_gemm_fp8_fp8_bf16_nt_fake(
6569
A: torch.Tensor,
@@ -708,6 +712,25 @@ def get_w8a8_block_fp8_configs(
708712
return None
709713

710714

715+
@contextmanager
716+
def _log_jit_build(M: int, N: int, K: int):
717+
from deep_gemm.jit.runtime import RuntimeCache
718+
719+
origin_func = RuntimeCache.__getitem__
720+
721+
def __patched_func(self, *args, **kwargs):
722+
ret = origin_func(self, *args, **kwargs)
723+
if ret is None:
724+
logger.warning(
725+
f"DeepGEMM JIT code generation <gemm_fp8_fp8_bf16_nt>: M={M}, N={N}, K={K}. Please wait."
726+
)
727+
return ret
728+
729+
RuntimeCache.__getitem__ = __patched_func
730+
yield
731+
RuntimeCache.__getitem__ = origin_func
732+
733+
711734
def w8a8_block_fp8_matmul(
712735
A: torch.Tensor,
713736
B: torch.Tensor,
@@ -782,7 +805,8 @@ def grid(META):
782805
if supports_custom_op():
783806
torch.ops.sglang.deep_gemm_fp8_fp8_bf16_nt(A, As, B, Bs, C)
784807
else:
785-
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
808+
with _log_jit_build(M, N, K):
809+
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
786810
else:
787811
kernel = (
788812
_w8a8_block_fp8_matmul_unrolledx4

0 commit comments

Comments
 (0)