Skip to content

Commit adf0bdc

Browse files
yinfan98zhyncs
authored andcommitted
[sgl-kernel] fix: fix cu118 compile error (sgl-project#6123)
Co-authored-by: zhyncs <[email protected]>
1 parent 3e145c9 commit adf0bdc

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

sgl-kernel/csrc/attention/cutlass_mla_kernel.cu

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,21 @@ limitations under the License.
2525
#include <device/sm100_mla.hpp>
2626
#include <kernel/sm100_mla_tile_scheduler.hpp>
2727

28-
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
28+
// clang-format off
29+
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
30+
void cutlass_mla_decode(
31+
torch::Tensor const& out,
32+
torch::Tensor const& q_nope_and_q_pe,
33+
torch::Tensor const& kv_c_and_k_pe_cache,
34+
torch::Tensor const& seq_lens,
35+
torch::Tensor const& page_table,
36+
torch::Tensor const& workspace) {
37+
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
38+
}
39+
int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count) {
40+
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size");
41+
}
42+
#else
2943

3044
#define CUTLASS_CHECK(status) \
3145
{ \
@@ -209,3 +223,4 @@ int64_t cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches,
209223
}
210224

211225
#endif
226+
// clang-format on

sgl-kernel/csrc/grammar/apply_token_bitmask_inplace_cuda.cu

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@
2424
#include <cuda_runtime.h>
2525
#include <torch/all.h>
2626
#include <ATen/cuda/CUDAContext.h>
27-
// clang-format on
2827

29-
#if defined CUDA_VERSION && CUDA_VERSION >= 12040
28+
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
29+
void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optional<at::Tensor> indices = at::nullopt) {
30+
TORCH_CHECK(false, "CUDA version must be >= 12.4 for ApplyTokenBitmaskInplace");
31+
}
32+
#else
3033

3134
#ifndef CUDART_INF_FP16
3235
#define CUDART_INF_FP16 __ushort_as_half((unsigned short)0x7C00U)
@@ -252,3 +255,4 @@ void ApplyTokenBitmaskInplace(at::Tensor logits, at::Tensor bitmask, at::optiona
252255
}
253256
}
254257
#endif
258+
// clang-format on

0 commit comments

Comments
 (0)