@@ -60,13 +60,21 @@ void bounds_check_indices_cuda(
60
60
const int64_t info_B_num_bits,
61
61
const int64_t info_B_mask,
62
62
const int8_t bounds_check_version) {
63
+ #if USE_ROCM
64
+ // Force using bounds_check_indices v2 on ROCm because ROCm has a constraint
65
+ // that the gridDim * blockDim has to be smaller than 2^32. The v1 kernel can
66
+ // be launched with gridDim * blockDim > 2^32 while the v2 kernel limits the
67
+ // gridDim size to 64 * # of SMs. Thus, its gridDim * blockDim is guaranteed
68
+ // to be smaller than 2^32
69
+ const auto bounds_check_indices_fn = _bounds_check_indices_cuda_v2;
70
+ #else
63
71
TORCH_CHECK (bounds_check_version == 1 || bounds_check_version == 2 );
64
- const static bool use_v2 =
65
- fbgemm_gpu::config::is_feature_enabled (
66
- fbgemm_gpu::config::FeatureGateName::BOUNDS_CHECK_INDICES_V2) ||
67
- bounds_check_version == 2 ;
68
- const auto bounds_check_indices_fn =
69
- use_v2 ? _bounds_check_indices_cuda_v2 : _bounds_check_indices_cuda_v1;
72
+ const static bool use_v2 = fbgemm_gpu::config::is_feature_enabled (
73
+ fbgemm_gpu::config::FeatureGateName::BOUNDS_CHECK_INDICES_V2);
74
+ const auto bounds_check_indices_fn = (use_v2 || bounds_check_version == 2 )
75
+ ? _bounds_check_indices_cuda_v2
76
+ : _bounds_check_indices_cuda_v1;
77
+ # endif
70
78
bounds_check_indices_fn (
71
79
rows_per_table,
72
80
indices,
0 commit comments