|
22 | 22 |
|
23 | 23 | #include "fbgemm_gpu/embedding_backward_template_helpers.cuh"
|
24 | 24 | #include "fbgemm_gpu/sparse_ops.h"
|
| 25 | +#include "fbgemm_gpu/config/feature_gates.h" |
25 | 26 | #include "fbgemm_gpu/split_embeddings_utils.cuh"
|
26 | 27 | #include "fbgemm_gpu/utils/barrier_isolation.cuh"
|
27 | 28 | #include "fbgemm_gpu/utils/ops_utils.h"
|
28 | 29 | #include "fbgemm_gpu/utils/tensor_accessor.h"
|
| 30 | + |
29 | 31 | {%- if is_rocm %}
|
30 | 32 | #include "fbgemm_gpu/rocm/cdna_guard.h"
|
31 | 33 | {%- endif %}
|
@@ -1218,14 +1220,18 @@ Tensor {{ embedding_cuda_op }}(
|
1218 | 1220 | #ifdef USE_ROCM
|
1219 | 1221 | {%- if is_rocm and not is_index_select and optimizer == "rowwise_adagrad" and
|
1220 | 1222 | not dense and not is_gwd_kernel and not vbe and not ssd and not nobag %}
|
1221 |
| - const bool isSupportedWeightsType = dev_weights.scalar_type() == at::ScalarType::Half |
| 1223 | +
|
| 1224 | + const static auto use_hip_kernel = fbgemm_gpu::config::is_feature_enabled(fbgemm_gpu::config::FeatureGateName::TBE_ROCM_HIP_BACKWARD_KERNEL); |
| 1225 | +
|
| 1226 | + const auto supported_weights_type = dev_weights.scalar_type() == at::ScalarType::Half |
1222 | 1227 | || dev_weights.scalar_type() == at::ScalarType::Float;
|
1223 |
| - if(isSupportedWeightsType && !mixed_D && rocm::is_supported_cdna()) |
| 1228 | +
|
| 1229 | + if (use_hip_kernel && supported_weights_type && !mixed_D && rocm::is_supported_cdna()) |
1224 | 1230 | {
|
1225 | 1231 | constexpr int segments_per_workgroup = 4;
|
1226 | 1232 | {%- for kDimSize in [64, 128, 160, 192, 256] %}
|
1227 | 1233 | {%- for kWeightDecayMode in [0, 1, 2] %}
|
1228 |
| - if(max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }}) |
| 1234 | + if (max_D == {{ kDimSize }} && weight_decay_mode == {{ kWeightDecayMode }}) |
1229 | 1235 | {
|
1230 | 1236 | warp_per_row_grid_size = div_round_up(sorted_linear_indices_num_runs[0].item<int32_t>(), segments_per_workgroup);
|
1231 | 1237 | blockSize = dim3(256);
|
|
0 commit comments