Skip to content

Commit 540db1f

Browse files
sryapfacebook-github-bot
authored andcommitted
Add a workaround for stochastic rounding for AMD GPUs (pytorch#997)
Summary: X-link: pytorch#3908 Pull Request resolved: facebookresearch/FBGEMM#997 This diff contains a workaround for the stochastic rounding issue for the AMD GPUs. Problem: `quantize_store` calls `nearest_rounding_vector` instead of `stochastic_rounding_vector` when stochastic rounding is used because the `StochasticRoundingRNGState` pointer is a nullptr (https://fburl.com/code/kna14icj) We found that the `WeightRow` constructor also gets a null `StochasticRoundingRNGState` pointer (https://fburl.com/code/vyq53lia) When `WeightRow` is instantiated, we confirm that `stochastic_rounding` is true. `WeightRow` should receive `&state`, but instead it receives a nullptr. (https://fburl.com/code/o3kxgt4z) We suspect that the compiler might have optimized out the `StochasticRoundingRNGState` since it is only passed to `WeightRow` and not utilized anywhere else in the caller kernel. Workaround: We move the `StochasticRoundingRNGState` storage inside the `WeightRow` struct and pass a boolean to the `WeightRow` constructor instead. Reviewed By: q10, yinbinm, jianyuh, xw285cornell, yoyoyocmu, joebos Differential Revision: D72201618 fbshipit-source-id: a2bc7f004ac5183c84eb0501ada6d848ebca17e1
1 parent b57ac02 commit 540db1f

File tree

5 files changed

+17
-19
lines changed

5 files changed

+17
-19
lines changed

fbgemm_gpu/codegen/training/optimizer/embedding_optimizer_split_device_kernel_template.cuh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,13 +97,12 @@ DEVICE_INLINE void {{ mdesc }}_{{ optimizer }}_table_update_kernel(
9797
}
9898
{%- endfor %}
9999

100-
StochasticRoundingRNGState state;
101100
auto weight_row_template =
102101
WeightRow<emb_t, cache_t, at::acc_type<cache_t, true>>(
103102
weights,
104103
cache_weights,
105104
D,
106-
stochastic_rounding ? &state : nullptr,
105+
stochastic_rounding,
107106
&stochastic_rounding_philox_args,
108107
threadIdx.x + run_id * blockDim.x);
109108

fbgemm_gpu/include/fbgemm_gpu/utils/weight_row.cuh

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -120,22 +120,20 @@ struct WeightRow {
120120
: row_(row),
121121
cache_row_(cache_row),
122122
dim_(dim),
123-
stoc_rounding_state_(nullptr) {}
123+
stoc_rounding_state_ptr_(nullptr) {}
124124

125125
// Constructor for stochastic rounding
126126
DEVICE_INLINE WeightRow(
127127
emb_t* row,
128128
cache_t* cache_row,
129129
int dim,
130-
StochasticRoundingRNGState* stoc_rounding_state,
130+
bool stochastic_rounding,
131131
const at::PhiloxCudaState* stochastic_rounding_philox_args,
132132
const uint64_t salt_value)
133133
: row_(row), cache_row_(cache_row), dim_(dim) {
134-
// Set the internal stoc_rounding_state_
135-
stoc_rounding_state_ = stoc_rounding_state;
136-
134+
stoc_rounding_state_ptr_ = nullptr;
137135
if constexpr (!std::is_same_v<emb_t, float>) {
138-
if (stoc_rounding_state != nullptr) {
136+
if (stochastic_rounding) {
139137
const auto stochastic_rounding_seeds =
140138
at::cuda::philox::unpack(*stochastic_rounding_philox_args);
141139

@@ -145,15 +143,18 @@ struct WeightRow {
145143
// The salt value should be different for every *run* and every
146144
// *thread*.
147145
salt_value,
148-
stoc_rounding_state);
146+
&stoc_rounding_state_);
147+
// Store the pointer here to avoid an if-else cond during load/store
148+
stoc_rounding_state_ptr_ = &stoc_rounding_state_;
149149
}
150150
}
151151
}
152152

153153
emb_t* row_;
154154
cache_t* cache_row_;
155155
int dim_;
156-
StochasticRoundingRNGState* stoc_rounding_state_;
156+
StochasticRoundingRNGState stoc_rounding_state_;
157+
StochasticRoundingRNGState* stoc_rounding_state_ptr_;
157158

158159
// Load from cache if resident; else load from embedding
159160
DEVICE_INLINE Vec4T<dst_t> load(const int32_t d, const float2 qparams) const {
@@ -169,9 +170,9 @@ struct WeightRow {
169170
DEVICE_INLINE void
170171
store(const Vec4T<dst_t>& v, const int32_t d, const float2 qparams) {
171172
if (cache_row_) {
172-
quantize_store(cache_row_ + d, v, stoc_rounding_state_, qparams);
173+
quantize_store(cache_row_ + d, v, stoc_rounding_state_ptr_, qparams);
173174
} else {
174-
quantize_store(row_ + d, v, stoc_rounding_state_, qparams);
175+
quantize_store(row_ + d, v, stoc_rounding_state_ptr_, qparams);
175176
}
176177
}
177178

@@ -201,7 +202,7 @@ struct WeightRow {
201202
} else {
202203
// Does 2-step conversion: cache_t -> FP32 -> weight_t
203204
const auto cache_slice = load(d, qparams);
204-
quantize_store(row_ + d, cache_slice, stoc_rounding_state_, qparams);
205+
quantize_store(row_ + d, cache_slice, stoc_rounding_state_ptr_, qparams);
205206
}
206207
}
207208

@@ -236,7 +237,7 @@ struct WeightRow {
236237
// Does 2-step conversion: weight_t -> FP32 -> cache_t
237238
for (int32_t d = lane_id * 4; d < dim_length; d += num_lanes * 4) {
238239
const auto slice = load(d, qparams);
239-
quantize_store(dst_row + d, slice, stoc_rounding_state_, qparams);
240+
quantize_store(dst_row + d, slice, stoc_rounding_state_ptr_, qparams);
240241
}
241242
}
242243
}

fbgemm_gpu/src/split_embeddings_cache/lfu_cache_populate.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,11 @@ __global__ __launch_bounds__(kCacheMaxThreads) void lfu_cache_insert_kernel(
116116
if constexpr (std::is_same_v<emb_t, uint8_t>) {
117117
D_emb += kINT8QparamsBytes;
118118
}
119-
StochasticRoundingRNGState state;
120119
auto weight_row = WeightRow<emb_t, cache_t, cache_t>(
121120
&weights[weights_offset_current + idx_current * D_emb + 0],
122121
&lxu_cache_weights[cache_set * kWarpSize + insert_slot][0],
123122
D_current,
124-
stochastic_rounding ? &state : nullptr,
123+
stochastic_rounding,
125124
&stochastic_rounding_philox_args,
126125
(blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x +
127126
threadIdx.x) *

fbgemm_gpu/src/split_embeddings_cache/lru_cache_populate.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,11 @@ __global__ __launch_bounds__(kMaxThreads) void lru_cache_insert_kernel(
123123
D_emb += kINT8QparamsBytes;
124124
}
125125

126-
StochasticRoundingRNGState state;
127126
auto weight_row = WeightRow<emb_t, cache_t, cache_t>(
128127
&weights[weights_offset_current + idx_current * D_emb + 0],
129128
&lxu_cache_weights[cache_set * kWarpSize + insert_slot][0],
130129
D_current,
131-
stochastic_rounding ? &state : nullptr,
130+
stochastic_rounding,
132131
&stochastic_rounding_philox_args,
133132
stoc_rounding_salt + l);
134133

fbgemm_gpu/src/split_embeddings_cache/lxu_cache.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ __global__ __launch_bounds__(kMaxThreads) void lxu_cache_flush_kernel(
6060
&weights[weights_offset_current + idx_current * D_emb + 0],
6161
&lxu_cache_weights[b][0],
6262
D_current,
63-
stochastic_rounding ? &state : nullptr,
63+
stochastic_rounding,
6464
&stochastic_rounding_philox_args,
6565
blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x +
6666
threadIdx.x);

0 commit comments

Comments
 (0)