Skip to content

Commit 5ad2a1a

Browse files
sryapfacebook-github-bot
authored andcommitted
Use bounds_check_indices v2 on ROCm
Differential Revision: D72334377
1 parent def7bbe commit 5ad2a1a

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,21 @@ void bounds_check_indices_cuda(
6060
const int64_t info_B_num_bits,
6161
const int64_t info_B_mask,
6262
const int8_t bounds_check_version) {
63+
#if USE_ROCM
64+
// Force using bounds_check_indices v2 on ROCm because ROCm has a constraint
65+
// that the gridDim * blockDim has to be smaller than 2^32. The v1 kernel can
66+
// be launched with gridDim * blockDim > 2^32 while the v2 kernel limits the
67+
// gridDim size to 64 * # of SMs. Thus, its gridDim * blockDim is guaranteed
68+
// to be smaller than 2^32
69+
const auto bounds_check_indices_fn = _bounds_check_indices_cuda_v2;
70+
#else
6371
TORCH_CHECK(bounds_check_version == 1 || bounds_check_version == 2);
64-
const static bool use_v2 =
65-
fbgemm_gpu::config::is_feature_enabled(
66-
fbgemm_gpu::config::FeatureGateName::BOUNDS_CHECK_INDICES_V2) ||
67-
bounds_check_version == 2;
68-
const auto bounds_check_indices_fn =
69-
use_v2 ? _bounds_check_indices_cuda_v2 : _bounds_check_indices_cuda_v1;
72+
const static bool use_v2 = fbgemm_gpu::config::is_feature_enabled(
73+
fbgemm_gpu::config::FeatureGateName::BOUNDS_CHECK_INDICES_V2);
74+
const auto bounds_check_indices_fn = (use_v2 || bounds_check_version == 2)
75+
? _bounds_check_indices_cuda_v2
76+
: _bounds_check_indices_cuda_v1;
77+
#endif
7078
bounds_check_indices_fn(
7179
rows_per_table,
7280
indices,

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,16 @@ def __init__( # noqa C901
716716
# See:
717717
# https://fb.workplace.com/groups/fbgemmusers/permalink/9438488366231860/
718718
cache_precision = SparseType.FP32
719+
self.log("Override cache_precision=SparseType.FP32 on ROCm")
720+
721+
# NOTE: Use bounds_check_indices v2 on ROCm because ROCm has a
722+
# constraint that the gridDim * blockDim has to be smaller than
723+
# 2^32. The v1 kernel can be launched with gridDim * blockDim >
724+
# 2^32 while the v2 kernel limits the gridDim size to 64 * # of
725+
# SMs. Thus, its gridDim * blockDim is guaranteed to be smaller
726+
# than 2^32
727+
self.bounds_check_version = 2
728+
self.log("Override bounds_check_version=2 on ROCm")
719729
else:
720730
# NOTE: The changes from D65865527 are retained here until we can
721731
# test that the the hack also works for non-ROCm environments.

0 commit comments

Comments
 (0)