Skip to content

Commit 53c7591

Browse files
jeffdailyfacebook-github-bot
authored andcommitted
updates for ROCm 6.0 support (pytorch#2088)
Summary: ROCm 6.0 introduces backwards-incompatible changes such as removing the long-deprecated use of `__HIP_PLATFORM_HCC__`. It is better to use the USE_ROCM macro which is already defined and indicates a ROCm build. This PR also defines `__HIP_PLATFORM_AMD__` which is the new symbol name. This symbol is still required for compiling with HIP headers but when not using hip-clang. Reviewed By: sryap Differential Revision: D50580075 Pulled By: sryap
1 parent f94254d commit 53c7591

21 files changed

+57
-59
lines changed

fbgemm_gpu/cmake/Hip.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ set(CMAKE_MODULE_PATH ${HIP_PATH}/cmake ${CMAKE_MODULE_PATH})
103103
# Disable Asserts In Code (Can't use asserts on HIP stack.)
104104
ADD_DEFINITIONS(-DNDEBUG)
105105
ADD_DEFINITIONS(-DUSE_ROCM)
106+
ADD_DEFINITIONS(-D__HIP_PLATFORM_AMD__)
106107

107108
IF(NOT DEFINED ENV{PYTORCH_ROCM_ARCH})
108109
SET(FBGEMM_ROCM_ARCH gfx900;gfx906;gfx908;gfx90a)

fbgemm_gpu/codegen/embedding_backward_dense_host.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ class SplitLookupFunction_Dense_Op
166166

167167
TORCH_CHECK_EQ(grad_outputs.size(), 1);
168168

169-
#ifdef __HIP_PLATFORM_HCC__
169+
#ifdef USE_ROCM
170170
constexpr int32_t BT_block_size = 64;
171171
constexpr int32_t max_segment_length_per_warp = 64;
172172
#else

fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ class {{ autograd_func }} :
398398

399399
TORCH_CHECK_EQ(grad_outputs.size(), 1);
400400

401-
#ifdef __HIP_PLATFORM_HCC__
401+
#ifdef USE_ROCM
402402
constexpr int32_t BT_block_size = 64;
403403
constexpr int32_t max_segment_length_per_warp = 64;
404404
#else

fbgemm_gpu/codegen/embedding_backward_split_template.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
459459

460460
// V100: 96 KB; A100: 160 KB; H100: 228 KB.
461461
int max_shared_bytes = 0;
462-
#ifndef __HIP_PLATFORM_HCC__
462+
#ifndef USE_ROCM
463463
cudaDeviceGetAttribute(&max_shared_bytes, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev_weights.get_device());
464464
#else
465465
// MI100 has 64 KB local memory (shared memory) per workgroup
@@ -468,7 +468,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
468468
C10_CUDA_KERNEL_LAUNCH_CHECK();
469469
int shared_kb = max_shared_bytes >> 10;
470470
// V100: 64 KB; A100: 96 KB; H100: 144 KB
471-
#ifndef __HIP_PLATFORM_HCC__
471+
#ifndef USE_ROCM
472472
// Use 2/3 of the available GPU shared mem; leave rooms for L1$.
473473
int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
474474
TORCH_CHECK_GT(used_shared_kb, 0);
@@ -740,7 +740,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
740740
kMaxVecsPerThread,
741741
kThreadGroupSize>;
742742
743-
#ifndef __HIP_PLATFORM_HCC__
743+
#ifndef USE_ROCM
744744
cudaFuncSetAttribute(
745745
backward_cta_per_row_kernel,
746746
cudaFuncAttributeMaxDynamicSharedMemorySize,
@@ -851,7 +851,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
851851
if (std::is_same<emb_t, uint8_t>::value) {
852852
shmem_bytes = BT_block_size * sizeof(
853853
at::acc_type<cache_t, true>) * 4 * kWarpSize * kMaxVecsPerThread;
854-
#ifndef __HIP_PLATFORM_HCC__
854+
#ifndef USE_ROCM
855855
cudaFuncSetAttribute(
856856
backward_warp_per_row_kernel,
857857
cudaFuncAttributeMaxDynamicSharedMemorySize,

fbgemm_gpu/codegen/embedding_forward_quantized_split_lookup.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ __launch_bounds__(kMaxThreads) void int_nbit_split_embedding_codegen_forward_pru
5353

5454
const uint32_t subwarp_id = threadIdx.x / 4;
5555
const uint32_t subwarp_tid = threadIdx.x % 4;
56-
#ifdef __HIP_PLATFORM_HCC__
56+
#ifdef USE_ROCM
5757
const uint64_t subwarp_mask = static_cast<uint64_t>(0xF) << (4 * subwarp_id);
5858
#else
5959
const uint32_t subwarp_mask = static_cast<uint32_t>(0xF) << (4 * subwarp_id);

fbgemm_gpu/codegen/embedding_forward_split_template.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ batch_index_select_dim0_codegen_forward_small_kernel(
7878
{%- endif %}
7979

8080
{% if not dense %}
81-
#ifndef __HIP_PLATFORM_HCC__
81+
#ifndef USE_ROCM
8282
// Support only the split-pooled TBE case
8383
template <
8484
typename emb_t,
@@ -647,7 +647,7 @@ batch_index_select_dim0_codegen_forward_cuda(
647647
// if (!is_experimental)
648648
} else {
649649
650-
#ifdef __HIP_PLATFORM_HCC__
650+
#ifdef USE_ROCM
651651
TORCH_CHECK(false, "is_experimental=True is not supported in ROCm");
652652
#else
653653
// Allocate num warps per table based on max_D

fbgemm_gpu/include/fbgemm_gpu/fbgemm_cuda_utils.cuh

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
#include <ATen/cuda/CUDAGraphsUtils.cuh>
1414

1515
// clang-format off
16-
#ifdef __HIP_PLATFORM_HCC__
16+
#ifdef USE_ROCM
1717
#define HIPCUB_ARCH 1
1818
#include <hipcub/backend/rocprim/block/block_scan.hpp>
1919
#else
@@ -35,8 +35,7 @@
3535
#include <cuda_fp16.h>
3636
#include <cuda_runtime.h>
3737
#include <curand_kernel.h>
38-
#if !defined(__HIP_PLATFORM_HCC__) && defined(CUDA_VERSION) && \
39-
CUDA_VERSION >= 9000
38+
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 9000
4039
#define FBGEMM_USE_SUBWARP_SHUFFLE
4140
#endif
4241

@@ -58,14 +57,14 @@ namespace fbgemm_gpu {
5857

5958
enum class PrimitiveType : uint8_t { FP = 0, INT = 1, BF = 2 };
6059

61-
#ifdef __HIP_PLATFORM_HCC__
60+
#ifdef USE_ROCM
6261
namespace cub = hipcub;
6362
#endif
6463

6564
#define DEVICE_INLINE __device__ inline __attribute__((always_inline))
6665

6766
// Warp size
68-
#ifdef __HIP_PLATFORM_HCC__
67+
#ifdef USE_ROCM
6968
static constexpr int32_t kWarpSize = 64;
7069
#else
7170
static constexpr int32_t kWarpSize = 32;
@@ -93,7 +92,7 @@ struct Half4 {
9392
half2 b;
9493

9594
__device__ inline void store(at::Half* p) {
96-
#ifdef __HIP_PLATFORM_HCC__
95+
#ifdef USE_ROCM
9796
p[0] = __low2half(a);
9897
p[1] = __high2half(a);
9998
p[2] = __low2half(b);
@@ -157,7 +156,7 @@ struct Vec4T<float> {
157156
}
158157

159158
DEVICE_INLINE void load(const at::Half* p) {
160-
#ifdef __HIP_PLATFORM_HCC__
159+
#ifdef USE_ROCM
161160
union U {
162161
half2 h[2];
163162
uint2 ui;
@@ -311,7 +310,7 @@ struct Vec4T<at::Half> {
311310
}
312311

313312
DEVICE_INLINE void load(const at::Half* p) {
314-
#ifdef __HIP_PLATFORM_HCC__
313+
#ifdef USE_ROCM
315314
union U {
316315
half2 h[2];
317316
uint2 ui;
@@ -409,7 +408,7 @@ struct Vec4T<at::Half> {
409408
}
410409

411410
DEVICE_INLINE static void copy(const at::Half* src, at::Half* dst) {
412-
#ifdef __HIP_PLATFORM_HCC__
411+
#ifdef USE_ROCM
413412
dst[0] = src[0];
414413
dst[1] = src[1];
415414
dst[2] = src[2];
@@ -525,7 +524,7 @@ struct Vec4T<at::BFloat16> {
525524
}
526525

527526
DEVICE_INLINE void load(const at::Half* p) {
528-
#ifdef __HIP_PLATFORM_HCC__
527+
#ifdef USE_ROCM
529528
union U {
530529
half2 h[2];
531530
uint2 ui;
@@ -705,7 +704,7 @@ struct Vec4T<double> {
705704
}
706705

707706
DEVICE_INLINE void load(const at::Half* p) {
708-
#ifdef __HIP_PLATFORM_HCC__
707+
#ifdef USE_ROCM
709708
union U {
710709
half2 h[2];
711710
uint2 ui;
@@ -854,7 +853,7 @@ DEVICE_INLINE T shfl_xor(
854853
int laneMask,
855854
int width = kWarpSize,
856855
unsigned shfl_sync_mask = kFullWarpMask) {
857-
#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000
856+
#if defined(USE_ROCM) || CUDA_VERSION < 9000
858857
return __shfl_xor(val, laneMask, width);
859858
#else
860859
return __shfl_xor_sync(shfl_sync_mask, val, laneMask, width);
@@ -867,7 +866,7 @@ DEVICE_INLINE T shfl_sync(
867866
int srcLane = 0,
868867
int width = kWarpSize,
869868
unsigned shfl_sync_mask = kFullWarpMask) {
870-
#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000
869+
#if defined(USE_ROCM) || CUDA_VERSION < 9000
871870
return __shfl(val, srcLane, width);
872871
#else
873872
return __shfl_sync(shfl_sync_mask, val, srcLane, width);
@@ -880,21 +879,21 @@ DEVICE_INLINE T shfl_down_sync(
880879
unsigned delta,
881880
int width = kWarpSize,
882881
unsigned shfl_sync_mask = kFullWarpMask) {
883-
#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000
882+
#if defined(USE_ROCM) || CUDA_VERSION < 9000
884883
return __shfl_down(val, delta, width);
885884
#else
886885
return __shfl_down_sync(shfl_sync_mask, val, delta, width);
887886
#endif
888887
}
889888

890-
#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000
889+
#if defined(USE_ROCM) || CUDA_VERSION < 9000
891890
DEVICE_INLINE uint64_t ballot_sync(
892891
#else
893892
DEVICE_INLINE uint32_t ballot_sync(
894893
#endif
895894
int predicate,
896895
unsigned shfl_sync_mask = kFullWarpMask) {
897-
#if defined(__HIP_PLATFORM_HCC__) || CUDA_VERSION < 9000
896+
#if defined(USE_ROCM) || CUDA_VERSION < 9000
898897
return __ballot(predicate);
899898
#else
900899
return __ballot_sync(shfl_sync_mask, predicate);
@@ -913,7 +912,7 @@ warpReduceAllSum(T val, unsigned shfl_sync_mask = kFullWarpMask) {
913912
}
914913

915914
DEVICE_INLINE void syncwarp() {
916-
#ifdef __HIP_PLATFORM_HCC__
915+
#ifdef USE_ROCM
917916
// Performance - replace a block level __syncthreads with per CU
918917
// __threadfence_block. It is a fine replacement for __syncwarp on AMD GPUs,
919918
// it is because a. memory fencing: __threadfence_block ops. at CU level,
@@ -1002,7 +1001,7 @@ inline __device__ void warpBitonicMergeLE16(K& k, V& v) {
10021001
template <typename K, typename V, bool Dir, typename Comp>
10031002
struct BitonicSort {
10041003
static inline __device__ void sort(K k[1], V v[1]) {
1005-
#ifdef __HIP_PLATFORM_HCC__
1004+
#ifdef USE_ROCM
10061005
static_assert(fbgemm_gpu::kWarpSize == 64, "unexpected warp size");
10071006
#else
10081007
static_assert(fbgemm_gpu::kWarpSize == 32, "unexpected warp size");
@@ -1607,7 +1606,7 @@ struct __align__(32) half16 {
16071606
half2 vals[8];
16081607
};
16091608

1610-
#ifdef __HIP_PLATFORM_HCC__
1609+
#ifdef USE_ROCM
16111610
using __nv_bfloat16 = hip_bfloat16;
16121611

16131612
typedef struct __align__(4) {
@@ -1689,7 +1688,7 @@ DEVICE_INLINE half16 to_half16(float_16 v) {
16891688

16901689
// Override __bfloat162float to accept at::BFloat16
16911690
static DEVICE_INLINE float __bfloat162float(const at::BFloat16 input) {
1692-
#ifdef __HIP_PLATFORM_HCC__
1691+
#ifdef USE_ROCM
16931692
return float(*reinterpret_cast<const __nv_bfloat16*>(&input));
16941693
#else
16951694
return __bfloat162float(*reinterpret_cast<const __nv_bfloat16*>(&input));
@@ -1709,7 +1708,7 @@ static DEVICE_INLINE float to_float(const at::BFloat16 input) {
17091708
return __bfloat162float(input);
17101709
}
17111710

1712-
#ifdef __HIP_PLATFORM_HCC__
1711+
#ifdef USE_ROCM
17131712
// the descriptions of __float2bfloat16 and __float2bfloat16_rn are identical
17141713
// https://docs.nvidia.com/cuda/cuda-math-api/group__CUDA__MATH____BFLOAT16__MISC.html#group__CUDA__MATH____BFLOAT16__MISC
17151714
static __host__ __device__ __nv_bfloat16 __float2bfloat16(float f) {
@@ -1829,8 +1828,7 @@ DEVICE_INLINE float_16 make_zero_float_16() {
18291828

18301829
__forceinline__ __device__ __half2
18311830
hfma2(const __half2 a, const __half2 b, const __half2 c) {
1832-
#if (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610) || \
1833-
defined(__HIP_PLATFORM_HCC__)
1831+
#if (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610) || defined(USE_ROCM)
18341832
return __hfma2(a, b, c);
18351833
#else
18361834
float2 fa, fb, fc;
@@ -1844,8 +1842,7 @@ hfma2(const __half2 a, const __half2 b, const __half2 c) {
18441842
}
18451843

18461844
__forceinline__ __device__ half hmul(half a, half b) {
1847-
#if (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610) || \
1848-
defined(__HIP_PLATFORM_HCC__)
1845+
#if (__CUDA_ARCH__ >= 530 && __CUDA_ARCH__ != 610) || defined(USE_ROCM)
18491846
return __hmul(a, b);
18501847
#else
18511848
return __float2half(__half2float(a) * __half2float(b));
@@ -3603,7 +3600,7 @@ DEVICE_INLINE float float16_min(float_16 val) {
36033600
// ROCm does not natively support __any_sync(). Using __ballot()
36043601
// (https://rocmdocs.amd.com/en/latest/Programming_Guides/Kernel_language.html)
36053602
// to implement __any_sync(). Note: the "warp-size" of AMD GPU is 64.
3606-
#ifdef __HIP_PLATFORM_HCC__
3603+
#ifdef USE_ROCM
36073604
__device__ int __any_sync(uint64_t mask, int predicate) {
36083605
uint64_t predicate_bit_pattern = __ballot(predicate);
36093606
return (predicate_bit_pattern & mask) > 0;

fbgemm_gpu/include/fbgemm_gpu/sparse_ops.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#pragma once
1010

11-
#ifdef __HIP_PLATFORM_HCC__
11+
#ifdef USE_ROCM
1212
#define HIPCUB_ARCH 1
1313
#endif
1414

fbgemm_gpu/src/jagged_tensor_ops/common.cuh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ inline bool jagged_dense_dense_elementwise_jagged_output_matches_opt(
661661
matches &= (y_0_reshaped.size(1) < INT_MAX);
662662

663663
int max_shared_bytes;
664-
#ifndef __HIP_PLATFORM_HCC__
664+
#ifndef USE_ROCM
665665
C10_CUDA_CHECK(cudaDeviceGetAttribute(
666666
&max_shared_bytes,
667667
cudaDevAttrMaxSharedMemoryPerBlockOptin,
@@ -671,7 +671,7 @@ inline bool jagged_dense_dense_elementwise_jagged_output_matches_opt(
671671
max_shared_bytes = 64 << 10;
672672
#endif
673673
int shared_kb = max_shared_bytes >> 10;
674-
#ifndef __HIP_PLATFORM_HCC__
674+
#ifndef USE_ROCM
675675
// Use 2/3 of the available GPU shared mem; leave rooms for L1$.
676676
int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
677677
TORCH_CHECK(used_shared_kb > 0);
@@ -779,7 +779,7 @@ void jagged_dense_elementwise_jagged_output_opt_(
779779
at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
780780
if (dynamic_smem_size > cur_max_shared_bytes) {
781781
int max_shared_bytes;
782-
#ifndef __HIP_PLATFORM_HCC__
782+
#ifndef USE_ROCM
783783
C10_CUDA_CHECK(cudaDeviceGetAttribute(
784784
&max_shared_bytes,
785785
cudaDevAttrMaxSharedMemoryPerBlockOptin,
@@ -789,7 +789,7 @@ void jagged_dense_elementwise_jagged_output_opt_(
789789
max_shared_bytes = 64 << 10;
790790
#endif
791791
int shared_kb = max_shared_bytes >> 10;
792-
#ifndef __HIP_PLATFORM_HCC__
792+
#ifndef USE_ROCM
793793
// Use 2/3 of the available GPU shared mem; leave rooms for L1$.
794794
int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
795795
TORCH_CHECK(used_shared_kb > 0);
@@ -798,7 +798,7 @@ void jagged_dense_elementwise_jagged_output_opt_(
798798
int used_shared_kb = shared_kb;
799799
#endif
800800
int used_shared_bytes = used_shared_kb << 10;
801-
#ifndef __HIP_PLATFORM_HCC__
801+
#ifndef USE_ROCM
802802
C10_CUDA_CHECK(cudaFuncSetAttribute(
803803
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
804804
index_t>,

fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ void jagged_dense_dense_elementwise_jagged_output_opt_(
9797
at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock;
9898
if (dynamic_smem_size > cur_max_shared_bytes) {
9999
int max_shared_bytes;
100-
#ifndef __HIP_PLATFORM_HCC__
100+
#ifndef USE_ROCM
101101
C10_CUDA_CHECK(cudaDeviceGetAttribute(
102102
&max_shared_bytes,
103103
cudaDevAttrMaxSharedMemoryPerBlockOptin,
@@ -107,7 +107,7 @@ void jagged_dense_dense_elementwise_jagged_output_opt_(
107107
max_shared_bytes = 64 << 10;
108108
#endif
109109
int shared_kb = max_shared_bytes >> 10;
110-
#ifndef __HIP_PLATFORM_HCC__
110+
#ifndef USE_ROCM
111111
// Use 2/3 of the available GPU shared mem; leave rooms for L1$.
112112
int used_shared_kb = round_down(shared_kb * 2 / 3, 16);
113113
TORCH_CHECK_GT(used_shared_kb, 0);
@@ -116,7 +116,7 @@ void jagged_dense_dense_elementwise_jagged_output_opt_(
116116
int used_shared_kb = shared_kb;
117117
#endif
118118
int used_shared_bytes = used_shared_kb << 10;
119-
#ifndef __HIP_PLATFORM_HCC__
119+
#ifndef USE_ROCM
120120
C10_CUDA_CHECK(cudaFuncSetAttribute(
121121
jagged_dense_dense_elementwise_jagged_output_opt_search_kernel_<
122122
index_t>,

fbgemm_gpu/src/quantize_ops/common.cuh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#include <ATen/TensorIterator.h>
1010
#include <ATen/cuda/Exceptions.h>
1111
#include <c10/cuda/CUDAGuard.h>
12-
#ifndef __HIP_PLATFORM_HCC__
12+
#ifndef USE_ROCM
1313
#include <math_constants.h>
1414
#endif
1515

fbgemm_gpu/src/quantize_ops/quantize_fused_8bit_rowwise.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ __global__ inline void _get_8bit_qparam_cuda_kernel(
6464
const int output_columns = ncols_aligned + 2 * sizeof(float);
6565

6666
// starting values for future reductions
67-
#ifdef __HIP_PLATFORM_HCC__
67+
#ifdef USE_ROCM
6868
#define HIPRT_INF_F __int_as_float(0x7f800000)
6969
float minimum_element = HIPRT_INF_F;
7070
float maximum_element = -HIPRT_INF_F;

0 commit comments

Comments
 (0)