Skip to content

Commit 0375709

Browse files
zhyncsjimoosciuc
authored andcommitted
fix bmm fp8 (sgl-project#4926)
1 parent e5e3782 commit 0375709

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

sgl-kernel/csrc/torch_extension.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,10 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
8282
/*
8383
* From FlashInfer
8484
*/
85-
m.def("bmm_fp8", bmm_fp8);
85+
m.def(
86+
"bmm_fp8(Tensor A, Tensor B, Tensor! D, Tensor A_scale, Tensor B_scale, Tensor workspace_buffer, int "
87+
"cublas_handle, int cuda_stream) -> ()");
88+
m.impl("bmm_fp8", torch::kCUDA, &bmm_fp8);
8689
m.def("min_p_sampling_from_probs", min_p_sampling_from_probs);
8790
m.def("top_k_renorm_probs", top_k_renorm_probs);
8891
m.def("top_p_renorm_probs", top_p_renorm_probs);

0 commit comments

Comments
 (0)