Skip to content

Commit a7d9212

Browse files
zhyncsjimoosciuc
authored andcommitted
use default for torch.ops (sgl-project#4835)
1 parent 1f7286e commit a7d9212

File tree

7 files changed

+51
-47
lines changed

7 files changed

+51
-47
lines changed

sgl-kernel/python/sgl_kernel/allreduce.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,49 +12,49 @@ def init_custom_ar(
1212
rank: int,
1313
full_nvlink: bool,
1414
) -> int:
15-
return torch.ops.sgl_kernel.init_custom_ar(
15+
return torch.ops.sgl_kernel.init_custom_ar.default(
1616
meta, rank_data, handles, offsets, rank, full_nvlink
1717
)
1818

1919
def all_reduce_reg(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
20-
torch.ops.sgl_kernel.all_reduce_reg(fa, inp, out)
20+
torch.ops.sgl_kernel.all_reduce_reg.default(fa, inp, out)
2121

2222
def all_reduce_unreg(
2323
fa: int, inp: torch.Tensor, reg_buffer: torch.Tensor, out: torch.Tensor
2424
) -> None:
25-
torch.ops.sgl_kernel.all_reduce_unreg(fa, inp, reg_buffer, out)
25+
torch.ops.sgl_kernel.all_reduce_unreg.default(fa, inp, reg_buffer, out)
2626

2727
def dispose(fa: int) -> None:
28-
torch.ops.sgl_kernel.dispose(fa)
28+
torch.ops.sgl_kernel.dispose.default(fa)
2929

3030
def meta_size() -> int:
31-
return torch.ops.sgl_kernel.meta_size()
31+
return torch.ops.sgl_kernel.meta_size.default()
3232

3333
def register_buffer(
3434
fa: int, t: torch.Tensor, handles: List[str], offsets: List[int]
3535
) -> None:
36-
return torch.ops.sgl_kernel.register_buffer(fa, t, handles, offsets)
36+
return torch.ops.sgl_kernel.register_buffer.default(fa, t, handles, offsets)
3737

3838
def get_graph_buffer_ipc_meta(fa: int) -> Tuple[torch.Tensor, List[int]]:
39-
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa)
39+
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa)
4040

4141
def register_graph_buffers(
4242
fa: int, handles: List[str], offsets: List[List[int]]
4343
) -> None:
44-
torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets)
44+
torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets)
4545

4646
def allocate_meta_buffer(size: int) -> torch.Tensor:
47-
return torch.ops.sgl_kernel.allocate_meta_buffer(size)
47+
return torch.ops.sgl_kernel.allocate_meta_buffer.default(size)
4848

4949
def get_meta_buffer_ipc_handle(inp: torch.Tensor) -> torch.Tensor:
50-
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle(inp)
50+
return torch.ops.sgl_kernel.get_meta_buffer_ipc_handle.default(inp)
5151

