Skip to content

Commit 9bf1d21

Browse files
authored
Merge branch 'main' into zhyncs/upd
2 parents 1613df1 + 8aab7fd commit 9bf1d21

File tree

15 files changed

+649
-19
lines changed

15 files changed

+649
-19
lines changed

.github/workflows/pr-test-sgl-kernel.yml

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,6 @@ jobs:
4444
cuda-version: '12.8'
4545
name: Build Wheel (CUDA ${{ matrix.cuda-version }})
4646
steps:
47-
- name: Skip unnecessary builds on push to main
48-
if: github.event_name == 'push' && (matrix.cuda-version == '11.8' || matrix.cuda-version == '12.8')
49-
run: |
50-
echo "Skipping CUDA ${{ matrix.cuda-version }} build on push to main"
51-
exit 0
52-
5347
- name: Cleanup
5448
run: |
5549
sudo rm -rf $GITHUB_WORKSPACE/* || true
@@ -64,6 +58,7 @@ jobs:
6458
python-version: ${{ matrix.python-version }}
6559

6660
- name: Build wheel for Python ${{ matrix.python-version }} and CUDA ${{ matrix.cuda-version }}
61+
if: github.event_name != 'push' || (matrix.cuda-version != '11.8' && matrix.cuda-version != '12.8')
6762
run: |
6863
cd sgl-kernel
6964
chmod +x ./build.sh

docker/Dockerfile.blackwell

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ WORKDIR /sgl-workspace
66

77
RUN pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu128
88

9-
RUN pip3 install https://github.com/sgl-project/whl/releases/download/v0.0.8.post3/sgl_kernel-0.0.8.post3+cu128-cp39-abi3-manylinux2014_x86_64.whl \
9+
RUN pip3 install https://github.com/sgl-project/whl/releases/download/v0.0.9/sgl_kernel-0.0.9+cu128-cp39-abi3-manylinux2014_x86_64.whl \
1010
&& pip3 install setuptools==75.0.0 wheel==0.41.0 scikit-build-core
1111

1212
RUN git clone --depth=1 https://github.com/sgl-project/sglang.git \

python/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ runtime_common = [
4747

4848
srt = [
4949
"sglang[runtime_common]",
50-
"sgl-kernel==0.0.8.post3",
50+
"sgl-kernel==0.0.9",
5151
"flashinfer_python==0.2.3",
5252
"torch==2.5.1",
5353
"torchvision==0.20.1",

scripts/ci_install_dependency.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ pip install --upgrade pip
2020

2121
# Install flashinfer and sgl-kernel
2222
pip install flashinfer_python==0.2.3 --find-links ${FLASHINFER_REPO} --no-cache-dir
23-
pip install sgl-kernel==0.0.8.post3 --no-cache-dir
23+
pip install sgl-kernel==0.0.9 --no-cache-dir
2424

2525
# Install the main package
2626
pip install -e "python[all]" --find-links ${FLASHINFER_REPO}

sgl-kernel/CMakeLists.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COM
1212
set(CMAKE_CXX_STANDARD 17)
1313
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
1414

15-
# Cuda
15+
# CUDA
1616
enable_language(CUDA)
1717
find_package(CUDAToolkit REQUIRED)
1818
set_property(GLOBAL PROPERTY CUDA_SEPARABLE_COMPILATION ON)
@@ -47,7 +47,7 @@ FetchContent_Populate(repo-cutlass)
4747
FetchContent_Declare(
4848
repo-deepgemm
4949
GIT_REPOSITORY https://github.com/deepseek-ai/DeepGEMM
50-
GIT_TAG c187c23ba8dcdbad91720737e8be9c43700cb9e9
50+
GIT_TAG 4499c4ccbb5d3958b1a069f29ef666156a121278
5151
GIT_SHALLOW OFF
5252
)
5353
FetchContent_Populate(repo-deepgemm)
@@ -170,6 +170,7 @@ string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE
170170
set(SOURCES
171171
"csrc/allreduce/custom_all_reduce.cu"
172172
"csrc/attention/cascade.cu"
173+
"csrc/attention/merge_attn_states.cu"
173174
"csrc/attention/cutlass_mla_kernel.cu"
174175
"csrc/attention/lightning_attention_decode_kernel.cu"
175176
"csrc/elementwise/activation.cu"

sgl-kernel/Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ format: check-deps ## Format all source files
4646

4747
FILES_TO_UPDATE = python/sgl_kernel/version.py \
4848
pyproject.toml \
49-
pyproject_rocm.toml
49+
pyproject_rocm.toml \
50+
../docker/Dockerfile.blackwell
5051

5152
update: ## Update version numbers across project files. Usage: make update <new_version>
5253
@if [ -z "$(filter-out $@,$(MAKECMDGOALS))" ]; then \
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
#include <ATen/cuda/CUDAContext.h>
2+
#include <c10/cuda/CUDAGuard.h>
3+
4+
#include <algorithm>
5+
#include <optional>
6+
7+
#include "pytorch_extension_utils.h"
8+
9+
// Helper functions to convert between different data types
10+
// (float, half, bfloat16) for the merge attention states kernel.
11+
inline __device__ float to_float(float u) {
12+
return u;
13+
}
14+
inline __device__ float to_float(half u) {
15+
return __half2float(u);
16+
}
17+
inline __device__ float to_float(__nv_bfloat16 u) {
18+
return __bfloat162float(u);
19+
}
20+
inline __device__ void from_float(float& d, float s) {
21+
d = s;
22+
}
23+
inline __device__ void from_float(half& d, float s) {
24+
d = __float2half(s);
25+
}
26+
inline __device__ void from_float(__nv_bfloat16& d, float s) {
27+
d = __float2bfloat16(s);
28+
}
29+
30+
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
31+
template <typename scalar_t, const uint NUM_THREADS>
32+
__global__ void merge_attn_states_kernel(
33+
scalar_t* output,
34+
float* output_lse,
35+
const scalar_t* prefix_output,
36+
const float* prefix_lse,
37+
const scalar_t* suffix_output,
38+
const float* suffix_lse,
39+
const uint num_tokens,
40+
const uint num_heads,
41+
const uint head_size) {
42+
using pack_128b_t = uint4;
43+
const uint pack_size = 16 / sizeof(scalar_t);
44+
const uint threads_per_head = head_size / pack_size;
45+
46+
const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x;
47+
const uint token_head_threads = num_tokens * num_heads * threads_per_head;
48+
49+
if (global_idx >= token_head_threads) return;
50+
51+
// global_idx -> token_idx + head_idx + pack_idx
52+
const uint token_head_idx = global_idx / threads_per_head;
53+
const uint pack_idx = global_idx % threads_per_head;
54+
55+
const uint token_idx = token_head_idx / num_heads;
56+
const uint head_idx = token_head_idx % num_heads;
57+
58+
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
59+
const uint head_offset = token_idx * num_heads * head_size + head_idx * head_size;
60+
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
61+
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
62+
scalar_t* output_head_ptr = output + head_offset;
63+
64+
// float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
65+
// float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
66+
float p_lse = prefix_lse[token_idx * num_heads + head_idx];
67+
float s_lse = suffix_lse[token_idx * num_heads + head_idx];
68+
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
69+
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
70+
71+
const float max_lse = fmaxf(p_lse, s_lse);
72+
p_lse = p_lse - max_lse;
73+
s_lse = s_lse - max_lse;
74+
const float p_se = expf(p_lse);
75+
const float s_se = expf(s_lse);
76+
const float out_se = p_se + s_se;
77+
const float p_scale = p_se / out_se;
78+
const float s_scale = s_se / out_se;
79+
80+
if (pack_offset < head_size) {
81+
// Pack 128b load
82+
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(prefix_head_ptr)[pack_offset / pack_size];
83+
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(suffix_head_ptr)[pack_offset / pack_size];
84+
pack_128b_t o_out_pack;
85+
86+
#pragma unroll
87+
for (uint i = 0; i < pack_size; ++i) {
88+
// Always use float for FMA to keep high precision.
89+
// half(uint16_t), bfloat16, float -> float.
90+
const float p_out_f = to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
91+
const float s_out_f = to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
92+
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
93+
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
94+
// float -> half(uint16_t), bfloat16, float.
95+
from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
96+
}
97+
98+
// Pack 128b storage
99+
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] = o_out_pack;
100+
}
101+
// We only need to write to output_lse once per head.
102+
if (output_lse != nullptr && pack_idx == 0) {
103+
float out_lse = logf(out_se) + max_lse;
104+
output_lse[token_idx * num_heads + head_idx] = out_lse;
105+
}
106+
}
107+
108+
// The following macro is used to dispatch the conversion function based on
109+
// the output data type. The FN is a macro that calls a function with
110+
// template<typename scalar_t>.
111+
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
112+
{ \
113+
if (scalar_dtype == at::ScalarType::Float) { \
114+
fn(float); \
115+
} else if (scalar_dtype == at::ScalarType::Half) { \
116+
fn(half); \
117+
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
118+
fn(__nv_bfloat16); \
119+
} else { \
120+
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
121+
} \
122+
}
123+
124+
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
125+
{ \
126+
merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \
127+
reinterpret_cast<scalar_t*>(output.data_ptr()), \
128+
reinterpret_cast<float*>(output_lse.data_ptr()), \
129+
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
130+
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
131+
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
132+
reinterpret_cast<float*>(suffix_lse.data_ptr()), \
133+
num_tokens, \
134+
num_heads, \
135+
head_size); \
136+
}
137+
138+
/*@brief Merges the attention states from prefix and suffix
139+
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
140+
*
141+
* @param output [n,h,d] The output tensor to store the merged attention states.
142+
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
143+
* @param prefix_output [n,h,d] The prefix attention states.
144+
* @param prefix_lse [n,h] The log-sum-exp values for the prefix attention
145+
* states.
146+
* @param suffix_output [n,h,d] The suffix attention states.
147+
* @param suffix_lse [n,h] The log-sum-exp values for the suffix attention
148+
* states.
149+
*/
150+
template <typename scalar_t>
151+
void merge_attn_states_launcher(
152+
const at::Tensor& prefix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
153+
const at::Tensor& prefix_lse, // [NUM_TOKENS, NUM_HEADS]
154+
const at::Tensor& suffix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
155+
const at::Tensor& suffix_lse, // [NUM_TOKENS, NUM_HEADS]
156+
at::Tensor& output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
157+
at::Tensor& output_lse // [NUM_TOKENS, NUM_HEADS]
158+
) {
159+
constexpr uint NUM_THREADS = 128;
160+
const uint num_tokens = output.size(0);
161+
const uint num_heads = output.size(1);
162+
const uint head_size = output.size(2);
163+
const uint pack_size = 16 / sizeof(scalar_t);
164+
TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size);
165+
// Process one pack elements per thread. for float, the
166+
// pack_size is 4 for half/bf16, the pack_size is 8.
167+
const uint threads_per_head = head_size / pack_size;
168+
const uint total_threads = num_tokens * num_heads * threads_per_head;
169+
170+
dim3 block(NUM_THREADS);
171+
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
172+
173+
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
174+
}
175+
176+
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
177+
{ merge_attn_states_launcher<scalar_t>(v_a, s_a, v_b, s_b, v_merged, s_merged); }
178+
179+
void merge_state_v2(
180+
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged) {
181+
// Input tensors must be contiguous
182+
CHECK_INPUT(v_a); // v_a prefix_output (seq_len, num_heads, head_dim)
183+
CHECK_INPUT(s_a); // s_a prefix_lse (seq_len, num_heads)
184+
CHECK_INPUT(v_b); // v_b suffix_output (seq_len, num_heads, head_dim)
185+
CHECK_INPUT(s_b); // s_b suffix_lse (seq_len, num_heads)
186+
// v_merged output (seq_len, num_heads, head_dim)
187+
// s_merged output_lse (seq_len, num_heads)
188+
auto device = v_a.device();
189+
CHECK_EQ(s_a.device(), device);
190+
CHECK_EQ(v_b.device(), device);
191+
CHECK_EQ(s_b.device(), device);
192+
CHECK_DIM(3, v_a);
193+
CHECK_DIM(2, s_a);
194+
CHECK_DIM(3, v_b);
195+
CHECK_DIM(2, s_b);
196+
CHECK_SHAPE(v_a, v_b);
197+
CHECK_SHAPE(s_a, s_b);
198+
CHECK_EQ(v_a.size(0), s_a.size(0));
199+
CHECK_EQ(v_a.size(1), s_b.size(1));
200+
DISPATCH_BY_SCALAR_DTYPE(v_merged.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
201+
}

sgl-kernel/csrc/common_extension.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
4747
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
4848
m.def("merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
4949
m.impl("merge_state", torch::kCUDA, &merge_state);
50+
m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
51+
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
5052
m.def(
5153
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
5254
"page_table, Tensor workspace) -> ()");

sgl-kernel/include/sgl_kernel_ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ void lightning_attention_decode(
8989
torch::Tensor new_kv);
9090
void merge_state(
9191
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
92+
void merge_state_v2(
93+
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
9294
void cutlass_mla_decode(
9395
torch::Tensor const& out,
9496
torch::Tensor const& q_nope_and_q_pe,

sgl-kernel/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build"
88

99
[project]
1010
name = "sgl-kernel"
11-
version = "0.0.8.post3"
11+
version = "0.0.9"
1212
description = "Kernel Library for SGLang"
1313
readme = "README.md"
1414
requires-python = ">=3.9"

sgl-kernel/pyproject_rocm.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"
99

1010
[project]
1111
name = "sgl-kernel"
12-
version = "0.0.8.post3"
12+
version = "0.0.9"
1313
description = "Kernel Library for SGLang"
1414
readme = "README.md"
1515
requires-python = ">=3.9"

sgl-kernel/python/sgl_kernel/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
cutlass_mla_get_workspace_size,
1717
lightning_attention_decode,
1818
merge_state,
19+
merge_state_v2,
1920
)
2021
from sgl_kernel.elementwise import (
2122
apply_rope_with_cos_sin_cache_inplace,

sgl-kernel/python/sgl_kernel/attention.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple
1+
from typing import Optional, Tuple
22

33
import torch
44

@@ -10,16 +10,47 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
1010

1111

1212
def merge_state(
13-
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
13+
v_a: torch.Tensor,
14+
s_a: torch.Tensor,
15+
v_b: torch.Tensor,
16+
s_b: torch.Tensor,
17+
v_merged: Optional[torch.Tensor] = None,
18+
s_merged: Optional[torch.Tensor] = None,
1419
) -> Tuple[torch.Tensor, torch.Tensor]:
1520
s_a = s_a.to(torch.float32)
1621
s_b = s_b.to(torch.float32)
17-
v_merged = torch.empty_like(v_a)
18-
s_merged = torch.empty_like(s_a)
22+
# Avoid creating new tensors if they are already provided
23+
if v_merged is None:
24+
v_merged = torch.empty_like(v_a)
25+
if s_merged is None:
26+
s_merged = torch.empty_like(s_a)
1927
torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
2028
return v_merged, s_merged
2129

2230

31+
def merge_state_v2(
32+
v_a: torch.Tensor,
33+
s_a: torch.Tensor,
34+
v_b: torch.Tensor,
35+
s_b: torch.Tensor,
36+
v_merged: Optional[torch.Tensor] = None,
37+
s_merged: Optional[torch.Tensor] = None,
38+
) -> Tuple[torch.Tensor, torch.Tensor]:
39+
s_a = s_a.to(torch.float32)
40+
s_b = s_b.to(torch.float32)
41+
# TODO(DefTruth): Currently, the custom merge_attn_states kernel
42+
# does not support the FP8 data type and non - CUDA devices.
43+
# It may be necessary to fall back to using the Triton kernel.
44+
45+
# Avoid creating new tensors if they are already provided
46+
if v_merged is None:
47+
v_merged = torch.empty_like(v_a)
48+
if s_merged is None:
49+
s_merged = torch.empty_like(s_a)
50+
torch.ops.sgl_kernel.merge_state_v2.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
51+
return v_merged, s_merged
52+
53+
2354
def cutlass_mla_decode(
2455
q_nope_and_q_pe: torch.Tensor,
2556
kv_c_and_k_pe_cache: torch.Tensor,
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.8.post3"
1+
__version__ = "0.0.9"

0 commit comments

Comments
 (0)