@@ -22,121 +22,49 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
22
22
/*
23
23
* From csrc/allreduce
24
24
*/
25
- m.def (
26
- " init_custom_ar(int rank_id, int world_size, Tensor rank_data, int[] buffers, int[] tmp_result_buffers, int[] "
27
- " barrier_in, int[] barrier_out) -> int" );
28
- m.impl (" init_custom_ar" , torch::kCUDA , &init_custom_ar);
29
-
30
- m.def (" dispose" , &dispose);
31
-
32
- m.def (" all_reduce(int fa, Tensor inp, Tensor! out) -> ()" );
33
- m.impl (" all_reduce" , torch::kCUDA , &all_reduce);
34
-
35
- m.def (" get_graph_buffer_ipc_meta" , &get_graph_buffer_ipc_meta);
36
- m.def (" register_graph_buffers" , ®ister_graph_buffers);
25
+ m.def (" init_custom_ar" , init_custom_ar);
26
+ m.def (" dispose" , dispose);
27
+ m.def (" all_reduce" , all_reduce);
28
+ m.def (" get_graph_buffer_ipc_meta" , get_graph_buffer_ipc_meta);
29
+ m.def (" register_graph_buffers" , register_graph_buffers);
37
30
38
31
/*
39
32
* From csrc/attention
40
33
*/
41
- m.def (
42
- " lightning_attention_decode(Tensor q, Tensor k, Tensor v, Tensor past_kv, Tensor slope, Tensor! output, Tensor! "
43
- " new_kv) -> ()" );
44
- m.impl (" lightning_attention_decode" , torch::kCUDA , &lightning_attention_decode);
34
+ m.def (" lightning_attention_decode" , lightning_attention_decode);
45
35
46
36
/*
47
37
* From csrc/elementwise
48
38
*/
49
- m.def (" rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()" );
50
- m.impl (" rmsnorm" , torch::kCUDA , &rmsnorm);
51
-
52
- m.def (" fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps) -> ()" );
53
- m.impl (" fused_add_rmsnorm" , torch::kCUDA , &sgl_fused_add_rmsnorm);
54
-
55
- m.def (" gemma_rmsnorm(Tensor! output, Tensor input, Tensor weight, float eps, int cuda_stream) -> ()" );
56
- m.impl (" gemma_rmsnorm" , torch::kCUDA , &gemma_rmsnorm);
57
-
58
- m.def (" gemma_fused_add_rmsnorm(Tensor! input, Tensor! residual, Tensor weight, float eps, int cuda_stream) -> ()" );
59
- m.impl (" gemma_fused_add_rmsnorm" , torch::kCUDA , &gemma_fused_add_rmsnorm);
60
-
61
- m.def (" silu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()" );
62
- m.impl (" silu_and_mul" , torch::kCUDA , &silu_and_mul);
63
-
64
- m.def (" gelu_tanh_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()" );
65
- m.impl (" gelu_tanh_and_mul" , torch::kCUDA , &gelu_tanh_and_mul);
66
-
67
- m.def (" gelu_and_mul(Tensor! out, Tensor input, int cuda_stream) -> ()" );
68
- m.impl (" gelu_and_mul" , torch::kCUDA , &gelu_and_mul);
69
-
70
- m.def (
71
- " apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
72
- " Tensor pos_ids, bool interleave, int cuda_stream) -> ()" );
73
- m.impl (" apply_rope_pos_ids_cos_sin_cache" , torch::kCUDA , &apply_rope_pos_ids_cos_sin_cache);
39
+ m.def (" rmsnorm" , rmsnorm);
40
+ m.def (" fused_add_rmsnorm" , sgl_fused_add_rmsnorm);
41
+ m.def (" gemma_rmsnorm" , gemma_rmsnorm);
42
+ m.def (" gemma_fused_add_rmsnorm" , gemma_fused_add_rmsnorm);
43
+ m.def (" silu_and_mul" , silu_and_mul);
44
+ m.def (" gelu_tanh_and_mul" , gelu_tanh_and_mul);
45
+ m.def (" gelu_and_mul" , gelu_and_mul);
46
+ m.def (" apply_rope_pos_ids_cos_sin_cache" , apply_rope_pos_ids_cos_sin_cache);
74
47
75
48
/*
76
49
* From csrc/gemm
77
50
*/
78
- m.def (" awq_dequantize(Tensor qweight, Tensor scales, Tensor qzeros) -> Tensor" );
79
- m.impl (" awq_dequantize" , torch::kCUDA , &awq_dequantize);
80
-
81
- m.def (
82
- " int8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
83
- " bias) -> Tensor" );
84
- m.impl (" int8_scaled_mm" , torch::kCUDA , &int8_scaled_mm);
85
-
86
- m.def (
87
- " fp8_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype, Tensor? "
88
- " bias) -> Tensor" );
89
- m.impl (" fp8_scaled_mm" , torch::kCUDA , &fp8_scaled_mm);
90
-
91
- m.def (
92
- " fp8_blockwise_scaled_mm(Tensor mat_a, Tensor mat_b, Tensor scales_a, Tensor scales_b, ScalarType out_dtype) -> "
93
- " Tensor" );
94
- m.impl (" fp8_blockwise_scaled_mm" , torch::kCUDA , &fp8_blockwise_scaled_mm);
95
-
96
- m.def (
97
- " sgl_per_token_group_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
98
- " float eps, float fp8_min, float fp8_max) -> ()" );
99
- m.impl (" sgl_per_token_group_quant_fp8" , torch::kCUDA , &sgl_per_token_group_quant_fp8);
100
-
101
- m.def (
102
- " sgl_per_token_group_quant_int8(Tensor input, Tensor output_q, Tensor output_s, int group_size,"
103
- " float eps, float int8_min, float int8_max) -> ()" );
104
- m.impl (" sgl_per_token_group_quant_int8" , torch::kCUDA , &sgl_per_token_group_quant_int8);
105
-
106
- m.def (" sgl_per_tensor_quant_fp8(Tensor input, Tensor output_q, Tensor output_s, bool is_static) -> ()" );
107
- m.impl (" sgl_per_tensor_quant_fp8" , torch::kCUDA , &sgl_per_tensor_quant_fp8);
108
-
109
- m.def (" sgl_per_token_quant_fp8(Tensor input, Tensor output_q, Tensor output_s) -> ()" );
110
- m.impl (" sgl_per_token_quant_fp8" , torch::kCUDA , &sgl_per_token_quant_fp8);
111
-
112
- m.def (
113
- " cublas_grouped_gemm(Tensor[] inputs, Tensor[] weights, Tensor[] outputs,"
114
- " ScalarType out_dtype, int cublas_handle, int cuda_stream) -> ()" );
115
- m.impl (" cublas_grouped_gemm" , torch::kCUDA , &cublas_grouped_gemm);
116
-
117
- m.def (
118
- " cutlass_scaled_fp4_mm(Tensor! out, Tensor a, Tensor b,"
119
- " Tensor block_scale_a, Tensor block_scale_b,"
120
- " Tensor alpha) -> ()" );
121
- m.impl (" cutlass_scaled_fp4_mm" , torch::kCUDA , &cutlass_scaled_fp4_mm);
122
-
123
- m.def (
124
- " scaled_fp4_quant(Tensor! output, Tensor! input,"
125
- " Tensor! output_scale, Tensor! input_scale) -> ()" );
126
- m.impl (" scaled_fp4_quant" , torch::kCUDA , &scaled_fp4_quant);
51
+ m.def (" awq_dequantize" , awq_dequantize);
52
+ m.def (" int8_scaled_mm" , int8_scaled_mm);
53
+ m.def (" fp8_scaled_mm" , fp8_scaled_mm);
54
+ m.def (" fp8_blockwise_scaled_mm" , fp8_blockwise_scaled_mm);
55
+ m.def (" sgl_per_token_group_quant_fp8" , sgl_per_token_group_quant_fp8);
56
+ m.def (" sgl_per_token_group_quant_int8" , sgl_per_token_group_quant_int8);
57
+ m.def (" sgl_per_tensor_quant_fp8" , sgl_per_tensor_quant_fp8);
58
+ m.def (" sgl_per_token_quant_fp8" , sgl_per_token_quant_fp8);
59
+ m.def (" cublas_grouped_gemm" , cublas_grouped_gemm);
60
+ m.def (" cutlass_scaled_fp4_mm" , cutlass_scaled_fp4_mm);
61
+ m.def (" scaled_fp4_quant" , scaled_fp4_quant);
127
62
128
63
/*
129
64
* From csrc/moe
130
65
*/
131
- m.def (
132
- " moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
133
- " experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()" );
134
- m.impl (" moe_align_block_size" , torch::kCUDA , &moe_align_block_size);
135
-
136
- m.def (
137
- " topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
138
- " token_expert_indices, Tensor gating_output) -> ()" );
139
- m.impl (" topk_softmax" , torch::kCUDA , &topk_softmax);
66
+ m.def (" moe_align_block_size" , moe_align_block_size);
67
+ m.def (" topk_softmax" , topk_softmax);
140
68
141
69
m.def (
142
70
" moe_fused_gate(Tensor input, Tensor bias, int num_expert_group, int topk_group, int topk) -> "
@@ -146,62 +74,20 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
146
74
/*
147
75
* From csrc/speculative
148
76
*/
149
- m.def (
150
- " tree_speculative_sampling_target_only(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
151
- " Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
152
- " Tensor uniform_samples, Tensor target_probs, Tensor draft_probs, "
153
- " float threshold_single, float threshold_acc, "
154
- " bool deterministic, int cuda_stream) -> ()" );
155
- m.impl (" tree_speculative_sampling_target_only" , torch::kCUDA , &tree_speculative_sampling_target_only);
156
-
157
- m.def (
158
- " verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
159
- " Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
160
- " Tensor target_predict, int cuda_stream) -> ()" );
161
- m.impl (" verify_tree_greedy" , torch::kCUDA , &verify_tree_greedy);
162
-
163
- m.def (
164
- " build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
165
- " Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
166
- " Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()" );
167
- m.impl (" build_tree_kernel_efficient" , torch::kCUDA , &build_tree_kernel_efficient);
168
-
169
- m.def (" segment_packbits(Tensor x, Tensor input_indptr, Tensor output_indptr, Tensor! y, int cuda_stream) -> ()" );
170
- m.impl (" segment_packbits" , torch::kCUDA , &segment_packbits);
77
+ m.def (" tree_speculative_sampling_target_only" , tree_speculative_sampling_target_only);
78
+ m.def (" verify_tree_greedy" , verify_tree_greedy);
79
+ m.def (" build_tree_kernel_efficient" , build_tree_kernel_efficient);
80
+ m.def (" segment_packbits" , segment_packbits);
171
81
172
82
/*
173
83
* From FlashInfer
174
84
*/
175
- m.def (
176
- " bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
177
- " cublas_handle, int cuda_stream) -> ()" );
178
- m.impl (" bmm_fp8" , torch::kCUDA , &bmm_fp8);
179
-
180
- m.def (
181
- " min_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor? maybe_min_p_arr, float "
182
- " min_p_val, bool deterministic, int cuda_stream) -> ()" );
183
- m.impl (" min_p_sampling_from_probs" , torch::kCUDA , &min_p_sampling_from_probs);
184
-
185
- m.def (
186
- " top_k_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_k_arr, int top_k_val, int "
187
- " cuda_stream) -> ()" );
188
- m.impl (" top_k_renorm_probs" , torch::kCUDA , &top_k_renorm_probs);
189
-
190
- m.def (
191
- " top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val, int "
192
- " cuda_stream) -> ()" );
193
- m.impl (" top_p_renorm_probs" , torch::kCUDA , &top_p_renorm_probs);
194
-
195
- m.def (
196
- " top_k_top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
197
- " maybe_top_k_arr, float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, int "
198
- " cuda_stream) -> ()" );
199
- m.impl (" top_k_top_p_sampling_from_probs" , torch::kCUDA , &top_k_top_p_sampling_from_probs);
200
-
201
- m.def (
202
- " top_p_sampling_from_probs(Tensor probs, Tensor uniform_samples, Tensor! samples, Tensor! success, Tensor? "
203
- " maybe_top_p_arr, float top_p_val, bool deterministic, int cuda_stream) -> ()" );
204
- m.impl (" top_p_sampling_from_probs" , torch::kCUDA , &top_p_sampling_from_probs);
85
+ m.def (" bmm_fp8" , bmm_fp8);
86
+ m.def (" min_p_sampling_from_probs" , min_p_sampling_from_probs);
87
+ m.def (" top_k_renorm_probs" , top_k_renorm_probs);
88
+ m.def (" top_p_renorm_probs" , top_p_renorm_probs);
89
+ m.def (" top_k_top_p_sampling_from_probs" , top_k_top_p_sampling_from_probs);
90
+ m.def (" top_p_sampling_from_probs" , top_p_sampling_from_probs);
205
91
}
206
92
207
93
REGISTER_EXTENSION (common_ops)
0 commit comments