5252
else:
5353
# TRTLLM custom allreduce
5454
def init_custom_reduce(
5555
rank_id, num_devices, rank_data, buffers, tmp_buffers, barrier_in, barrier_out
5656
):
57-
return torch.ops.sgl_kernel.init_custom_ar(
57+
return torch.ops.sgl_kernel.init_custom_ar.default(
5858
rank_id,
5959
num_devices,
6060
rank_data,
@@ -65,13 +65,13 @@ def init_custom_reduce(
6565
)
6666

6767
def custom_dispose(fa):
68-
torch.ops.sgl_kernel.dispose(fa)
68+
torch.ops.sgl_kernel.dispose.default(fa)
6969

7070
def custom_reduce(fa, inp, out):
71-
torch.ops.sgl_kernel.all_reduce(fa, inp, out)
71+
torch.ops.sgl_kernel.all_reduce.default(fa, inp, out)
7272

7373
def get_graph_buffer_ipc_meta(fa):
74-
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta(fa)
74+
return torch.ops.sgl_kernel.get_graph_buffer_ipc_meta.default(fa)
7575

7676
def register_graph_buffers(fa, handles, offsets):
77-
torch.ops.sgl_kernel.register_graph_buffers(fa, handles, offsets)
77+
torch.ops.sgl_kernel.register_graph_buffers.default(fa, handles, offsets)

sgl-kernel/python/sgl_kernel/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,6 @@
22

33

44
def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
5-
torch.ops.sgl_kernel.lightning_attention_decode(
5+
torch.ops.sgl_kernel.lightning_attention_decode.default(
66
q, k, v, past_kv, slope, output, new_kv
77
)

sgl-kernel/python/sgl_kernel/elementwise.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@ def rmsnorm(
1414
) -> torch.Tensor:
1515
if out is None:
1616
out = torch.empty_like(input)
17-
torch.ops.sgl_kernel.rmsnorm(out, input, weight, eps, get_cuda_stream())
17+
torch.ops.sgl_kernel.rmsnorm.default(out, input, weight, eps, get_cuda_stream())
1818
return out
1919

2020

2121
def fused_add_rmsnorm(
2222
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
2323
) -> None:
24-
torch.ops.sgl_kernel.fused_add_rmsnorm(input, residual, weight, eps)
24+
torch.ops.sgl_kernel.fused_add_rmsnorm.default(input, residual, weight, eps)
2525

2626

2727
def gemma_rmsnorm(
@@ -32,14 +32,16 @@ def gemma_rmsnorm(
3232
) -> torch.Tensor:
3333
if out is None:
3434
out = torch.empty_like(input)
35-
torch.ops.sgl_kernel.gemma_rmsnorm(out, input, weight, eps, get_cuda_stream())
35+
torch.ops.sgl_kernel.gemma_rmsnorm.default(
36+
out, input, weight, eps, get_cuda_stream()
37+
)
3638
return out
3739

3840

3941
def gemma_fused_add_rmsnorm(
4042
input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
4143
) -> None:
42-
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm(
44+
torch.ops.sgl_kernel.gemma_fused_add_rmsnorm.default(
4345
input, residual, weight, eps, get_cuda_stream()
4446
)
4547

@@ -65,7 +67,7 @@ def silu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
6567
device=input.device,
6668
dtype=input.dtype,
6769
)
68-
torch.ops.sgl_kernel.silu_and_mul(out, input, get_cuda_stream())
70+
torch.ops.sgl_kernel.silu_and_mul.default(out, input, get_cuda_stream())
6971
return out
7072

7173

@@ -80,7 +82,7 @@ def gelu_tanh_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Te
8082
device=input.device,
8183
dtype=input.dtype,
8284
)
83-
torch.ops.sgl_kernel.gelu_tanh_and_mul(out, input, get_cuda_stream())
85+
torch.ops.sgl_kernel.gelu_tanh_and_mul.default(out, input, get_cuda_stream())
8486
return out
8587

8688

@@ -95,7 +97,7 @@ def gelu_and_mul(input: torch.Tensor, out: torch.Tensor = None) -> torch.Tensor:
9597
device=input.device,
9698
dtype=input.dtype,
9799
)
98-
torch.ops.sgl_kernel.gelu_and_mul(out, input, get_cuda_stream())
100+
torch.ops.sgl_kernel.gelu_and_mul.default(out, input, get_cuda_stream())
99101
return out
100102

101103

@@ -139,7 +141,7 @@ def apply_rope_with_cos_sin_cache_inplace(
139141
if cos_sin_cache.dtype != torch.float32:
140142
raise ValueError("cos_sin_cache should be float32")
141143

142-
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache(
144+
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
143145
q=query.view(query.shape[0], -1, head_size),
144146
k=key.view(key.shape[0], -1, head_size),
145147
q_rope=query.view(query.shape[0], -1, head_size),

sgl-kernel/python/sgl_kernel/gemm.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@
77
def awq_dequantize(
88
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
99
) -> torch.ByteTensor:
10-
return torch.ops.sgl_kernel.awq_dequantize(qweight, scales, qzeros)
10+
return torch.ops.sgl_kernel.awq_dequantize.default(qweight, scales, qzeros)
1111

1212

1313
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
14-
return torch.ops.sgl_kernel.int8_scaled_mm(
14+
return torch.ops.sgl_kernel.int8_scaled_mm.default(
1515
mat_a,
1616
mat_b,
1717
scales_a,
@@ -22,7 +22,7 @@ def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
2222

2323

2424
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
25-
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm(
25+
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm.default(
2626
mat_a,
2727
mat_b,
2828
scales_a,
@@ -32,7 +32,7 @@ def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
3232

3333

3434
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
35-
return torch.ops.sgl_kernel.fp8_scaled_mm(
35+
return torch.ops.sgl_kernel.fp8_scaled_mm.default(
3636
mat_a,
3737
mat_b,
3838
scales_a,
@@ -51,7 +51,7 @@ def _bmm_fp8_internal(
5151
B_scale: torch.Tensor,
5252
) -> None:
5353
cublas_handle = torch.cuda.current_blas_handle()
54-
torch.ops.sgl_kernel.bmm_fp8(
54+
torch.ops.sgl_kernel.bmm_fp8.default(
5555
A,
5656
B,
5757
D,
@@ -91,7 +91,7 @@ def sgl_per_token_group_quant_fp8(
9191
fp8_min: float,
9292
fp8_max: float,
9393
) -> None:
94-
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8(
94+
torch.ops.sgl_kernel.sgl_per_token_group_quant_fp8.default(
9595
input, output_q, output_s, group_size, eps, fp8_min, fp8_max
9696
)
9797

@@ -105,7 +105,7 @@ def sgl_per_token_group_quant_int8(
105105
int8_min: float,
106106
int8_max: float,
107107
) -> None:
108-
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8(
108+
torch.ops.sgl_kernel.sgl_per_token_group_quant_int8.default(
109109
input, output_q, output_s, group_size, eps, int8_min, int8_max
110110
)
111111

@@ -116,7 +116,9 @@ def sgl_per_tensor_quant_fp8(
116116
output_s: torch.Tensor,
117117
is_static: bool,
118118
) -> None:
119-
torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8(input, output_q, output_s, is_static)
119+
torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default(
120+
input, output_q, output_s, is_static
121+
)
120122

121123

122124
def cublas_grouped_gemm(
@@ -129,7 +131,7 @@ def cublas_grouped_gemm(
129131
len(inputs) > 0 and len(weights) > 0 and len(outputs) > 0
130132
), "Inputs/weights/outputs should not be empty!"
131133
cublas_handle = torch.cuda.current_blas_handle()
132-
torch.ops.sgl_kernel.cublas_grouped_gemm(
134+
torch.ops.sgl_kernel.cublas_grouped_gemm.default(
133135
inputs,
134136
weights,
135137
outputs,
@@ -144,7 +146,7 @@ def sgl_per_token_quant_fp8(
144146
output_q: torch.Tensor,
145147
output_s: torch.Tensor,
146148
) -> None:
147-
torch.ops.sgl_kernel.sgl_per_token_quant_fp8(input, output_q, output_s)
149+
torch.ops.sgl_kernel.sgl_per_token_quant_fp8.default(input, output_q, output_s)
148150

149151

150152
def cutlass_scaled_fp4_mm(
@@ -158,7 +160,7 @@ def cutlass_scaled_fp4_mm(
158160
assert a.ndim == 2 and b.ndim == 2
159161
m, n = a.shape[0], b.shape[0]
160162
out = torch.empty((m, n), dtype=out_dtype, device=a.device)
161-
torch.ops.sgl_kernels.cutlass_scaled_fp4_mm(
163+
torch.ops.sgl_kernel.cutlass_scaled_fp4_mm.default(
162164
out, a, b, block_scale_a, block_scale_b, alpha
163165
)
164166
return out
@@ -210,7 +212,7 @@ def scaled_fp4_quant(
210212
(rounded_m, rounded_n // 4), device=device, dtype=torch.int32
211213
)
212214

213-
torch.ops.sgl_kernels.scaled_fp4_quant(
215+
torch.ops.sgl_kernel.scaled_fp4_quant.default(
214216
output, input, output_scale, input_global_scale
215217
)
216218
output_scale = output_scale.view(torch.float8_e4m3fn)

sgl-kernel/python/sgl_kernel/moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def moe_align_block_size(
1111
token_cnts_buffer,
1212
cumsum_buffer,
1313
):
14-
torch.ops.sgl_kernel.moe_align_block_size(
14+
torch.ops.sgl_kernel.moe_align_block_size.default(
1515
topk_ids,
1616
num_experts,
1717
block_size,
@@ -29,6 +29,6 @@ def topk_softmax(
2929
token_expert_indices: torch.Tensor,
3030
gating_output: float,
3131
) -> None:
32-
torch.ops.sgl_kernel.topk_softmax(
32+
torch.ops.sgl_kernel.topk_softmax.default(
3333
topk_weights, topk_ids, token_expert_indices, gating_output
3434
)

sgl-kernel/python/sgl_kernel/sampling.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def _top_k_renorm_probs_internal(
1212
probs = probs.float()
1313
maybe_top_k_arr = maybe_top_k_arr.int() if maybe_top_k_arr is not None else None
1414
renorm_probs = torch.empty_like(probs)
15-
torch.ops.sgl_kernel.top_k_renorm_probs(
15+
torch.ops.sgl_kernel.top_k_renorm_probs.default(
1616
probs,
1717
renorm_probs,
1818
maybe_top_k_arr,
@@ -40,7 +40,7 @@ def _top_p_renorm_probs_internal(
4040
probs = probs.float()
4141
maybe_top_p_arr = maybe_top_p_arr.float() if maybe_top_p_arr is not None else None
4242
renorm_probs = torch.empty_like(probs)
43-
torch.ops.sgl_kernel.top_p_renorm_probs(
43+
torch.ops.sgl_kernel.top_p_renorm_probs.default(
4444
probs,
4545
renorm_probs,
4646
maybe_top_p_arr,
@@ -75,7 +75,7 @@ def _top_p_sampling_from_probs_internal(
7575
)
7676
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
7777
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
78-
torch.ops.sgl_kernel.top_p_sampling_from_probs(
78+
torch.ops.sgl_kernel.top_p_sampling_from_probs.default(
7979
probs,
8080
uniform_samples,
8181
samples,
@@ -121,7 +121,7 @@ def _top_k_top_p_sampling_from_probs_internal(
121121
)
122122
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
123123
success = torch.empty(probs.size(0), dtype=torch.bool, device=device)
124-
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs(
124+
torch.ops.sgl_kernel.top_k_top_p_sampling_from_probs.default(
125125
probs,
126126
uniform_samples,
127127
samples,
@@ -179,7 +179,7 @@ def _min_p_sampling_from_probs_internal(
179179
maybe_min_p_arr.float() if maybe_min_p_arr is not None else None
180180
)
181181
samples = torch.empty(probs.size(0), dtype=torch.int32, device=device)
182-
torch.ops.sgl_kernel.min_p_sampling_from_probs(
182+
torch.ops.sgl_kernel.min_p_sampling_from_probs.default(
183183
probs,
184184
uniform_samples,
185185
samples,

sgl-kernel/python/sgl_kernel/speculative.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def tree_speculative_sampling_target_only(
1717
threshold_acc: float = 1.0,
1818
deterministic: bool = True,
1919
) -> None:
20-
torch.ops.sgl_kernel.tree_speculative_sampling_target_only(
20+
torch.ops.sgl_kernel.tree_speculative_sampling_target_only.default(
2121
predicts,
2222
accept_index,
2323
accept_token_num,
@@ -45,7 +45,7 @@ def verify_tree_greedy(
4545
retrive_next_sibling: torch.Tensor,
4646
target_predict: torch.Tensor,
4747
) -> None:
48-
torch.ops.sgl_kernel.verify_tree_greedy(
48+
torch.ops.sgl_kernel.verify_tree_greedy.default(
4949
predicts,
5050
accept_index,
5151
accept_token_num,
@@ -71,7 +71,7 @@ def build_tree_kernel_efficient(
7171
depth: int,
7272
draft_token_num: int,
7373
) -> None:
74-
torch.ops.sgl_kernel.build_tree_kernel_efficient(
74+
torch.ops.sgl_kernel.build_tree_kernel_efficient.default(
7575
parent_list,
7676
selected_index,
7777
verified_seq_len,
@@ -92,7 +92,7 @@ def segment_packbits(
9292
output_indptr: torch.Tensor,
9393
y: torch.Tensor,
9494
) -> None:
95-
torch.ops.sgl_kernel.segment_packbits(
95+
torch.ops.sgl_kernel.segment_packbits.default(
9696
x,
9797
input_indptr,
9898
output_indptr,

0 commit comments

Comments
 (0)