Skip to content

Commit 610e8dd

Browse files
q10facebook-github-bot
authored andcommitted
Refactor OptionalCUDAGuard -> CUDA_DEVICE_GUARD (pytorch#2270)
Summary: Pull Request resolved: pytorch#2270 - Refactor OptionalCUDAGuard -> CUDA_DEVICE_GUARD Reviewed By: jianyuh Differential Revision: D52820946 fbshipit-source-id: cdf1b98f55450c31f45899fff8bb90b001589a43
1 parent 54a56f0 commit 610e8dd

File tree

65 files changed

+121
-240
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+121
-240
lines changed

fbgemm_gpu/codegen/embedding_backward_split_indice_weights_template.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,8 +247,8 @@ Tensor {{ ddesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
247247
TENSOR_ON_CUDA_GPU(feature_requires_grad);
248248
}
249249

250-
at::cuda::OptionalCUDAGuard device_guard;
251-
device_guard.set_index(dev_weights.get_device());
250+
CUDA_DEVICE_GUARD(dev_weights);
251+
252252
const auto T = D_offsets.size(0) - 1;
253253
TORCH_CHECK_GT(T, 0);
254254
// offsets = [B x T + 1]

fbgemm_gpu/codegen/embedding_backward_split_template.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,8 +427,7 @@ Tensor split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}_e
427427
}
428428
{%- endif %}
429429

430-
at::cuda::OptionalCUDAGuard device_guard;
431-
device_guard.set_index(dev_weights.get_device());
430+
CUDA_DEVICE_GUARD(dev_weights);
432431

433432
{%- if nobag and not is_index_select %}
434433
auto max_D = D;

fbgemm_gpu/codegen/embedding_bounds_check.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,7 @@ void bounds_check_indices_cuda(
190190
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
191191
rows_per_table, indices, offsets, warning, weights, B_offsets);
192192

193-
at::cuda::OptionalCUDAGuard device_guard;
194-
device_guard.set_index(rows_per_table.get_device());
193+
CUDA_DEVICE_GUARD(rows_per_table);
195194

196195
const int32_t T = rows_per_table.size(0);
197196
const int32_t total_B = offsets.size(0) - 1;

fbgemm_gpu/codegen/embedding_forward_quantized_split_lookup.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,8 @@ Tensor pruned_hashmap_lookup_cuda(
140140
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
141141
indices, offsets, hash_table, hash_table_offsets);
142142

