Skip to content

Commit 69583c9

Browse files
laixinnsleepcooHandH1998shuaillsyinfan98
authored andcommitted
DeepGemm integrate to sgl-kernel (sgl-project#4165)
Co-authored-by: sleepcoo <[email protected]> Co-authored-by: HandH1998 <[email protected]> Co-authored-by: shuaills <[email protected]> Co-authored-by: yinfan98 <[email protected]> Co-authored-by: Yineng Zhang <[email protected]>
1 parent 110b0ad commit 69583c9

File tree

6 files changed

+324
-5
lines changed

6 files changed

+324
-5
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77
[submodule "sgl-kernel/3rdparty/flashinfer"]
88
path = sgl-kernel/3rdparty/flashinfer
99
url = https://github.com/flashinfer-ai/flashinfer.git
10+
[submodule "sgl-kernel/3rdparty/deepgemm"]
11+
path = sgl-kernel/3rdparty/deepgemm
12+
url = https://github.com/deepseek-ai/DeepGEMM

sgl-kernel/3rdparty/deepgemm

Submodule deepgemm added at 5e4badc

sgl-kernel/build.sh

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ else
1111
fi
1212

1313
docker run --rm \
14-
-v "$(pwd)":/sgl-kernel \
14+
-v $(pwd):/sgl-kernel \
1515
pytorch/manylinux-builder:cuda${CUDA_VERSION} \
1616
bash -c "
1717
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir torch==2.5.1 --index-url https://download.pytorch.org/whl/cu${CUDA_VERSION//.} && \
18-
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja && \
18+
${PYTHON_ROOT_PATH}/bin/pip install --no-cache-dir ninja setuptools==75.0.0 wheel==0.41.0 numpy && \
1919
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0+PTX' && \
2020
export CUDA_VERSION=${CUDA_VERSION} && \
2121
export SGL_KERNEL_ENABLE_BF16=1 && \
@@ -24,5 +24,6 @@ docker run --rm \
2424
mkdir -p /usr/lib/x86_64-linux-gnu/ && \
2525
ln -s /usr/local/cuda-${CUDA_VERSION}/targets/x86_64-linux/lib/stubs/libcuda.so /usr/lib/x86_64-linux-gnu/libcuda.so && \
2626
cd /sgl-kernel && \
27-
${PYTHON_ROOT_PATH}/bin/python setup.py bdist_wheel
27+
ls -la ${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages/wheel/ && \
28+
PYTHONPATH=${PYTHON_ROOT_PATH}/lib/python${PYTHON_VERSION}/site-packages ${PYTHON_ROOT_PATH}/bin/python setup.py bdist_wheel
2829
"

sgl-kernel/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[build-system]
22
requires = [
3-
"setuptools>=61.0",
3+
"setuptools>=75.0",
44
"scikit-build-core>=0.10",
55
"torch==2.5.1",
66
"wheel",

sgl-kernel/setup.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
# ==============================================================================
1515

1616
import os
17+
import shutil
1718
import sys
1819
from pathlib import Path
1920

2021
import torch
2122
from setuptools import find_packages, setup
23+
from setuptools.command.build_py import build_py
2224
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
2325

2426
root = Path(__file__).parent.resolve()
@@ -52,6 +54,7 @@ def _get_version():
5254
cutlass_default = root / "3rdparty" / "cutlass"
5355
cutlass = Path(os.environ.get("CUSTOM_CUTLASS_SRC_DIR", default=cutlass_default))
5456
flashinfer = root / "3rdparty" / "flashinfer"
57+
deepgemm = root / "3rdparty" / "deepgemm"
5558
include_dirs = [
5659
root / "include",
5760
root / "csrc",
@@ -63,6 +66,51 @@ def _get_version():
6366
"cublas",
6467
]
6568

69+
70+
class CustomBuildPy(build_py):
71+
def run(self):
72+
self.copy_deepgemm_to_build_lib()
73+
self.make_jit_include_symlinks()
74+
build_py.run(self)
75+
76+
def make_jit_include_symlinks(self):
77+
# Make symbolic links of third-party include directories
78+
build_include_dir = os.path.join(self.build_lib, "deep_gemm/include")
79+
os.makedirs(build_include_dir, exist_ok=True)
80+
81+
third_party_include_dirs = [
82+
cutlass.resolve() / "include" / "cute",
83+
cutlass.resolve() / "include" / "cutlass",
84+
]
85+
86+
for d in third_party_include_dirs:
87+
dirname = str(d).split("/")[-1]
88+
src_dir = d
89+
dst_dir = f"{build_include_dir}/{dirname}"
90+
assert os.path.exists(src_dir)
91+
if os.path.exists(dst_dir):
92+
assert os.path.islink(dst_dir)
93+
os.unlink(dst_dir)
94+
os.symlink(src_dir, dst_dir, target_is_directory=True)
95+
96+
def copy_deepgemm_to_build_lib(self):
97+
"""
98+
This function copies DeepGemm to python's site-packages
99+
"""
100+
dst_dir = os.path.join(self.build_lib, "deep_gemm")
101+
os.makedirs(dst_dir, exist_ok=True)
102+
103+
# Copy deepgemm/deep_gemm to the build directory
104+
src_dir = os.path.join(str(deepgemm.resolve()), "deep_gemm")
105+
106+
# Remove existing directory if it exists
107+
if os.path.exists(dst_dir):
108+
shutil.rmtree(dst_dir)
109+
110+
# Copy the directory
111+
shutil.copytree(src_dir, dst_dir)
112+
113+
66114
nvcc_flags = [
67115
"-DNDEBUG",
68116
f"-DOPERATOR_NAMESPACE={operator_namespace}",
@@ -175,6 +223,9 @@ def _get_version():
175223
packages=find_packages(where="python"),
176224
package_dir={"": "python"},
177225
ext_modules=ext_modules,
178-
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)},
226+
cmdclass={
227+
"build_ext": BuildExtension.with_options(use_ninja=True),
228+
"build_py": CustomBuildPy,
229+
},
179230
options={"bdist_wheel": {"py_limited_api": "cp39"}},
180231
)

sgl-kernel/tests/test_deep_gemm.py

Lines changed: 263 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
import os
2+
import random
3+
import unittest
4+
from typing import Any, Tuple
5+
6+
import deep_gemm
7+
import torch
8+
from deep_gemm import calc_diff, ceil_div, get_col_major_tma_aligned_tensor, jit
9+
10+
"""
11+
fork deepgemm/tests/test_core.py
12+
"""
13+
14+
15+
def per_token_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
16+
assert x.dim() == 2 and x.size(1) % 128 == 0
17+
m, n = x.shape
18+
x_view = x.view(m, -1, 128)
19+
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
20+
return (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn).view(
21+
m, n
22+
), (x_amax / 448.0).view(m, -1)
23+
24+
25+
def per_block_cast_to_fp8(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
26+
assert x.dim() == 2
27+
m, n = x.shape
28+
x_padded = torch.zeros(
29+
(ceil_div(m, 128) * 128, ceil_div(n, 128) * 128), dtype=x.dtype, device=x.device
30+
)
31+
x_padded[:m, :n] = x
32+
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, 128)
33+
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
34+
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
35+
return x_scaled.view_as(x_padded)[:m, :n].contiguous(), (x_amax / 448.0).view(
36+
x_view.size(0), x_view.size(2)
37+
)
38+
39+
40+
def construct(m: int, k: int, n: int) -> Tuple[
41+
Tuple[torch.Tensor, torch.Tensor],
42+
Tuple[torch.Tensor, torch.Tensor],
43+
torch.Tensor,
44+
torch.Tensor,
45+
]:
46+
x = torch.randn((m, k), device="cuda", dtype=torch.bfloat16)
47+
y = torch.randn((n, k), device="cuda", dtype=torch.bfloat16)
48+
out = torch.empty((m, n), device="cuda", dtype=torch.bfloat16)
49+
ref_out = x @ y.t()
50+
51+
x_fp8, y_fp8 = per_token_cast_to_fp8(x), per_block_cast_to_fp8(y)
52+
# Transpose earlier so that the testing will not trigger transposing kernels
53+
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
54+
return x_fp8, y_fp8, out, ref_out
55+
56+
57+
def construct_grouped(
58+
num_groups: int, m: int, k: int, n: int, is_masked: bool
59+
) -> Tuple[
60+
Tuple[torch.Tensor, torch.Tensor],
61+
Tuple[torch.Tensor, torch.Tensor],
62+
torch.Tensor,
63+
torch.Tensor,
64+
]:
65+
x = torch.randn((num_groups, m, k), device="cuda", dtype=torch.bfloat16)
66+
y = torch.randn((num_groups, n, k), device="cuda", dtype=torch.bfloat16)
67+
out = torch.empty((num_groups, m, n), device="cuda", dtype=torch.bfloat16)
68+
ref_out = torch.einsum("gmk,gnk->gmn", x, y)
69+
70+
assert m % 4 == 0, f"TMA alignment error: {m}"
71+
x_fp8 = (
72+
torch.empty_like(x, dtype=torch.float8_e4m3fn),
73+
torch.empty((num_groups, m, k // 128), device="cuda", dtype=torch.float),
74+
)
75+
y_fp8 = (
76+
torch.empty_like(y, dtype=torch.float8_e4m3fn),
77+
torch.empty(
78+
(num_groups, (n + 127) // 128, k // 128), device="cuda", dtype=torch.float
79+
),
80+
)
81+
for i in range(num_groups):
82+
x_fp8[0][i], x_fp8[1][i] = per_token_cast_to_fp8(x[i])
83+
y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i])
84+
85+
# For non-masked input, we must merge the group and M dims
86+
if not is_masked:
87+
x_fp8 = (x_fp8[0].view(-1, k), per_token_cast_to_fp8(x.view(-1, k))[1])
88+
out, ref_out = out.view(-1, n), ref_out.view(-1, n)
89+
90+
# Transpose earlier so that the testing will not trigger transposing kernels
91+
x_fp8 = (x_fp8[0], get_col_major_tma_aligned_tensor(x_fp8[1]))
92+
return x_fp8, y_fp8, out, ref_out
93+
94+
95+
class TestDeepGemmCore(unittest.TestCase):
96+
@classmethod
97+
def setUpClass(cls):
98+
torch.backends.cuda.matmul.allow_tf32 = True
99+
torch.backends.cudnn.allow_tf32 = True
100+
torch.manual_seed(0)
101+
random.seed(0)
102+
103+
print("Library path:")
104+
print(f" > {deep_gemm.__path__}\n")
105+
106+
def test_gemm(self):
107+
print("Testing GEMM:")
108+
for m in (64, 128, 4096):
109+
for k, n in [
110+
(7168, 2112),
111+
(1536, 24576),
112+
(512, 32768),
113+
(16384, 7168),
114+
(7168, 4096),
115+
(2048, 7168),
116+
]:
117+
x_fp8, y_fp8, out, ref_out = construct(m, k, n)
118+
deep_gemm.gemm_fp8_fp8_bf16_nt(x_fp8, y_fp8, out)
119+
diff = calc_diff(out, ref_out)
120+
self.assertTrue(diff < 0.001, f"{m=}, {k=}, {n=}, {diff:.5f}")
121+
122+
def test_m_grouped_gemm_contiguous(self):
123+
print("Testing grouped contiguous GEMM:")
124+
125+
for num_groups, m, k, n in (
126+
(4, 8192, 7168, 4096),
127+
(4, 8192, 2048, 7168),
128+
(8, 4096, 7168, 4096),
129+
(8, 4096, 2048, 7168),
130+
):
131+
# TODO: make a stronger test
132+
x_fp8, y_fp8, out, ref_out = construct_grouped(
133+
num_groups, m, k, n, is_masked=False
134+
)
135+
m_indices = torch.arange(0, num_groups, device="cuda", dtype=torch.int)
136+
m_indices = (
137+
m_indices.unsqueeze(-1).expand(num_groups, m).contiguous().view(-1)
138+
)
139+
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
140+
x_fp8, y_fp8, out, m_indices
141+
)
142+
diff = calc_diff(out, ref_out)
143+
self.assertTrue(diff < 0.001, f"m={m * num_groups}, {k=}, {n=}, {diff:.5f}")
144+
145+
def test_m_grouped_gemm_masked(self):
146+
print("Testing grouped masked GEMM:")
147+
148+
for num_groups, m in ((1, 1024), (2, 512), (4, 256)):
149+
for k, n in (
150+
(7168, 4096),
151+
(2048, 7168),
152+
):
153+
# Test correctness
154+
masked_m_candidates = list(
155+
filter(
156+
lambda candidate: candidate <= m, (64, 128, 192, 256, 320, 384)
157+
)
158+
)
159+
for i in range(10):
160+
x_fp8, y_fp8, out, ref_out = construct_grouped(
161+
num_groups, m, k, n, is_masked=True
162+
)
163+
masked_m = torch.empty(
164+
(num_groups,), device="cuda", dtype=torch.int
165+
)
166+
for j in range(num_groups):
167+
masked_m[j] = random.choice(masked_m_candidates)
168+
expected_m = min(int(masked_m.float().mean()) + 1, m)
169+
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
170+
x_fp8, y_fp8, out, masked_m, expected_m
171+
)
172+
for j in range(num_groups):
173+
diff = calc_diff(
174+
out[j, : masked_m[j].item()],
175+
ref_out[j, : masked_m[j].item()],
176+
)
177+
self.assertTrue(
178+
diff < 0.001,
179+
f"{m=}, {k=}, {n=}, {j=}, masked_m={masked_m[j]}, {num_groups=}, {diff:.5f}",
180+
)
181+
182+
183+
"""
184+
fork deepgemm/tests/test_jit.py
185+
"""
186+
187+
188+
class Capture:
189+
def __init__(self) -> None:
190+
self.read_fd = None
191+
self.write_fd = None
192+
self.saved_stdout = None
193+
self.captured = None
194+
195+
def __enter__(self) -> Any:
196+
self.read_fd, self.write_fd = os.pipe()
197+
self.saved_stdout = os.dup(1)
198+
os.dup2(self.write_fd, 1)
199+
return self
200+
201+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
202+
os.dup2(self.saved_stdout, 1)
203+
os.close(self.write_fd)
204+
with os.fdopen(self.read_fd, "r") as f:
205+
self.captured = f.read()
206+
207+
def capture(self) -> str:
208+
return self.captured
209+
210+
211+
class TestDeepGemmJIT(unittest.TestCase):
212+
def test_jit(self):
213+
# Runtime
214+
print(f"NVCC compiler: {jit.get_nvcc_compiler()}\n")
215+
216+
# Templates
217+
print("Generated code:")
218+
args = (
219+
("lhs", torch.float8_e4m3fn),
220+
("rhs", torch.float8_e4m3fn),
221+
("scale", torch.float),
222+
("out", torch.bfloat16),
223+
("enable_double_streams", bool),
224+
("stream", torch.cuda.Stream),
225+
)
226+
body = "\n"
227+
body += "std::cout << reinterpret_cast<uint64_t>(lhs) << std::endl;\n"
228+
body += "std::cout << reinterpret_cast<uint64_t>(rhs) << std::endl;\n"
229+
body += "std::cout << reinterpret_cast<uint64_t>(scale) << std::endl;\n"
230+
body += "std::cout << reinterpret_cast<uint64_t>(out) << std::endl;\n"
231+
body += "std::cout << enable_double_streams << std::endl;\n"
232+
body += "std::cout << reinterpret_cast<uint64_t>(stream) << std::endl;\n"
233+
code = jit.generate((), args, body)
234+
print(code)
235+
236+
# Build
237+
print("Building ...")
238+
func = jit.build("test_func", args, code)
239+
240+
# Test correctness
241+
print("Running ...")
242+
fp8_tensor = torch.empty((1,), dtype=torch.float8_e4m3fn, device="cuda")
243+
fp32_tensor = torch.empty((1,), dtype=torch.float, device="cuda")
244+
bf16_tensor = torch.empty((1,), dtype=torch.bfloat16, device="cuda")
245+
with Capture() as capture:
246+
self.assertTrue(
247+
func(
248+
fp8_tensor,
249+
fp8_tensor,
250+
fp32_tensor,
251+
bf16_tensor,
252+
True,
253+
torch.cuda.current_stream(),
254+
)
255+
== 0
256+
)
257+
output = capture.capture()
258+
ref_output = f"{fp8_tensor.data_ptr()}\n{fp8_tensor.data_ptr()}\n{fp32_tensor.data_ptr()}\n{bf16_tensor.data_ptr()}\n1\n{torch.cuda.current_stream().cuda_stream}\n"
259+
self.assertTrue(output == ref_output, f"{output=}, {ref_output=}")
260+
261+
262+
if __name__ == "__main__":
263+
unittest.main()

0 commit comments

Comments
 (0)