We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents ad83687 + cf45304 commit 420a21eCopy full SHA for 420a21e
fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu
@@ -9,7 +9,9 @@
9
#include "fbgemm_gpu/embedding_backward_template_helpers.cuh" // @manual
10
#include "fbgemm_gpu/ops_utils.h" // @manual
11
#include "fbgemm_gpu/split_embeddings_utils.cuh" // @manual
12
+#ifdef USE_ROCM
13
#include <rocprim/device/device_radix_sort.hpp>
14
+#endif
15
// clang-format off
16
#include "fbgemm_gpu/cub_namespace_prefix.cuh" // @manual
17
#include <cub/device/device_radix_sort.cuh>
@@ -297,7 +299,7 @@ transpose_embedding_input(
297
299
}
298
300
{
301
size_t temp_storage_bytes = 0;
-#ifdef __HIP_PLATFORM_NVIDIA__
302
+#ifndef USE_ROCM
303
AT_CUDA_CHECK(
304
FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs(
305
nullptr,
0 commit comments