Skip to content

Commit 420a21e

Browse files
authored
Merge pull request #66 from avbokovoy/fix-rocm-header
Wrap rocmprim header with #ifndef
2 parents ad83687 + cf45304 commit 420a21e

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
#include "fbgemm_gpu/embedding_backward_template_helpers.cuh" // @manual
1010
#include "fbgemm_gpu/ops_utils.h" // @manual
1111
#include "fbgemm_gpu/split_embeddings_utils.cuh" // @manual
12+
#ifdef USE_ROCM
1213
#include <rocprim/device/device_radix_sort.hpp>
14+
#endif
1315
// clang-format off
1416
#include "fbgemm_gpu/cub_namespace_prefix.cuh" // @manual
1517
#include <cub/device/device_radix_sort.cuh>
@@ -297,7 +299,7 @@ transpose_embedding_input(
297299
}
298300
{
299301
size_t temp_storage_bytes = 0;
300-
#ifdef __HIP_PLATFORM_NVIDIA__
302+
#ifndef USE_ROCM
301303
AT_CUDA_CHECK(
302304
FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceRadixSort::SortPairs(
303305
nullptr,

0 commit comments

Comments
 (0)