Skip to content

Commit 7d840c5

Browse files
yinfan98jimoosciuc
authored andcommitted
[Misc] Clean m.def and add Development Tips (sgl-project#4890)
1 parent 8b51ab0 commit 7d840c5

File tree

3 files changed

+86
-159
lines changed

3 files changed

+86
-159
lines changed

sgl-kernel/README.md

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,47 @@ Steps to add a new kernel:
5151
4. Update [CMakeLists.txt](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/CMakeLists.txt) to include new CUDA source
5252
5. Expose Python interface in [python](https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel)
5353

54+
### Development Tips
55+
56+
1. When implementing kernels in [csrc](https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc), only define pure CUDA files and C++ interfaces. If you need to use `Torch::tensor`, use `<torch/all.h>` instead of `<torch/extension.h>`. Using `<torch/extension.h>` will cause compilation errors when using SABI.
57+
58+
2. When creating torch extensions, simply add the function definition with `m.def`:
59+
```cpp
60+
m.def("register_graph_buffers", register_graph_buffers);
61+
```
62+
63+
3. When exposing Python interfaces, avoid using kwargs in C++ interface kernels.
64+
65+
**Avoid this:**
66+
67+
```cpp
68+
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
69+
q=query.view(query.shape[0], -1, head_size),
70+
k=key.view(key.shape[0], -1, head_size),
71+
q_rope=query.view(query.shape[0], -1, head_size),
72+
k_rope=key.view(key.shape[0], -1, head_size),
73+
cos_sin_cache=cos_sin_cache,
74+
pos_ids=positions.long(),
75+
interleave=(not is_neox),
76+
cuda_stream=get_cuda_stream(),
77+
)
78+
```
79+
80+
**Use this instead:**
81+
82+
```cpp
83+
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
84+
query.view(query.shape[0], -1, head_size),
85+
key.view(key.shape[0], -1, head_size),
86+
query.view(query.shape[0], -1, head_size),
87+
key.view(key.shape[0], -1, head_size),
88+
cos_sin_cache,
89+
positions.long(),
90+
(not is_neox),
91+
get_cuda_stream(),
92+
)
93+
```
94+
5495
### Build & Install
5596

5697
Development build:

sgl-kernel/csrc/torch_extension.cc

Lines changed: 37 additions & 151 deletions
Original file line numberDiff line numberDiff line change
@@ -22,121 +22,49 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
2222
/*
2323
* From csrc/allreduce
2424
*/
25-
m.def(
26-
"init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] "
27-
"barrier_in, int[] barrier_out) -> int");
28-
m.impl("init_custom_ar", torch::kCUDA, &init_custom_ar);
29-
30-
m.def("dispose", &dispose);
31-
32-
m.def("all_reduce(int fa, Tensor inp, Tensor! out) -> ()");
33-
m.impl("all_reduce", torch::kCUDA, &all_reduce);
34-
35-
m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta);
36-
m.def("register_graph_buffers", &register_graph_buffers);
25+
m.def("init_custom_ar", init_custom_ar);
26+
m.def("dispose", dispose);
27+
m.def("all_reduce", all_reduce);
28+
m.def("get_graph_buffer_ipc_meta", get_graph_buffer_ipc_meta);
29+
m.def("register_graph_buffers", register_graph_buffers);
3730

3831
/*
3932
* From csrc/attention
4033
*/
41-
m.def(
42-
"lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
43-
"new_kv) -> ()");
44-
m.impl("lightning_attention_decode", torch::kCUDA, &lightning_attention_decode);
34+
m.def("lightning_attention_decode", lightning_attention_decode);
4535

