Skip to content

Commit ce3ef6b

Browse files
DefTruthjimoosciuc
authored andcommitted
kernel: support slightly faster merge_state_v2 cuda kernel (sgl-project#5381)
1 parent 9f2c2cc commit ce3ef6b

File tree

7 files changed

+638
-4
lines changed

7 files changed

+638
-4
lines changed

sgl-kernel/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ string(REPLACE "-D__CUDA_NO_HALF2_OPERATORS__" "" CMAKE_CUDA_FLAGS "${CMAKE
170170
set(SOURCES
171171
"csrc/allreduce/custom_all_reduce.cu"
172172
"csrc/attention/cascade.cu"
173+
"csrc/attention/merge_attn_states.cu"
173174
"csrc/attention/cutlass_mla_kernel.cu"
174175
"csrc/attention/lightning_attention_decode_kernel.cu"
175176
"csrc/elementwise/activation.cu"
Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
#include <ATen/cuda/CUDAContext.h>
2+
#include <c10/cuda/CUDAGuard.h>
3+
4+
#include <algorithm>
5+
#include <optional>
6+
7+
#include "pytorch_extension_utils.h"
8+
9+
// Helper functions to convert between different data types
10+
// (float, half, bfloat16) for the merge attention states kernel.
11+
inline __device__ float to_float(float u) {
12+
return u;
13+
}
14+
inline __device__ float to_float(half u) {
15+
return __half2float(u);
16+
}
17+
inline __device__ float to_float(__nv_bfloat16 u) {
18+
return __bfloat162float(u);
19+
}
20+
inline __device__ void from_float(float& d, float s) {
21+
d = s;
22+
}
23+
inline __device__ void from_float(half& d, float s) {
24+
d = __float2half(s);
25+
}
26+
inline __device__ void from_float(__nv_bfloat16& d, float s) {
27+
d = __float2bfloat16(s);
28+
}
29+
30+
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
31+
template <typename scalar_t, const uint NUM_THREADS>
32+
__global__ void merge_attn_states_kernel(
33+
scalar_t* output,
34+
float* output_lse,
35+
const scalar_t* prefix_output,
36+
const float* prefix_lse,
37+
const scalar_t* suffix_output,
38+
const float* suffix_lse,
39+
const uint num_tokens,
40+
const uint num_heads,
41+
const uint head_size) {
42+
using pack_128b_t = uint4;
43+
const uint pack_size = 16 / sizeof(scalar_t);
44+
const uint threads_per_head = head_size / pack_size;
45+
46+
const uint global_idx = blockIdx.x * NUM_THREADS + threadIdx.x;
47+
const uint token_head_threads = num_tokens * num_heads * threads_per_head;
48+
49+
if (global_idx >= token_head_threads) return;
50+
51+
// global_idx -> token_idx + head_idx + pack_idx
52+
const uint token_head_idx = global_idx / threads_per_head;
53+
const uint pack_idx = global_idx % threads_per_head;
54+
55+
const uint token_idx = token_head_idx / num_heads;
56+
const uint head_idx = token_head_idx % num_heads;
57+
58+
const uint pack_offset = pack_idx * pack_size; // (0~15)*8, etc.
59+
const uint head_offset = token_idx * num_heads * head_size + head_idx * head_size;
60+
const scalar_t* prefix_head_ptr = prefix_output + head_offset;
61+
const scalar_t* suffix_head_ptr = suffix_output + head_offset;
62+
scalar_t* output_head_ptr = output + head_offset;
63+
64+
// float p_lse = prefix_lse[head_idx * num_tokens + token_idx];
65+
// float s_lse = suffix_lse[head_idx * num_tokens + token_idx];
66+
float p_lse = prefix_lse[token_idx * num_heads + head_idx];
67+
float s_lse = suffix_lse[token_idx * num_heads + head_idx];
68+
p_lse = std::isinf(p_lse) ? -std::numeric_limits<float>::infinity() : p_lse;
69+
s_lse = std::isinf(s_lse) ? -std::numeric_limits<float>::infinity() : s_lse;
70+
71+
const float max_lse = fmaxf(p_lse, s_lse);
72+
p_lse = p_lse - max_lse;
73+
s_lse = s_lse - max_lse;
74+
const float p_se = expf(p_lse);
75+
const float s_se = expf(s_lse);
76+
const float out_se = p_se + s_se;
77+
const float p_scale = p_se / out_se;
78+
const float s_scale = s_se / out_se;
79+
80+
if (pack_offset < head_size) {
81+
// Pack 128b load
82+
pack_128b_t p_out_pack = reinterpret_cast<const pack_128b_t*>(prefix_head_ptr)[pack_offset / pack_size];
83+
pack_128b_t s_out_pack = reinterpret_cast<const pack_128b_t*>(suffix_head_ptr)[pack_offset / pack_size];
84+
pack_128b_t o_out_pack;
85+
86+
#pragma unroll
87+
for (uint i = 0; i < pack_size; ++i) {
88+
// Always use float for FMA to keep high precision.
89+
// half(uint16_t), bfloat16, float -> float.
90+
const float p_out_f = to_float(reinterpret_cast<const scalar_t*>(&p_out_pack)[i]);
91+
const float s_out_f = to_float(reinterpret_cast<const scalar_t*>(&s_out_pack)[i]);
92+
// fma: a * b + c = p_out_f * p_scale + (s_out_f * s_scale)
93+
const float o_out_f = p_out_f * p_scale + (s_out_f * s_scale);
94+
// float -> half(uint16_t), bfloat16, float.
95+
from_float(reinterpret_cast<scalar_t*>(&o_out_pack)[i], o_out_f);
96+
}
97+
98+
// Pack 128b storage
99+
reinterpret_cast<pack_128b_t*>(output_head_ptr)[pack_offset / pack_size] = o_out_pack;
100+
}
101+
// We only need to write to output_lse once per head.
102+
if (output_lse != nullptr && pack_idx == 0) {
103+
float out_lse = logf(out_se) + max_lse;
104+
output_lse[token_idx * num_heads + head_idx] = out_lse;
105+
}
106+
}
107+
108+
// The following macro is used to dispatch the conversion function based on
109+
// the output data type. The FN is a macro that calls a function with
110+
// template<typename scalar_t>.
111+
#define DISPATCH_BY_SCALAR_DTYPE(scalar_dtype, fn) \
112+
{ \
113+
if (scalar_dtype == at::ScalarType::Float) { \
114+
fn(float); \
115+
} else if (scalar_dtype == at::ScalarType::Half) { \
116+
fn(half); \
117+
} else if (scalar_dtype == at::ScalarType::BFloat16) { \
118+
fn(__nv_bfloat16); \
119+
} else { \
120+
TORCH_CHECK(false, "Unsupported data type of O: ", scalar_dtype); \
121+
} \
122+
}
123+
124+
#define LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS) \
125+
{ \
126+
merge_attn_states_kernel<scalar_t, NUM_THREADS><<<grid, block>>>( \
127+
reinterpret_cast<scalar_t*>(output.data_ptr()), \
128+
reinterpret_cast<float*>(output_lse.data_ptr()), \
129+
reinterpret_cast<scalar_t*>(prefix_output.data_ptr()), \
130+
reinterpret_cast<float*>(prefix_lse.data_ptr()), \
131+
reinterpret_cast<scalar_t*>(suffix_output.data_ptr()), \
132+
reinterpret_cast<float*>(suffix_lse.data_ptr()), \
133+
num_tokens, \
134+
num_heads, \
135+
head_size); \
136+
}
137+
138+
/*@brief Merges the attention states from prefix and suffix
139+
* into the output tensor. NUM_TOKENS: n, NUM_HEADS: h, HEAD_SIZE: d
140+
*
141+
* @param output [n,h,d] The output tensor to store the merged attention states.
142+
* @param output_lse [h,d] Optional tensor to store the log-sum-exp values.
143+
* @param prefix_output [n,h,d] The prefix attention states.
144+
* @param prefix_lse [n,h] The log-sum-exp values for the prefix attention
145+
* states.
146+
* @param suffix_output [n,h,d] The suffix attention states.
147+
* @param suffix_lse [n,h] The log-sum-exp values for the suffix attention
148+
* states.
149+
*/
150+
template <typename scalar_t>
151+
void merge_attn_states_launcher(
152+
const at::Tensor& prefix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
153+
const at::Tensor& prefix_lse, // [NUM_TOKENS, NUM_HEADS]
154+
const at::Tensor& suffix_output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
155+
const at::Tensor& suffix_lse, // [NUM_TOKENS, NUM_HEADS]
156+
at::Tensor& output, // [NUM_TOKENS, NUM_HEADS, HEAD_SIZE]
157+
at::Tensor& output_lse // [NUM_TOKENS, NUM_HEADS]
158+
) {
159+
constexpr uint NUM_THREADS = 128;
160+
const uint num_tokens = output.size(0);
161+
const uint num_heads = output.size(1);
162+
const uint head_size = output.size(2);
163+
const uint pack_size = 16 / sizeof(scalar_t);
164+
TORCH_CHECK(head_size % pack_size == 0, "headsize must be multiple of pack_size:", pack_size);
165+
// Process one pack elements per thread. for float, the
166+
// pack_size is 4 for half/bf16, the pack_size is 8.
167+
const uint threads_per_head = head_size / pack_size;
168+
const uint total_threads = num_tokens * num_heads * threads_per_head;
169+
170+
dim3 block(NUM_THREADS);
171+
dim3 grid((total_threads + NUM_THREADS - 1) / NUM_THREADS);
172+
173+
LAUNCH_MERGE_ATTN_STATES(scalar_t, NUM_THREADS);
174+
}
175+
176+
#define CALL_MERGE_ATTN_STATES_LAUNCHER(scalar_t) \
177+
{ merge_attn_states_launcher<scalar_t>(v_a, s_a, v_b, s_b, v_merged, s_merged); }
178+
179+
void merge_state_v2(
180+
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged) {
181+
// Input tensors must be contiguous
182+
CHECK_INPUT(v_a); // v_a prefix_output (seq_len, num_heads, head_dim)
183+
CHECK_INPUT(s_a); // s_a prefix_lse (seq_len, num_heads)
184+
CHECK_INPUT(v_b); // v_b suffix_output (seq_len, num_heads, head_dim)
185+
CHECK_INPUT(s_b); // s_b suffix_lse (seq_len, num_heads)
186+
// v_merged output (seq_len, num_heads, head_dim)
187+
// s_merged output_lse (seq_len, num_heads)
188+
auto device = v_a.device();
189+
CHECK_EQ(s_a.device(), device);
190+
CHECK_EQ(v_b.device(), device);
191+
CHECK_EQ(s_b.device(), device);
192+
CHECK_DIM(3, v_a);
193+
CHECK_DIM(2, s_a);
194+
CHECK_DIM(3, v_b);
195+
CHECK_DIM(2, s_b);
196+
CHECK_SHAPE(v_a, v_b);
197+
CHECK_SHAPE(s_a, s_b);
198+
CHECK_EQ(v_a.size(0), s_a.size(0));
199+
CHECK_EQ(v_a.size(1), s_b.size(1));
200+
DISPATCH_BY_SCALAR_DTYPE(v_merged.dtype(), CALL_MERGE_ATTN_STATES_LAUNCHER);
201+
}

sgl-kernel/csrc/common_extension.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
4747
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
4848
m.def("merge_state(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
4949
m.impl("merge_state", torch::kCUDA, &merge_state);
50+
m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
51+
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
5052
m.def(
5153
"cutlass_mla_decode(Tensor! out, Tensor q_nope_and_q_pe, Tensor kv_c_and_k_pe_cache, Tensor seq_lens, Tensor "
5254
"page_table, Tensor workspace) -> ()");

sgl-kernel/include/sgl_kernel_ops.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ void lightning_attention_decode(
8989
torch::Tensor new_kv);
9090
void merge_state(
9191
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
92+
void merge_state_v2(
93+
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
9294
void cutlass_mla_decode(
9395
torch::Tensor const& out,
9496
torch::Tensor const& q_nope_and_q_pe,

sgl-kernel/python/sgl_kernel/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
cutlass_mla_get_workspace_size,
1717
lightning_attention_decode,
1818
merge_state,
19+
merge_state_v2,
1920
)
2021
from sgl_kernel.elementwise import (
2122
apply_rope_with_cos_sin_cache_inplace,

sgl-kernel/python/sgl_kernel/attention.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple
1+
from typing import Optional, Tuple
22

33
import torch
44

@@ -10,16 +10,47 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
1010

1111

1212
def merge_state(
13-
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
13+
v_a: torch.Tensor,
14+
s_a: torch.Tensor,
15+
v_b: torch.Tensor,
16+
s_b: torch.Tensor,
17+
v_merged: Optional[torch.Tensor] = None,
18+
s_merged: Optional[torch.Tensor] = None,
1419
) -> Tuple[torch.Tensor, torch.Tensor]:
1520
s_a = s_a.to(torch.float32)
1621
s_b = s_b.to(torch.float32)
17-
v_merged = torch.empty_like(v_a)
18-
s_merged = torch.empty_like(s_a)
22+
# Avoid creating new tensors if they are already provided
23+
if v_merged is None:
24+
v_merged = torch.empty_like(v_a)
25+
if s_merged is None:
26+
s_merged = torch.empty_like(s_a)
1927
torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
2028
return v_merged, s_merged
2129

2230

31+
def merge_state_v2(
32+
v_a: torch.Tensor,
33+
s_a: torch.Tensor,
34+
v_b: torch.Tensor,
35+
s_b: torch.Tensor,
36+
v_merged: Optional[torch.Tensor] = None,
37+
s_merged: Optional[torch.Tensor] = None,
38+
) -> Tuple[torch.Tensor, torch.Tensor]:
39+
s_a = s_a.to(torch.float32)
40+
s_b = s_b.to(torch.float32)
41+
# TODO(DefTruth): Currently, the custom merge_attn_states kernel
42+
# does not support the FP8 data type and non - CUDA devices.
43+
# It may be necessary to fall back to using the Triton kernel.
44+
45+
# Avoid creating new tensors if they are already provided
46+
if v_merged is None:
47+
v_merged = torch.empty_like(v_a)
48+
if s_merged is None:
49+
s_merged = torch.empty_like(s_a)
50+
torch.ops.sgl_kernel.merge_state_v2.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
51+
return v_merged, s_merged
52+
53+
2354
def cutlass_mla_decode(
2455
q_nope_and_q_pe: torch.Tensor,
2556
kv_c_and_k_pe_cache: torch.Tensor,

0 commit comments

Comments
 (0)