Skip to content

Commit a086a11

Browse files
authored
Use sgl-kernel sgl_per_token_group_quant_int8 (#4971)
1 parent bdbe5f8 commit a086a11

File tree

2 files changed

+39
-2
lines changed

2 files changed

+39
-2
lines changed

python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,6 +755,9 @@ def invoke_fused_moe_kernel(
755755
from sglang.srt.layers.quantization.fp8_kernel import (
756756
sglang_per_token_group_quant_fp8,
757757
)
758+
from sglang.srt.layers.quantization.int8_kernel import (
759+
sglang_per_token_group_quant_int8,
760+
)
758761
else:
759762
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
760763

@@ -794,7 +797,10 @@ def invoke_fused_moe_kernel(
794797
# activation block-wise int8 quantization
795798
assert len(block_shape) == 2
796799
block_n, block_k = block_shape[0], block_shape[1]
797-
A, A_scale = per_token_group_quant_int8(A, block_k)
800+
if _is_cuda:
801+
A, A_scale = sglang_per_token_group_quant_int8(A, block_k)
802+
else:
803+
A, A_scale = per_token_group_quant_int8(A, block_k)
798804
assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
799805
assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
800806
assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]

python/sglang/srt/layers/quantization/int8_kernel.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,11 @@
88
import triton
99
import triton.language as tl
1010

11-
from sglang.srt.utils import get_device_name
11+
from sglang.srt.utils import get_device_name, is_cuda
12+
13+
_is_cuda = is_cuda()
14+
if _is_cuda:
15+
from sgl_kernel import sgl_per_token_group_quant_int8
1216

1317
logger = logging.getLogger(__name__)
1418

@@ -165,6 +169,33 @@ def per_token_group_quant_int8(
165169
return x_q, x_s
166170

167171

172+
def sglang_per_token_group_quant_int8(
173+
x: torch.Tensor,
174+
group_size: int,
175+
eps: float = 1e-10,
176+
dtype: torch.dtype = torch.int8,
177+
):
178+
assert (
179+
x.shape[-1] % group_size == 0
180+
), "the last dimension of `x` cannot be divisible by `group_size`"
181+
assert x.is_contiguous(), "`x` is not contiguous"
182+
183+
iinfo = torch.iinfo(dtype)
184+
int8_max = iinfo.max
185+
int8_min = iinfo.min
186+
187+
x_q = torch.empty_like(x, device=x.device, dtype=dtype)
188+
x_s = torch.empty(
189+
x.shape[:-1] + (x.shape[-1] // group_size,),
190+
device=x.device,
191+
dtype=torch.float32,
192+
)
193+
194+
sgl_per_token_group_quant_int8(x, x_q, x_s, group_size, eps, int8_min, int8_max)
195+
196+
return x_q, x_s
197+
198+
168199
@triton.jit
169200
def _w8a8_block_int8_matmul(
170201
# Pointers to inputs and output

0 commit comments

Comments
 (0)