diff --git a/python/sglang/srt/layers/quantization/fp8_kernel.py b/python/sglang/srt/layers/quantization/fp8_kernel.py index 54c07f9094..4ffeab8bab 100644 --- a/python/sglang/srt/layers/quantization/fp8_kernel.py +++ b/python/sglang/srt/layers/quantization/fp8_kernel.py @@ -29,10 +29,13 @@ _is_cuda = torch.cuda.is_available() and torch.version.cuda if _is_cuda: + import deep_gemm from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8 logger = logging.getLogger(__name__) +_enable_jit_deepgemm = int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "0")) + @triton.jit def _per_token_group_quant_fp8( @@ -722,34 +725,39 @@ def grid(META): num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv( N, config["BLOCK_SIZE_N"] ) - kernel = ( - _w8a8_block_fp8_matmul_unrolledx4 - if (is_hip_ == True and num_workgroups <= get_device_core_count()) - else _w8a8_block_fp8_matmul - ) - kernel[grid]( - A, - B, - C, - As, - Bs, - M, - N, - K, - block_n, - block_k, - A.stride(-2), - A.stride(-1), - B.stride(1), - B.stride(0), - C.stride(-2), - C.stride(-1), - As.stride(-2), - As.stride(-1), - Bs.stride(1), - Bs.stride(0), - **config, - ) + # deepgemm only support bf16 + if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm: + deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C) + else: + kernel = ( + _w8a8_block_fp8_matmul_unrolledx4 + if (is_hip_ == True and num_workgroups <= get_device_core_count()) + else _w8a8_block_fp8_matmul + ) + + kernel[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + **config, + ) return C diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py index b3da7690ce..25aaf498a0 100644 --- a/python/sglang/test/test_block_fp8.py +++ b/python/sglang/test/test_block_fp8.py @@ -1,4 +1,5 @@ import itertools +import os import unittest import torch @@ -11,6 +12,8 @@ w8a8_block_fp8_matmul, ) +_is_cuda = torch.cuda.is_available() and torch.version.cuda + # For test def native_per_token_group_quant_fp8( @@ -208,13 +211,35 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl class TestW8A8BlockFP8Matmul(unittest.TestCase): - OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16] - M = [1, 7, 83, 512, 2048] - N = [128, 512, 1024, 4096, 7748, 13824] - K = [256, 4096, 5120, 3884, 13824] - # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] - BLOCK_SIZE = [[128, 128]] - SEEDS = [0] + + if not _is_cuda: + OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16] + M = [1, 7, 83, 512, 2048] + NKs = [ + (N, K) + for N in [128, 512, 1024, 4096, 7748, 13824] + for K in [256, 4096, 5120, 3884, 13824] + ] + # BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]] + BLOCK_SIZE = [[128, 128]] + SEEDS = [0] + else: + # use practical shape in DeepSeek V3 for test + OUT_DTYPES = [torch.bfloat16] + M = [64, 128, 512, 1024, 4096] + NKs = [ + (1536, 7168), + (3072, 1536), + (24576, 7168), + (4096, 512), + (7168, 2048), + (4608, 7168), + (512, 7168), + (7168, 2304), + (7168, 512), + ] + BLOCK_SIZE = [[128, 128]] + SEEDS = [0] @classmethod def setUpClass(cls): @@ -222,7 +247,8 @@ def setUpClass(cls): raise unittest.SkipTest("CUDA is not available") torch.set_default_device("cuda") - def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed): + def _w8a8_block_fp8_matmul(self, M, NK, block_size, out_dtype, seed): + N, K = NK torch.manual_seed(seed) # NOTE(HandH1998): to avoid overflow when out_dtype = torch.half factor_for_scale = 1e-2 @@ -257,19 +283,17 @@ def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed): def test_w8a8_block_fp8_matmul(self): for params in itertools.product( self.M, - self.N, - self.K, + self.NKs, self.BLOCK_SIZE, self.OUT_DTYPES, self.SEEDS, ): with self.subTest( M=params[0], - N=params[1], - K=params[2], - block_size=params[3], - out_dtype=params[4], - seed=params[5], + NKs=params[1], + block_size=params[2], + out_dtype=params[3], + seed=params[4], ): self._w8a8_block_fp8_matmul(*params) diff --git a/test/srt/test_fp8_kernel.py b/test/srt/test_fp8_kernel.py index fe92bfd076..dcc5d42748 100644 --- a/test/srt/test_fp8_kernel.py +++ b/test/srt/test_fp8_kernel.py @@ -17,7 +17,7 @@ def setUpClass(cls): cls.K = 512 cls.group_size = 128 cls.quant_type = torch.float8_e4m3fn - cls.output_type = torch.float16 + cls.output_type = torch.bfloat16 @staticmethod def _make_A(M, K, group_size, out_dtype):