Skip to content

Commit c08a717

Browse files
yinfan98zhyncs
andauthored
[Feat] Update sgl-kernel flashinfer to latest main version (#5500)
Co-authored-by: zhyncs <[email protected]>
1 parent f13d65a commit c08a717

File tree

8 files changed

+393
-133
lines changed

8 files changed

+393
-133
lines changed

sgl-kernel/CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ FetchContent_Populate(repo-deepgemm)
5858
# flashinfer
5959
FetchContent_Declare(
6060
repo-flashinfer
61-
GIT_REPOSITORY https://github.com/sgl-project/flashinfer
62-
GIT_TAG sgl-kernel
61+
GIT_REPOSITORY https://github.com/flashinfer-ai/flashinfer.git
62+
GIT_TAG 9220fb3443b5a5d274f00ca5552f798e225239b7
6363
GIT_SHALLOW OFF
6464
)
6565
FetchContent_Populate(repo-flashinfer)

sgl-kernel/csrc/common_extension.cc

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,16 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
5858
/*
5959
* From csrc/elementwise
6060
*/
61-
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
61+
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()");
6262
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
6363

64-
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
64+
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
6565
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
6666

67-
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
67+
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, bool enable_pdl) -> ()");
6868
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
6969

70-
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
70+
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, bool enable_pdl) -> ()");
7171
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
7272

7373
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
@@ -186,29 +186,24 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
186186
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
187187

188188
m.def(
189-
"min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
190-
"min_p_val, bool deterministic, int cuda_stream) -> ()");
189+
"min_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_min_p_arr, float "
190+
"min_p_val, bool deterministic, Generator? gen) -> ()");
191191
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
192192

193-
m.def(
194-
"top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
195-
"cuda_stream) -> ()");
193+
m.def("top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
196194
m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs);
197195

198-
m.def(
199-
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
200-
"cuda_stream) -> ()");
196+
m.def("top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val) -> ()");
201197
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
202198

203199
m.def(
204-
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
205-
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
206-
"cuda_stream) -> ()");
200+
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_top_k_arr, "
201+
"float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
207202
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
208203

209204
m.def(
210-
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
211-
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
205+
"top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? "
206+
"maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
212207
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
213208

214209
/*

sgl-kernel/csrc/elementwise/fused_add_rms_norm_kernel.cu

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ limitations under the License.
2121

2222
using namespace flashinfer;
2323

24-
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps) {
24+
void sgl_fused_add_rmsnorm(
25+
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl) {
2526
CHECK_INPUT(input);
2627
CHECK_INPUT(residual);
2728
CHECK_INPUT(weight);
@@ -46,7 +47,10 @@ void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::T
4647
static_cast<c_type*>(weight.data_ptr()),
4748
batch_size,
4849
hidden_size,
50+
input.stride(0),
51+
residual.stride(0),
4952
eps,
53+
enable_pdl,
5054
torch_current_stream);
5155
TORCH_CHECK(
5256
status == cudaSuccess, "FusedAddRMSNorm failed with error code " + std::string(cudaGetErrorString(status)));

sgl-kernel/csrc/speculative/speculative_sampling.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
5454
DType threshold_acc) {
5555
const uint32_t bx = blockIdx.x, tx = threadIdx.x;
5656

57-
extern __shared__ __align__(alignof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
57+
extern __shared__ __align__(alignof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>))
5858
uint8_t smem_sampling[];
5959
auto& temp_storage =
60-
reinterpret_cast<SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
60+
reinterpret_cast<SamplingTempStorage<BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM>&>(smem_sampling);
6161

6262
DType prob_acc = 0.0;
6363
uint32_t cur_prob_offset = bx * num_draft_tokens * d;
@@ -144,7 +144,7 @@ __global__ void TreeSpeculativeSamplingTargetOnly(
144144
relu_q_minus_p_vec[j] = max(q_vec[j] - p_vec[j], DType(0));
145145
}
146146

147-
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC, DType>(
147+
DeviceSamplingFromProb<VEC_SIZE, BLOCK_THREADS, SCAN_ALGORITHM, REDUCE_ALGORITHM, DETERMINISTIC>(
148148
i, d, [&](DType x) { return x > 0; }, u, relu_q_minus_p_vec, aggregate_relu_q_minus_p, &temp_storage);
149149
if (aggregate_relu_q_minus_p > u) {
150150
break;
@@ -179,7 +179,7 @@ cudaError_t TreeSpeculativeSamplingTargetOnly(
179179
constexpr uint32_t BLOCK_THREADS = 1024;
180180
const uint32_t vec_size = std::gcd(16 / sizeof(DType), d);
181181

182-
const uint32_t smem_size = sizeof(SamplingTempStorage<DType, BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
182+
const uint32_t smem_size = sizeof(SamplingTempStorage<BLOCK_THREADS, SCAN_ALGO, REDUCE_ALGO>);
183183
dim3 nblks(batch_size);
184184
dim3 nthrs(BLOCK_THREADS);
185185
float capped_threshold_acc = fmaxf(threshold_acc, 1e-9f);

sgl-kernel/include/sgl_kernel_ops.h

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,11 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
102102
/*
103103
* From csrc/elementwise
104104
*/
105-
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
106-
void sgl_fused_add_rmsnorm(torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps);
107-
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, int64_t cuda_stream);
108-
void gemma_fused_add_rmsnorm(
109-
at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, int64_t cuda_stream);
105+
void rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
106+
void sgl_fused_add_rmsnorm(
107+
torch::Tensor input, torch::Tensor residual, torch::Tensor weight, double eps, bool enable_pdl);
108+
void gemma_rmsnorm(at::Tensor& output, at::Tensor& input, at::Tensor& weight, double eps, bool enable_pdl);
109+
void gemma_fused_add_rmsnorm(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps, bool enable_pdl);
110110
void silu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
111111
void gelu_tanh_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
112112
void gelu_and_mul(at::Tensor& out, at::Tensor& input, int64_t cuda_stream);
@@ -254,48 +254,38 @@ void segment_packbits(
254254
*/
255255
void min_p_sampling_from_probs(
256256
at::Tensor probs,
257-
at::Tensor uniform_samples,
258-
at::Tensor samples,
257+
at::Tensor output,
258+
std::optional<at::Tensor> maybe_indices,
259259
std::optional<at::Tensor> maybe_min_p_arr,
260260
double min_p_val,
261261
bool deterministic,
262-
int64_t cuda_stream);
262+
std::optional<at::Generator> gen);
263263

