Skip to content

Commit 3793f85

Browse files
q10facebook-github-bot
authored andcommitted
Add support for int32_t indices in TBE training (2G/N) (pytorch#3377)
Summary: X-link: facebookresearch/FBGEMM#622 X-link: facebookresearch/FBGEMM#468 - Add `index_t` support to TBE training backward kernels Reviewed By: basilwong Differential Revision: D65960050
1 parent 0a8ef1a commit 3793f85

File tree

1 file changed

+21
-18
lines changed

1 file changed

+21
-18
lines changed

fbgemm_gpu/codegen/training/backward/embedding_backward_split_indice_weights_template.cu

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ template <
6464
typename emb_t,
6565
typename grad_t,
6666
typename cache_t,
67+
typename index_t,
6768
int32_t kFixedMaxVecsPerThread
6869
>
6970
__global__ __launch_bounds__(kForwardMaxThreads) void
@@ -78,8 +79,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
7879
{%- endif %}
7980
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> weights_offsets,
8081
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> D_offsets,
81-
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L]
82-
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> offsets, // [B x T + 1]
82+
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L]
83+
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets, // [B x T + 1]
8384
{%- if not dense %}
8485
const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1, at::RestrictPtrTraits> {{ locs_or_addrs_tensor }},
8586
{%- endif %}
@@ -113,17 +114,17 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
113114
fd_B.DivMod(b_t, &t, &b);
114115
{%- endif %}
115116

116-
int64_t weights_offset = weights_offsets[t];
117-
int32_t D_start = D_offsets[t];
118-
int32_t D_end = D_offsets[t + 1];
119-
int32_t D = D_end - D_start;
120-
int64_t indices_start = offsets[b_t];
121-
int64_t indices_end = offsets[b_t + 1];
122-
int32_t L = indices_end - indices_start;
117+
const auto weights_offset = weights_offsets[t];
118+
const auto D_start = D_offsets[t];
119+
const auto D_end = D_offsets[t + 1];
120+
const auto D = D_end - D_start;
121+
const auto indices_start = offsets[b_t];
122+
const auto indices_end = offsets[b_t + 1];
123+
const auto L = indices_end - indices_start;
123124
if (feature_requires_grad.size(0) > 0 && !feature_requires_grad[t]) {
124125
// If the table does not require gradient computation, we set the gradient to zero.
125-
for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) {
126-
int32_t l = l_start + threadIdx.x;
126+
for (auto l_start = 0; l_start < L; l_start += kWarpSize) {
127+
auto l = l_start + threadIdx.x;
127128
if (l < L) {
128129
grad_indice_weights[indices_start + l] = 0.0;
129130
}
@@ -173,14 +174,14 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
173174

174175
for (int32_t l_start = 0; l_start < L; l_start += kWarpSize) {
175176
int32_t l = l_start + threadIdx.x;
176-
int64_t idx = l < L ? indices[indices_start + l] : 0;
177+
index_t idx = l < L ? indices[indices_start + l] : 0;
177178
{%- if not dense %}
178179
const auto {{ locs_or_addrs_idx }} =
179180
(placement == PlacementType::MANAGED_CACHING && l < L)
180181
? {{ locs_or_addrs_tensor }}[indices_start + l] : 0;
181182
{%- endif %}
182183
for (auto j = 0; j < kWarpSize && l_start + j < L; ++j) {
183-
int64_t idx_j = shfl_sync(idx, j);
184+
auto idx_j = shfl_sync(idx, j);
184185
{%- if not dense %}
185186
const auto {{ locs_or_addrs_idx }}_j = shfl_sync({{ locs_or_addrs_idx }}, j);
186187
{%- endif %}
@@ -354,6 +355,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
354355
const uint32_t info_B_mask = info_B_mask_int64;
355356
{%- endif %}
356357

358+
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel_1", [&] {
357359
DISPATCH_EMB_GRAD_CACHE_TYPES(
358360
dev_weights.scalar_type(),
359361
aligned_grad_output.scalar_type(),
@@ -362,7 +364,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
362364
{%- else %}
363365
dev_weights.scalar_type(),
364366
{%- endif %}
365-
"split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel",
367+
"split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel_2",
366368
[&] {
367369
{%- if vbe %}
368370
const auto& grad_output_reshaped = aligned_grad_output.reshape({1, -1});
@@ -379,13 +381,13 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
379381
mdesc, vdesc, vbdesc)
380382
%}
381383
#ifdef FBGEMM_GPU_MEMCHECK
382-
const auto func_name =
383-
"{{ kernel_name }}";
384+
const auto func_name = "{{ kernel_name }}";
384385
#endif
385386
{{ kernel_name }}<
386387
emb_t,
387388
grad_t,
388389
cache_t,
390+
index_t,
389391
kFixedMaxVecsPerThread><<<
390392
div_round_up(total_B, kForwardMaxThreads / kWarpSize),
391393
dim3(kWarpSize, kForwardMaxThreads / kWarpSize),
@@ -400,8 +402,8 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
400402
{%- endif %}
401403
MAKE_PTA_WITH_NAME(func_name, weights_offsets, int64_t, 1, 32),
402404
MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t, 1, 32),
403-
MAKE_PTA_WITH_NAME(func_name, indices, int64_t, 1, 32),
404-
MAKE_PTA_WITH_NAME(func_name, offsets, int64_t, 1, 32),
405+
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
406+
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32),
405407
{%- if not dense %}
406408
MAKE_PTA_WITH_NAME(func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1, 32),
407409
{%- endif %}
@@ -421,6 +423,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
421423
});
422424
{%- endfor %} {# /* for use_vec_blocking */ #}
423425
});
426+
});
424427
425428
C10_CUDA_KERNEL_LAUNCH_CHECK();
426429
return grad_indice_weights;

0 commit comments

Comments
 (0)