Skip to content

Commit 625f2e9

Browse files
q10liligwu
authored andcommitted
Add feature gate for HIP-based backward kernel (pytorch#3835)
Summary: X-link: https://github.com/facebookresearch/FBGEMM/pull/921 Pull Request resolved: pytorch#3835 - Add feature gate for HIP-based backward kernel. This is a followup to D66986498 (pytorch#3488) Reviewed By: sryap Differential Revision: D71329616 fbshipit-source-id: b2b7beb034d5a43edeebb83dfe6597b133f984fe
1 parent a4274de commit 625f2e9

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,12 @@
2222

2323
#include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
2424
#include "fbgemm_gpu/sparse_ops.h"
25+
#include "fbgemm_gpu/config/feature_gates.h"
2526
#include "fbgemm_gpu/split_embeddings_utils.cuh"
2627
#include "fbgemm_gpu/utils/barrier_isolation.cuh"
2728
#include "fbgemm_gpu/utils/ops_utils.h"
2829
#include "fbgemm_gpu/utils/tensor_accessor.h"
30+
2931
{%- if is_rocm %}
3032
#include "fbgemm_gpu/rocm/cdna_guard.h"
3133
{%- endif %}
@@ -1218,14 +1220,18 @@ Tensor {{ embedding_cuda_op }}(
12181220
#ifdef USE_ROCM
12191221
{%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and
12201222
not dense and not is_gwd_kernel and not vbe and not ssd and not nobag %}
1221-
const bool isSupportedWeightsType = dev_weights.scalar_type() == at::ScalarType::Half
1223+
1224+
const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL);
1225+
1226+
const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half
12221227
|| dev_weights.scalar_type() == at::ScalarType::Float;
1223-
if(isSupportedWeightsType && !mixed_D && rocm::is_supported_cdna())
1228+
1229+
if (use_hip_kernel && supported_weights_type && !mixed_D && rocm::is_supported_cdna())
12241230
{
12251231
constexpr int segments_per_workgroup = 4;
12261232
{%- for kDimSize in [64, 128, 160, 192, 256] %}
12271233
{%- for kWeightDecayMode in [0, 1, 2] %}
1228-
if(max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }})
1234+
if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }})
12291235
{
12301236
warp_per_row_grid_size = div_round_up(sorted_linear_indices_num_runs[0].item<int32_t>(), segments_per_workgroup);
12311237
blockSize = dim3(256);

fbgemm_gpu/fbgemm_gpu/config/feature_list.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,12 @@ def foo():
5151
# Enable Ensemble Rowwise Adagrad (D60189486 stack)
5252
TBE_ENSEMBLE_ROWWISE_ADAGRAD = auto()
5353

54-
# Enable ROCm packed bags optimization in inference
54+
# Enable ROCm packed bags optimization in TBE inference
5555
TBE_ROCM_INFERENCE_PACKED_BAGS = auto()
5656

57+
# Enable HIP-based backward kernels in TBE training
58+
TBE_ROCM_HIP_BACKWARD_KERNEL = auto()
59+
5760
# Enable bounds_check_indices_v2
5861
BOUNDS_CHECK_INDICES_V2 = auto()
5962

fbgemm_gpu/include/fbgemm_gpu/config/feature_gates.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ namespace fbgemm_gpu::config {
6060
X(TBE_ENSEMBLE_ROWWISE_ADAGRAD) \
6161
X(TBE_ANNOTATE_KINETO_TRACE) \
6262
X(TBE_ROCM_INFERENCE_PACKED_BAGS) \
63+
X(TBE_ROCM_HIP_BACKWARD_KERNEL) \
6364
X(BOUNDS_CHECK_INDICES_V2)
6465
// X(EXAMPLE_FEATURE_FLAG)
6566

0 commit comments

Comments
 (0)