@@ -64,6 +64,7 @@ template <
64
64
typename emb_t ,
65
65
typename grad_t ,
66
66
typename cache_t ,
67
+ typename index_t ,
67
68
int32_t kFixedMaxVecsPerThread
68
69
>
69
70
__global__ __launch_bounds__ (kForwardMaxThreads ) void
@@ -78,8 +79,8 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
78
79
{%- endif %}
79
80
const pta::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits> weights_offsets,
80
81
const pta::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> D_offsets,
81
- const pta::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits> indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L]
82
- const pta::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits> offsets, // [B x T + 1]
82
+ const pta::PackedTensorAccessor32<index_t , 1 , at::RestrictPtrTraits> indices, // [N = \sum_{b,t} L_{b,t} total indices, i.e. flattened [B][T][L]
83
+ const pta::PackedTensorAccessor32<index_t , 1 , at::RestrictPtrTraits> offsets, // [B x T + 1]
83
84
{%- if not dense %}
84
85
const pta::PackedTensorAccessor32<{{ locs_or_addrs_type }}, 1 , at::RestrictPtrTraits> {{ locs_or_addrs_tensor }},
85
86
{%- endif %}
@@ -113,17 +114,17 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
113
114
fd_B.DivMod (b_t , &t, &b);
114
115
{%- endif %}
115
116
116
- int64_t weights_offset = weights_offsets[t];
117
- int32_t D_start = D_offsets[t];
118
- int32_t D_end = D_offsets[t + 1 ];
119
- int32_t D = D_end - D_start;
120
- int64_t indices_start = offsets[b_t ];
121
- int64_t indices_end = offsets[b_t + 1 ];
122
- int32_t L = indices_end - indices_start;
117
+ const auto weights_offset = weights_offsets[t];
118
+ const auto D_start = D_offsets[t];
119
+ const auto D_end = D_offsets[t + 1 ];
120
+ const auto D = D_end - D_start;
121
+ const auto indices_start = offsets[b_t ];
122
+ const auto indices_end = offsets[b_t + 1 ];
123
+ const auto L = indices_end - indices_start;
123
124
if (feature_requires_grad.size (0 ) > 0 && !feature_requires_grad[t]) {
124
125
// If the table does not require gradient computation, we set the gradient to zero.
125
- for (int32_t l_start = 0 ; l_start < L; l_start += kWarpSize ) {
126
- int32_t l = l_start + threadIdx .x ;
126
+ for (auto l_start = 0 ; l_start < L; l_start += kWarpSize ) {
127
+ auto l = l_start + threadIdx .x ;
127
128
if (l < L) {
128
129
grad_indice_weights[indices_start + l] = 0.0 ;
129
130
}
@@ -173,14 +174,14 @@ __global__ __launch_bounds__(kForwardMaxThreads) void
173
174
174
175
for (int32_t l_start = 0 ; l_start < L; l_start += kWarpSize ) {
175
176
int32_t l = l_start + threadIdx .x ;
176
- int64_t idx = l < L ? indices[indices_start + l] : 0 ;
177
+ index_t idx = l < L ? indices[indices_start + l] : 0 ;
177
178
{%- if not dense %}
178
179
const auto {{ locs_or_addrs_idx }} =
179
180
(placement == PlacementType::MANAGED_CACHING && l < L)
180
181
? {{ locs_or_addrs_tensor }}[indices_start + l] : 0 ;
181
182
{%- endif %}
182
183
for (auto j = 0 ; j < kWarpSize && l_start + j < L; ++j) {
183
- int64_t idx_j = shfl_sync (idx, j);
184
+ auto idx_j = shfl_sync (idx, j);
184
185
{%- if not dense %}
185
186
const auto {{ locs_or_addrs_idx }}_j = shfl_sync ({{ locs_or_addrs_idx }}, j);
186
187
{%- endif %}
@@ -354,6 +355,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
354
355
const uint32_t info_B_mask = info_B_mask_int64;
355
356
{%- endif %}
356
357
358
+ AT_DISPATCH_INDEX_TYPES (indices.scalar_type (), " split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel_1" , [&] {
357
359
DISPATCH_EMB_GRAD_CACHE_TYPES (
358
360
dev_weights.scalar_type (),
359
361
aligned_grad_output.scalar_type (),
@@ -362,7 +364,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
362
364
{%- else %}
363
365
dev_weights.scalar_type (),
364
366
{%- endif %}
365
- " split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel " ,
367
+ " split_embedding_codegen_grad_indice_weights{{ vdesc }}_kernel_2 " ,
366
368
[&] {
367
369
{%- if vbe %}
368
370
const auto & grad_output_reshaped = aligned_grad_output.reshape ({1 , -1 });
@@ -379,13 +381,13 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
379
381
mdesc, vdesc, vbdesc)
380
382
%}
381
383
#ifdef FBGEMM_GPU_MEMCHECK
382
- const auto func_name =
383
- " {{ kernel_name }}" ;
384
+ const auto func_name = " {{ kernel_name }}" ;
384
385
#endif
385
386
{{ kernel_name }}<
386
387
emb_t ,
387
388
grad_t ,
388
389
cache_t ,
390
+ index_t ,
389
391
kFixedMaxVecsPerThread ><<<
390
392
div_round_up (total_B, kForwardMaxThreads / kWarpSize ),
391
393
dim3(kWarpSize , kForwardMaxThreads / kWarpSize ),
@@ -400,8 +402,8 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
400
402
{%- endif %}
401
403
MAKE_PTA_WITH_NAME (func_name, weights_offsets, int64_t , 1 , 32 ),
402
404
MAKE_PTA_WITH_NAME(func_name, D_offsets, int32_t , 1 , 32 ),
403
- MAKE_PTA_WITH_NAME(func_name, indices, int64_t , 1 , 32 ),
404
- MAKE_PTA_WITH_NAME(func_name, offsets, int64_t , 1 , 32 ),
405
+ MAKE_PTA_WITH_NAME(func_name, indices, index_t , 1 , 32 ),
406
+ MAKE_PTA_WITH_NAME(func_name, offsets, index_t , 1 , 32 ),
405
407
{%- if not dense %}
406
408
MAKE_PTA_WITH_NAME (func_name, {{ locs_or_addrs_tensor }}, {{ locs_or_addrs_type }}, 1 , 32 ),
407
409
{%- endif %}
@@ -421,6 +423,7 @@ Tensor {{ mdesc }}_embedding_codegen_grad_indice_weights{{ vdesc }}_cuda(
421
423
});
422
424
{%- endfor %} {# /* for use_vec_blocking */ #}
423
425
});
426
+ });
424
427
425
428
C10_CUDA_KERNEL_LAUNCH_CHECK ();
426
429
return grad_indice_weights;
0 commit comments