Skip to content

linear support deepgemm #4199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Mar 11, 2025
64 changes: 36 additions & 28 deletions python/sglang/srt/layers/quantization/fp8_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@

_is_cuda = torch.cuda.is_available() and torch.version.cuda
if _is_cuda:
import deep_gemm
from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8

logger = logging.getLogger(__name__)

_enable_jit_deepgemm = int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "0"))


@triton.jit
def _per_token_group_quant_fp8(
Expand Down Expand Up @@ -722,34 +725,39 @@ def grid(META):
num_workgroups = triton.cdiv(M, config["BLOCK_SIZE_M"]) * triton.cdiv(
N, config["BLOCK_SIZE_N"]
)
kernel = (
_w8a8_block_fp8_matmul_unrolledx4
if (is_hip_ == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul
)

kernel[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)
# deepgemm only support bf16
if _is_cuda and C.dtype == torch.bfloat16 and _enable_jit_deepgemm:
deep_gemm.gemm_fp8_fp8_bf16_nt((A, As), (B, Bs), C)
else:
kernel = (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a duplicate. Please delete line 726 of the code.

_w8a8_block_fp8_matmul_unrolledx4
if (is_hip_ == True and num_workgroups <= get_device_core_count())
else _w8a8_block_fp8_matmul
)

kernel[grid](
A,
B,
C,
As,
Bs,
M,
N,
K,
block_n,
block_k,
A.stride(-2),
A.stride(-1),
B.stride(1),
B.stride(0),
C.stride(-2),
C.stride(-1),
As.stride(-2),
As.stride(-1),
Bs.stride(1),
Bs.stride(0),
**config,
)

return C
54 changes: 39 additions & 15 deletions python/sglang/test/test_block_fp8.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import os
import unittest

import torch
Expand All @@ -11,6 +12,8 @@
w8a8_block_fp8_matmul,
)

_is_cuda = torch.cuda.is_available() and torch.version.cuda


# For test
def native_per_token_group_quant_fp8(
Expand Down Expand Up @@ -208,21 +211,44 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl


class TestW8A8BlockFP8Matmul(unittest.TestCase):
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
M = [1, 7, 83, 512, 2048]
N = [128, 512, 1024, 4096, 7748, 13824]
K = [256, 4096, 5120, 3884, 13824]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]

if not _is_cuda:
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
M = [1, 7, 83, 512, 2048]
NKs = [
(N, K)
for N in [128, 512, 1024, 4096, 7748, 13824]
for K in [256, 4096, 5120, 3884, 13824]
]
# BLOCK_SIZE = [[64, 64], [64, 128], [128, 64], [128, 128]]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]
else:
# use practical shape in DeepSeek V3 for test
OUT_DTYPES = [torch.bfloat16]
M = [64, 128, 512, 1024, 4096]
NKs = [
(1536, 7168),
(3072, 1536),
(24576, 7168),
(4096, 512),
(7168, 2048),
(4608, 7168),
(512, 7168),
(7168, 2304),
(7168, 512),
]
BLOCK_SIZE = [[128, 128]]
SEEDS = [0]

@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
torch.set_default_device("cuda")

def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed):
def _w8a8_block_fp8_matmul(self, M, NK, block_size, out_dtype, seed):
N, K = NK
torch.manual_seed(seed)
# NOTE(HandH1998): to avoid overflow when out_dtype = torch.half
factor_for_scale = 1e-2
Expand Down Expand Up @@ -257,19 +283,17 @@ def _w8a8_block_fp8_matmul(self, M, N, K, block_size, out_dtype, seed):
def test_w8a8_block_fp8_matmul(self):
for params in itertools.product(
self.M,
self.N,
self.K,
self.NKs,
self.BLOCK_SIZE,
self.OUT_DTYPES,
self.SEEDS,
):
with self.subTest(
M=params[0],
N=params[1],
K=params[2],
block_size=params[3],
out_dtype=params[4],
seed=params[5],
NKs=params[1],
block_size=params[2],
out_dtype=params[3],
seed=params[4],
):
self._w8a8_block_fp8_matmul(*params)

Expand Down
2 changes: 1 addition & 1 deletion test/srt/test_fp8_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def setUpClass(cls):
cls.K = 512
cls.group_size = 128
cls.quant_type = torch.float8_e4m3fn
cls.output_type = torch.float16
cls.output_type = torch.bfloat16

@staticmethod
def _make_A(M, K, group_size, out_dtype):
Expand Down
Loading