264264
void top_k_renorm_probs(
265-
at::Tensor probs,
266-
at::Tensor renorm_probs,
267-
std::optional<at::Tensor> maybe_top_k_arr,
268-
int64_t top_k_val,
269-
int64_t cuda_stream);
265+
at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
270266

271267
void top_p_renorm_probs(
272-
at::Tensor probs,
273-
at::Tensor renorm_probs,
274-
std::optional<at::Tensor> maybe_top_p_arr,
275-
double top_p_val,
276-
int64_t cuda_stream);
268+
at::Tensor probs, at::Tensor renorm_probs, std::optional<at::Tensor> maybe_top_p_arr, double top_p_val);
277269

278270
void top_k_top_p_sampling_from_probs(
279271
at::Tensor probs,
280-
at::Tensor uniform_samples,
281-
at::Tensor samples,
282-
at::Tensor success,
272+
at::Tensor output,
273+
std::optional<at::Tensor> maybe_indices,
283274
std::optional<at::Tensor> maybe_top_k_arr,
284275
double top_k_val,
285276
std::optional<at::Tensor> maybe_top_p_arr,
286277
double top_p_val,
287278
bool deterministic,
288-
int64_t cuda_stream);
279+
std::optional<at::Generator> gen);
289280

290281
void top_p_sampling_from_probs(
291282
at::Tensor probs,
292-
at::Tensor uniform_samples,
293-
at::Tensor samples,
294-
at::Tensor success,
283+
at::Tensor output,
284+
std::optional<at::Tensor> maybe_indices,
295285
std::optional<at::Tensor> maybe_top_p_arr,
296286
double top_p_val,
297287
bool deterministic,
298-
int64_t cuda_stream);
288+
std::optional<at::Generator> gen);
299289