4636
/*
4737
* From csrc/elementwise
4838
*/
49-
m.def("rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
50-
m.impl("rmsnorm", torch::kCUDA, &rmsnorm);
51-
52-
m.def("fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()");
53-
m.impl("fused_add_rmsnorm", torch::kCUDA, &sgl_fused_add_rmsnorm);
54-
55-
m.def("gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()");
56-
m.impl("gemma_rmsnorm", torch::kCUDA, &gemma_rmsnorm);
57-
58-
m.def("gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()");
59-
m.impl("gemma_fused_add_rmsnorm", torch::kCUDA, &gemma_fused_add_rmsnorm);
60-
61-
m.def("silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
62-
m.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
63-
64-
m.def("gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
65-
m.impl("gelu_tanh_and_mul", torch::kCUDA, &gelu_tanh_and_mul);
66-
67-
m.def("gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()");
68-
m.impl("gelu_and_mul", torch::kCUDA, &gelu_and_mul);
69-
70-
m.def(
71-
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
72-
"Tensor pos_ids, bool interleave, int cuda_stream) -> ()");
73-
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
39+
m.def("rmsnorm", rmsnorm);
40+
m.def("fused_add_rmsnorm", sgl_fused_add_rmsnorm);
41+
m.def("gemma_rmsnorm", gemma_rmsnorm);
42+
m.def("gemma_fused_add_rmsnorm", gemma_fused_add_rmsnorm);
43+
m.def("silu_and_mul", silu_and_mul);
44+
m.def("gelu_tanh_and_mul", gelu_tanh_and_mul);
45+
m.def("gelu_and_mul", gelu_and_mul);
46+
m.def("apply_rope_pos_ids_cos_sin_cache", apply_rope_pos_ids_cos_sin_cache);
7447

7548
/*
7649
* From csrc/gemm
7750
*/
78-
m.def("awq_dequantize(Tensor qweight, Tensor scales, Tensor qzeros) -> Tensor");
79-
m.impl("awq_dequantize", torch::kCUDA, &awq_dequantize);
80-
81-
m.def(
82-
"int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
83-
"bias) -> Tensor");
84-
m.impl("int8_scaled_mm", torch::kCUDA, &int8_scaled_mm);
85-
86-
m.def(
87-
"fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
88-
"bias) -> Tensor");
89-
m.impl("fp8_scaled_mm", torch::kCUDA, &fp8_scaled_mm);
90-
91-
m.def(
92-
"fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype) -> "
93-
"Tensor");
94-
m.impl("fp8_blockwise_scaled_mm", torch::kCUDA, &fp8_blockwise_scaled_mm);
95-
96-
m.def(
97-
"sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
98-
" float eps, float fp8_min, float fp8_max) -> ()");
99-
m.impl("sgl_per_token_group_quant_fp8", torch::kCUDA, &sgl_per_token_group_quant_fp8);
100-
101-
m.def(
102-
"sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
103-
" float eps, float int8_min, float int8_max) -> ()");
104-
m.impl("sgl_per_token_group_quant_int8", torch::kCUDA, &sgl_per_token_group_quant_int8);
105-
106-
m.def("sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()");
107-
m.impl("sgl_per_tensor_quant_fp8", torch::kCUDA, &sgl_per_tensor_quant_fp8);
108-
109-
m.def("sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()");
110-
m.impl("sgl_per_token_quant_fp8", torch::kCUDA, &sgl_per_token_quant_fp8);
111-
112-
m.def(
113-
"cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs,"
114-
" ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()");
115-
m.impl("cublas_grouped_gemm", torch::kCUDA, &cublas_grouped_gemm);
116-
117-
m.def(
118-
"cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
119-
" Tensor block_scale_a, Tensor block_scale_b,"
120-
" Tensor alpha) -> ()");
121-
m.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
122-
123-
m.def(
124-
"scaled_fp4_quant(Tensor! output, Tensor! input,"
125-
" Tensor! output_scale, Tensor! input_scale) -> ()");
126-
m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
51+
m.def("awq_dequantize", awq_dequantize);
52+
m.def("int8_scaled_mm", int8_scaled_mm);
53+
m.def("fp8_scaled_mm", fp8_scaled_mm);
54+
m.def("fp8_blockwise_scaled_mm", fp8_blockwise_scaled_mm);
55+
m.def("sgl_per_token_group_quant_fp8", sgl_per_token_group_quant_fp8);
56+
m.def("sgl_per_token_group_quant_int8", sgl_per_token_group_quant_int8);
57+
m.def("sgl_per_tensor_quant_fp8", sgl_per_tensor_quant_fp8);
58+
m.def("sgl_per_token_quant_fp8", sgl_per_token_quant_fp8);
59+
m.def("cublas_grouped_gemm", cublas_grouped_gemm);
60+
m.def("cutlass_scaled_fp4_mm", cutlass_scaled_fp4_mm);
61+
m.def("scaled_fp4_quant", scaled_fp4_quant);
12762

12863
/*
12964
* From csrc/moe
13065
*/
131-
m.def(
132-
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
133-
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
134-
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
135-
136-
m.def(
137-
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
138-
"token_expert_indices, Tensor gating_output) -> ()");
139-
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
66+
m.def("moe_align_block_size", moe_align_block_size);
67+
m.def("topk_softmax", topk_softmax);
14068

14169
m.def(
14270
"moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> "
@@ -146,62 +74,20 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
14674
/*
14775
* From csrc/speculative
14876
*/
149-
m.def(
150-
"tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
151-
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
152-
"Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, "
153-
"float threshold_single, float threshold_acc, "
154-
"bool deterministic, int cuda_stream) -> ()");
155-
m.impl("tree_speculative_sampling_target_only", torch::kCUDA, &tree_speculative_sampling_target_only);
156-
157-
m.def(
158-
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
159-
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
160-
"Tensor target_predict, int cuda_stream) -> ()");
161-
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
162-
163-
m.def(
164-
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
165-
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
166-
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
167-
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
168-
169-
m.def("segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int cuda_stream) -> ()");
170-
m.impl("segment_packbits", torch::kCUDA, &segment_packbits);
77+
m.def("tree_speculative_sampling_target_only", tree_speculative_sampling_target_only);
78+
m.def("verify_tree_greedy", verify_tree_greedy);
79+
m.def("build_tree_kernel_efficient", build_tree_kernel_efficient);
80+
m.def("segment_packbits", segment_packbits);
17181

17282
/*
17383
* From FlashInfer
17484
*/
175-
m.def(
176-
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
177-
"cublas_handle, int cuda_stream) -> ()");
178-
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
179-
180-
m.def(
181-
"min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
182-
"min_p_val, bool deterministic, int cuda_stream) -> ()");
183-
m.impl("min_p_sampling_from_probs", torch::kCUDA, &min_p_sampling_from_probs);
184-
185-
m.def(
186-
"top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
187-
"cuda_stream) -> ()");
188-
m.impl("top_k_renorm_probs", torch::kCUDA, &top_k_renorm_probs);
189-
190-
m.def(
191-
"top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
192-
"cuda_stream) -> ()");
193-
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
194-
195-
m.def(
196-
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
197-
"maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
198-
"cuda_stream) -> ()");
199-
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
200-
201-
m.def(
202-
"top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
203-
"maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()");
204-
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
85+
m.def("bmm_fp8", bmm_fp8);
86+
m.def("min_p_sampling_from_probs", min_p_sampling_from_probs);
87+
m.def("top_k_renorm_probs", top_k_renorm_probs);
88+
m.def("top_p_renorm_probs", top_p_renorm_probs);
89+
m.def("top_k_top_p_sampling_from_probs", top_k_top_p_sampling_from_probs);
90+
m.def("top_p_sampling_from_probs", top_p_sampling_from_probs);
20591
}
20692

20793
REGISTER_EXTENSION(common_ops)

sgl-kernel/python/sgl_kernel/elementwise.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,12 @@ def apply_rope_with_cos_sin_cache_inplace(
142142
raise ValueError("cos_sin_cache should be float32")
143143

144144
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
145-
q=query.view(query.shape[0], -1, head_size),
146-
k=key.view(key.shape[0], -1, head_size),
147-
q_rope=query.view(query.shape[0], -1, head_size),
148-
k_rope=key.view(key.shape[0], -1, head_size),
149-
cos_sin_cache=cos_sin_cache,
150-
pos_ids=positions.long(),
151-
interleave=(not is_neox),
152-
cuda_stream=get_cuda_stream(),
145+
query.view(query.shape[0], -1, head_size),
146+
key.view(key.shape[0], -1, head_size),
147+
query.view(query.shape[0], -1, head_size),
148+
key.view(key.shape[0], -1, head_size),
149+
cos_sin_cache,
150+
positions.long(),
151+
(not is_neox),
152+
get_cuda_stream(),
153153
)

0 commit comments

Comments
 (0)