143-
at::cuda::OptionalCUDAGuard device_guard;
144-
device_guard.set_index(indices.get_device());
143+
CUDA_DEVICE_GUARD(indices);
144+
145145
auto dense_indices = at::empty_like(indices);
146146
const int32_t T = hash_table_offsets.size(0) - 1;
147147
const int32_t B = (offsets.size(0) - 1) / T;
@@ -179,8 +179,8 @@ Tensor pruned_array_lookup_cuda(
179179
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
180180
indices, offsets, index_remappings, index_remappings_offsets);
181181
182-
at::cuda::OptionalCUDAGuard device_guard;
183-
device_guard.set_index(indices.get_device());
182+
CUDA_DEVICE_GUARD(indices);
183+
184184
auto dense_indices = at::empty_like(indices);
185185
const int32_t T = index_remappings_offsets.size(0) - 1;
186186
TORCH_CHECK(

fbgemm_gpu/codegen/embedding_forward_quantized_split_nbit_host_template.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,7 @@ Tensor int_nbit_split_embedding{{ "_nobag" if nobag else "" }}_codegen_forward_{
107107
TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_weights, dev_weights);
108108
TENSORS_EMPTY_OR_ON_SAME_DEVICE(lxu_cache_locations, dev_weights);
109109

110-
at::cuda::OptionalCUDAGuard device_guard;
111-
device_guard.set_index(dev_weights.get_device());
110+
CUDA_DEVICE_GUARD(dev_weights);
112111

113112
// kernels assume indices are contiguous.
114113
indices = indices.contiguous();

fbgemm_gpu/codegen/embedding_forward_split_template.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,7 @@ batch_index_select_dim0_codegen_forward_cuda(
364364
}
365365
{%- endif %}
366366

367-
at::cuda::OptionalCUDAGuard device_guard;
368-
device_guard.set_index(dev_weights.get_device());
367+
CUDA_DEVICE_GUARD(dev_weights);
369368

370369
{%- if not nobag %}
371370
int32_t T = D_offsets.numel() - 1;

fbgemm_gpu/codegen/embedding_optimizer_split_template.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,7 @@ void split_embedding_{{ optimizer }}_update(
8282
return;
8383
}
8484

85-
at::cuda::OptionalCUDAGuard device_guard;
86-
device_guard.set_index(dev_weights.get_device());
85+
CUDA_DEVICE_GUARD(dev_weights);
8786

8887
// Flatten dev_weights because it is currrently 2D
8988
dev_weights = dev_weights.flatten();

fbgemm_gpu/src/embedding_inplace_ops/embedding_inplace_update.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,7 @@ void embedding_inplace_update_cuda(
133133
lxu_cache_weights,
134134
lxu_cache_locations);
135135

136-
at::cuda::OptionalCUDAGuard device_guard;
137-
device_guard.set_index(dev_weights.get_device());
136+
CUDA_DEVICE_GUARD(dev_weights);
138137

139138
const int64_t N = update_row_idx.numel();
140139
if (N == 0) {
@@ -226,9 +225,8 @@ Tensor pruned_array_lookup_from_row_idx_cuda(
226225
update_table_indices,
227226
index_remappings,
228227
index_remappings_offsets);
228+
CUDA_DEVICE_GUARD(update_table_indices);
229229
230-
at::cuda::OptionalCUDAGuard device_guard;
231-
device_guard.set_index(update_table_indices.get_device());
232230
auto dense_indices = at::empty_like(update_row_indices);
233231
const int32_t T = index_remappings_offsets.size(0) - 1;
234232

fbgemm_gpu/src/histogram_binning_calibration_ops.cu

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,7 @@ std::tuple<Tensor, Tensor> histogram_binning_calibration_cuda(
6464
TENSOR_ON_CUDA_GPU(bin_num_examples);
6565
TENSOR_ON_CUDA_GPU(bin_num_positives);
6666
TORCH_CHECK_EQ(bin_num_examples.numel(), bin_num_positives.numel());
67-
68-
at::cuda::OptionalCUDAGuard device_guard;
69-
device_guard.set_index(logit.get_device());
67+
CUDA_DEVICE_GUARD(logit);
7068

7169
Tensor calibrated_prediction = at::empty_like(logit);
7270
Tensor bin_ids = at::empty({logit.numel()}, logit.options().dtype(at::kLong));
@@ -188,9 +186,7 @@ std::tuple<Tensor, Tensor> histogram_binning_calibration_by_feature_cuda(
188186
TENSOR_ON_CUDA_GPU(bin_num_examples);
189187
TENSOR_ON_CUDA_GPU(bin_num_positives);
190188
TORCH_CHECK_EQ(bin_num_examples.numel(), bin_num_positives.numel());
191-
192-
at::cuda::OptionalCUDAGuard device_guard;
193-
device_guard.set_index(logit.get_device());
189+
CUDA_DEVICE_GUARD(logit);
194190
195191
// Convert lengths to offsets for better handling on GPUs.
196192
const auto segment_lengths_packed = segment_lengths.contiguous();
@@ -351,9 +347,7 @@ generic_histogram_binning_calibration_by_feature_cuda(
351347
TORCH_CHECK(
352348
bin_num_examples.numel() ==
353349
(num_segments + 1) * (bin_boundaries.numel() + 1));
354-
355-
at::cuda::OptionalCUDAGuard device_guard;
356-
device_guard.set_index(logit.get_device());
350+
CUDA_DEVICE_GUARD(logit);
357351
358352
// Convert lengths to offsets for better handling on GPUs.
359353
const auto segment_lengths_packed = segment_lengths.contiguous();

fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_backward.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,7 @@ std::tuple<Tensor, Tensor> batched_dense_vec_jagged_2d_mul_backward(
9191
const Tensor& a_values,
9292
const Tensor& a_offsets) {
9393
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad_output, a_values, a_offsets, v);
94-
95-
at::cuda::OptionalCUDAGuard device_guard;
96-
device_guard.set_index(grad_output.get_device());
94+
CUDA_DEVICE_GUARD(grad_output);
9795

9896
const int B = a_offsets.numel() - 1;
9997
const int D = grad_output.size(-1);

fbgemm_gpu/src/jagged_tensor_ops/batched_dense_vec_jagged_2d_mul_forward.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ Tensor batched_dense_vec_jagged_2d_mul_forward(
5656
const Tensor& a_values,
5757
const Tensor& a_offsets) {
5858
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(v, a_values, a_offsets);
59-
60-
at::cuda::OptionalCUDAGuard device_guard;
61-
device_guard.set_index(v.get_device());
59+
CUDA_DEVICE_GUARD(v);
6260

6361
const int B = a_offsets.numel() - 1;
6462
TORCH_CHECK(

fbgemm_gpu/src/jagged_tensor_ops/dense_to_jagged_forward.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@ Tensor dense_to_jagged_forward(
2929
auto values = at::empty_symint({total_L_computed, D}, dense.options());
3030
auto output = at::empty_like(values);
3131

32-
at::cuda::OptionalCUDAGuard device_guard;
33-
device_guard.set_index(dense.get_device());
32+
CUDA_DEVICE_GUARD(dense);
3433

3534
#define DISPATCH_DENSE_TO_JAGGED_CASE(TYPE) \
3635
AT_DISPATCH_CASE(TYPE, [&] { \

fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_bmm_forward.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,7 @@ Tensor jagged_dense_bmm_forward_cuda(
156156
const Tensor& y,
157157
const int64_t max_L) {
158158
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(x_values, x_offsets, y);
159-
160-
at::cuda::OptionalCUDAGuard device_guard;
161-
device_guard.set_index(x_values.get_device());
159+
CUDA_DEVICE_GUARD(x_values);
162160

163161
const int B = x_offsets.numel() - 1;
164162
const int M = x_values.size(-1);

fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_dense_elementwise_add_jagged_output_forward.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,7 @@ Tensor jagged_dense_dense_elementwise_add_jagged_output_forward(
218218
TORCH_CHECK_EQ(dense_0.sizes(), dense_1.sizes());
219219
auto output = at::empty_like(x_values);
220220
221-
at::cuda::OptionalCUDAGuard device_guard;
222-
device_guard.set_index(dense_0.get_device());
221+
CUDA_DEVICE_GUARD(dense_0);
223222
224223
if (x_values.scalar_type() == at::ScalarType::BFloat16 &&
225224
dense_0.scalar_type() == at::ScalarType::BFloat16 &&

fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_elementwise_mul_backward.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,7 @@ std::tuple<Tensor, Tensor> jagged_dense_elementwise_mul_backward(
128128
const std::vector<Tensor>& x_offsets,
129129
const Tensor& y,
130130
const Tensor& x_values) {
131-
at::cuda::OptionalCUDAGuard device_guard;
132-
device_guard.set_index(grad_output.get_device());
131+
CUDA_DEVICE_GUARD(grad_output);
133132

134133
Tensor x_values_grad = at::empty_like(grad_output);
135134
Tensor y_grad = at::empty_like(y);

fbgemm_gpu/src/jagged_tensor_ops/jagged_dense_elementwise_mul_forward.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@ Tensor jagged_dense_elementwise_mul_forward(
1616
const Tensor& x_values,
1717
const std::vector<Tensor>& x_offsets,
1818
const Tensor& y) {
19-
at::cuda::OptionalCUDAGuard device_guard;
20-
device_guard.set_index(x_values.get_device());
19+
CUDA_DEVICE_GUARD(x_values);
2120

2221
Tensor output = at::empty_like(x_values);
2322

fbgemm_gpu/src/jagged_tensor_ops/jagged_index_add_2d_forward.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,7 @@ Tensor jagged_index_add_2d_forward_cuda(
7878
const int64_t num_output_rows) {
7979
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
8080
values, indices, input_offsets, output_offsets);
81-
82-
at::cuda::OptionalCUDAGuard device_guard;
83-
device_guard.set_index(values.get_device());
81+
CUDA_DEVICE_GUARD(values);
8482

8583
auto num_cols = values.size(1);
8684

fbgemm_gpu/src/jagged_tensor_ops/jagged_index_select_2d_forward.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@ Tensor jagged_index_select_2d_forward_cuda(
7474
const int64_t num_dense_output_rows) {
7575
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
7676
values, indices, input_offsets, output_offsets);
77-
78-
at::cuda::OptionalCUDAGuard device_guard;
79-
device_guard.set_index(values.get_device());
77+
CUDA_DEVICE_GUARD(values);
8078

8179
auto num_cols = values.size(1);
8280

fbgemm_gpu/src/jagged_tensor_ops/jagged_jagged_bmm_forward.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,9 +162,7 @@ Tensor jagged_jagged_bmm_forward_cuda(
162162
const Tensor& offsets,
163163
const int64_t max_L) {
164164
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(x_values, y_values, offsets);
165-
166-
at::cuda::OptionalCUDAGuard device_guard;
167-
device_guard.set_index(x_values.get_device());
165+
CUDA_DEVICE_GUARD(x_values);
168166

169167
const int B = offsets.numel() - 1;
170168
const int M = x_values.size(-1);

fbgemm_gpu/src/jagged_tensor_ops/jagged_softmax_backward.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,7 @@ Tensor jagged_softmax_backward_cuda(
9696
const Tensor& offsets,
9797
const int64_t max_L) {
9898
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(grad_output, output, offsets);
99-
100-
at::cuda::OptionalCUDAGuard device_guard;
101-
device_guard.set_index(grad_output.get_device());
99+
CUDA_DEVICE_GUARD(grad_output);
102100

103101
const auto B = offsets.numel() - 1;
104102
const auto D = grad_output.size(1);

fbgemm_gpu/src/jagged_tensor_ops/jagged_softmax_forward.cu

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,7 @@ Tensor jagged_softmax_forward_cuda(
119119
const Tensor& offsets,
120120
const int64_t max_L) {
121121
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(values, offsets);
122-
123-
at::cuda::OptionalCUDAGuard device_guard;
124-
device_guard.set_index(values.get_device());
122+
CUDA_DEVICE_GUARD(values);
125123

126124
const auto B = offsets.numel() - 1;
127125
const auto D = values.size(1);

fbgemm_gpu/src/jagged_tensor_ops/jagged_to_padded_dense_backward.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ at::Tensor jagged_to_padded_dense_backward(
1717
const std::vector<Tensor>& offsets,
1818
at::SymInt total_L) {
1919
auto grad_padded_values = grad_output;
20-
at::cuda::OptionalCUDAGuard device_guard;
21-
device_guard.set_index(grad_padded_values.get_device());
20+
CUDA_DEVICE_GUARD(grad_padded_values);
2221

2322
// Canonicalize padded_values by unsqueeze the last dim if the inner dense
2423
// dimension is 1 and folded.

fbgemm_gpu/src/jagged_tensor_ops/jagged_to_padded_dense_forward.cu

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@ at::Tensor jagged_to_padded_dense_forward(
3030
max_lengths.size(),
3131
" != num_jagged_dim, ",
3232
num_jagged_dim);
33-
at::cuda::OptionalCUDAGuard device_guard;
34-
device_guard.set_index(values.get_device());
33+
CUDA_DEVICE_GUARD(values);
3534

3635
const Tensor values_canonicalized = values.view(
3736
{values.size(0),
@@ -83,8 +82,7 @@ std::vector<Tensor> stacked_jagged_1d_to_dense_gpu(
8382
int64_t padding_value) {
8483
TORCH_CHECK(values.dim() == 1);
8584
TORCH_CHECK(lengths.dim() == 2);
86-
at::cuda::OptionalCUDAGuard device_guard;
87-
device_guard.set_index(values.get_device());
85+
CUDA_DEVICE_GUARD(values);
8886

8987
const auto lengths_contig = lengths.contiguous();
9088
int32_t B = lengths.size(1);
@@ -138,8 +136,7 @@ stacked_jagged_2d_to_dense_forward_cuda(
138136
int64_t padding_value) {
139137
TORCH_CHECK(values.dim() == 2);
140138
TORCH_CHECK(lengths.dim() == 2);
141-
at::cuda::OptionalCUDAGuard device_guard;
142-
device_guard.set_index(values.get_device());
139+
CUDA_DEVICE_GUARD(values);
143140

144141
const auto lengths_contig = lengths.contiguous();
145142
int32_t D = values.size(1);
@@ -194,8 +191,7 @@ Tensor stacked_jagged_2d_to_dense_backward_cuda(
194191
const std::vector<Tensor>& grad_padded_values_per_key,
195192
const std::vector<Tensor>& offsets_tensor_per_key,
196193
const std::vector<int64_t>& offset_per_key) {
197-
at::cuda::OptionalCUDAGuard device_guard;
198-
device_guard.set_index(grad_padded_values_per_key[0].get_device());
194+
CUDA_DEVICE_GUARD(grad_padded_values_per_key[0]);
199195

200196
auto grad_values =
201197
at::zeros({total_L, D}, grad_padded_values_per_key[0].options());
@@ -321,8 +317,7 @@ class JaggedDenseAddJaggedOutputGPUOp
321317

322318
auto output = at::empty_like(x_values);
323319

324-
at::cuda::OptionalCUDAGuard device_guard;
325-
device_guard.set_index(dense.get_device());
320+
CUDA_DEVICE_GUARD(dense);
326321

327322
AT_DISPATCH_SWITCH(
328323
x_values.scalar_type(),
@@ -364,9 +359,7 @@ class JaggedDenseAddJaggedOutputGPUOp
364359
auto offsets = ctx->get_saved_variables();
365360
auto dense_shape = ctx->saved_data["dense_shape"].toIntVector();
366361
TORCH_CHECK(grad_outputs.size() == 1);
367-
368-
at::cuda::OptionalCUDAGuard device_guard;
369-
device_guard.set_index(grad_outputs[0].get_device());
362+
CUDA_DEVICE_GUARD(grad_outputs[0]);
370363

371364
Tensor dense_values_grad = jagged_to_padded_dense_forward(
372365
grad_outputs[0],

fbgemm_gpu/src/jagged_tensor_ops/keyed_jagged_index_select_dim1.cu

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,7 @@ class KeyedJaggedIndexSelectDim1GPUOp
194194
"weights size and values size must be the same");
195195
}
196196

197-
at::cuda::OptionalCUDAGuard device_guard;
198-
device_guard.set_index(values.get_device());
197+
CUDA_DEVICE_GUARD(values);
199198

200199
const int num_batches = lengths.numel() / batch_size;
201200
const int num_output_lengths = num_batches * indices.numel();
@@ -380,8 +379,7 @@ class KeyedJaggedIndexSelectDim1GPUOp
380379
int64_t output_batch_size = ctx->saved_data["batch_size"].toInt();
381380
int64_t num_batches = ctx->saved_data["num_batches"].toInt();
382381
383-
at::cuda::OptionalCUDAGuard device_guard;
384-
device_guard.set_index(grad.get_device());
382+
CUDA_DEVICE_GUARD(grad);
385383
386384
Tensor grad_input = at::zeros({num_outputs}, grad.options());
387385
auto grid_size = cuda_calc_xblock_count(grad.numel(), kMaxThreads);

fbgemm_gpu/src/layout_transform_ops/layout_transform_ops.cu

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ Tensor recat_embedding_grad_output_cuda(
3737
const std::vector<int64_t>& num_features_per_rank) {
3838
TENSOR_ON_CUDA_GPU(grad_output);
3939

40-
at::cuda::OptionalCUDAGuard device_guard;
41-
device_guard.set_index(grad_output.get_device());
40+
CUDA_DEVICE_GUARD(grad_output);
4241

4342
TORCH_CHECK(grad_output.is_contiguous());
4443
const auto B_local = grad_output.size(0);
@@ -82,8 +81,7 @@ Tensor recat_embedding_grad_output_mixed_D_cuda(
8281
TENSOR_ON_CUDA_GPU(grad_output);
8382
TORCH_CHECK(grad_output.is_contiguous());
8483

85-
at::cuda::OptionalCUDAGuard device_guard;
86-
device_guard.set_index(grad_output.get_device());
84+
CUDA_DEVICE_GUARD(grad_output);
8785

8886
const auto B_local = grad_output.size(0);
8987
const auto global_dim_sum = at::sum_integers(dim_sum_per_rank);
@@ -129,8 +127,7 @@ Tensor recat_embedding_grad_output_mixed_D_batch_cuda(
129127
grad_output, dim_sum_per_rank, cumsum_dim_sum_per_rank);
130128
TORCH_CHECK(grad_output.is_contiguous());
131129

132-
at::cuda::OptionalCUDAGuard device_guard;
133-
device_guard.set_index(grad_output.get_device());
130+
CUDA_DEVICE_GUARD(grad_output);
134131

135132
const auto B_local = grad_output.size(0);
136133
Tensor sharded_grad_output =

0 commit comments

Comments
 (0)