300290
namespace flash {
301291
/*

sgl-kernel/python/sgl_kernel/elementwise.py

Lines changed: 108 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,38 +11,138 @@ def rmsnorm(
1111
weight: torch.Tensor,
1212
eps: float = 1e-6,
1313
out: Optional[torch.Tensor] = None,
14+
enable_pdl: bool = False,
1415
) -> torch.Tensor:
16+
r"""Root mean square normalization.
17+
18+
``out[i] = (input[i] / RMS(input)) * weight[i]``
19+
20+
Parameters
21+
----------
22+
input: torch.Tensor
23+
Input tensor, shape (batch_size, hidden_size).
24+
weight: torch.Tensor
25+
Weight tensor, shape (hidden_size,).
26+
eps: float
27+
Epsilon for numerical stability.
28+
out: Optional[torch.Tensor]
29+
The output tensor, if specified, the kernel will update this tensor inplace.
30+
enable_pdl: bool
31+
Whether to enable `programmatic dependent launch
32+
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
33+
34+
Returns
35+
-------
36+
output: torch.Tensor
37+
Normalized tensor, shape (batch_size, hidden_size).
38+
"""
1539
if out is None:
1640
out = torch.empty_like(input)
17-
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, get_cuda_stream())
41+
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, enable_pdl)
1842
return out
1943

2044

2145
def fused_add_rmsnorm(
22-
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
46+
input: torch.Tensor,
47+
residual: torch.Tensor,
48+
weight: torch.Tensor,
49+
eps: float = 1e-6,
50+
enable_pdl: bool = False,
2351
) -> None:
24-
torch.ops.sgl_kernel.fused_add_rmsnorm.default(input, residual, weight, eps)
52+
r"""Fused add root mean square normalization.
53+
54+
Step 1:
55+
``residual[i] += input[i]``
56+
57+
Step 2:
58+
``input[i] = (residual[i] / RMS(residual)) * weight[i]``
59+
60+
Parameters
61+
----------
62+
input: torch.Tensor
63+
Input tensor, shape (batch_size, hidden_size).
64+
residual: torch.Tensor
65+
Residual tensor, shape (batch_size, hidden_size).
66+
weight: torch.Tensor
67+
Weight tensor, shape (hidden_size,).
68+
eps: float
69+
Epsilon for numerical stability.
70+
enable_pdl: bool
71+
Whether to enable `programmatic dependent launch
72+
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
73+
"""
74+
torch.ops.sgl_kernel.fused_add_rmsnorm.default(
75+
input, residual, weight, eps, enable_pdl
76+
)
2577

2678

2779
def gemma_rmsnorm(
2880
input: torch.Tensor,
2981
weight: torch.Tensor,
3082
eps: float = 1e-6,
3183
out: Optional[torch.Tensor] = None,
84+
enable_pdl: bool = False,
3285
) -> torch.Tensor:
86+
r"""Gemma-style root mean square normalization.
87+
88+
``out[i] = (input[i] / RMS(input)) * (weight[i] + 1)``
89+
90+
Parameters
91+
----------
92+
input: torch.Tensor
93+
Input tensor, shape (batch_size, hidden_size).
94+
weight: torch.Tensor
95+
Weight tensor, shape (hidden_size,).
96+
eps: float
97+
Epsilon for numerical stability.
98+
out: Optional[torch.Tensor]
99+
The output tensor, if specified, the kernel will update this tensor inplace.
100+
enable_pdl: bool
101+
Whether to enable `programmatic dependent launch
102+
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
103+
104+
Returns
105+
-------
106+
output: torch.Tensor
107+
Gemma Normalized tensor, shape (batch_size, hidden_size).
108+
"""
33109
if out is None:
34110
out = torch.empty_like(input)
35-
torch.ops.sgl_kernel.gemma_rmsnorm.default(
36-
out, input, weight, eps, get_cuda_stream()
37-
)
111+
torch.ops.sgl_kernel.gemma_rmsnorm.default(out, input, weight, eps, enable_pdl)
38112
return out
39113

40114

41115
def gemma_fused_add_rmsnorm(
42-
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
116+
input: torch.Tensor,
117+
residual: torch.Tensor,
118+
weight: torch.Tensor,
119+
eps: float = 1e-6,
120+
enable_pdl: bool = False,
43121
) -> None:
122+
r"""Gemma-style fused add root mean square normalization.
123+
124+
Step 1:
125+
``residual[i] += input[i]``
126+
127+
Step 2:
128+
``input[i] = (residual[i] / RMS(residual)) * (weight + 1)``
129+
130+
Parameters
131+
----------
132+
input: torch.Tensor
133+
Input tensor, shape (batch_size, hidden_size).
134+
residual: torch.Tensor
135+
Residual tensor, shape (batch_size, hidden_size).
136+
weight: torch.Tensor
137+
Weight tensor, shape (hidden_size,).
138+
eps: float
139+
Epsilon for numerical stability.
140+
enable_pdl: bool
141+
Whether to enable `programmatic dependent launch
142+
<https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization>`_
143+
"""
44144
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
45-
input, residual, weight, eps, get_cuda_stream()
145+
input, residual, weight, eps, enable_pdl
46146
)
47147

48148

0 commit comments

Comments
 (0)