@@ -40,7 +40,7 @@ struct half2float16<at::Half> {
40
40
} // namespace internal
41
41
42
42
namespace {
43
- template <typename scalar_t , typename grad_t >
43
+ template <typename index_t , typename scalar_t , typename grad_t >
44
44
void split_embedding_backward_exact_cpu_kernel (
45
45
Tensor grad_output,
46
46
Tensor host_weights,
@@ -225,7 +225,7 @@ for (const auto d : c10::irange(D)) {
225
225
} // for each table
226
226
}
227
227
228
- template <typename scalar_t >
228
+ template <typename index_t , typename scalar_t >
229
229
void split_embedding_backward_exact_cpu_dense_kernel (
230
230
Tensor grad,
231
231
Tensor grad_output,
@@ -242,8 +242,8 @@ void split_embedding_backward_exact_cpu_dense_kernel(
242
242
243
243
auto grad_output_data = grad_output.accessor <scalar_t , 2 >();
244
244
245
- const auto indices_data = indices.accessor <int64_t , 1 >();
246
- const auto offsets_data = offsets.accessor <int64_t , 1 >();
245
+ const auto indices_data = indices.accessor <index_t , 1 >();
246
+ const auto offsets_data = offsets.accessor <index_t , 1 >();
247
247
const auto indice_weights_data = indice_weights.defined ()
248
248
?
249
249
// If indice_weights are not defined, then this accessor won't be
@@ -349,34 +349,42 @@ for (const auto d : c10::irange(D)) {
349
349
350
350
grad_output = grad_output.contiguous ();
351
351
352
-
353
- FBGEMM_DISPATCH_FLOAT_AND_HALF (
352
+ FBGEMM_DISPATCH_INTEGRAL_TYPES (
353
+ indices.scalar_type (),
354
+ " split_embedding_backward_exact_cpu_kernel_1" , [&] {
355
+ using index_t = scalar_t ;
356
+
357
+ FBGEMM_DISPATCH_FLOAT_AND_HALF (
354
358
grad_output.scalar_type (),
355
- " split_embedding_backward_exact_cpu_outer" , [&]() {
356
- using grad_t = scalar_t ;
359
+ " split_embedding_backward_exact_cpu_kernel_2" , [&] {
360
+ using grad_t = scalar_t ;
361
+
357
362
FBGEMM_DISPATCH_FLOAT_AND_HALF (
358
- host_weights.scalar_type (), " split_embedding_backward_exact_cpu" , [&] {
359
- split_embedding_backward_exact_cpu_kernel<scalar_t , grad_t >(
360
- grad_output,
361
- host_weights,
362
- weights_offsets_data,
363
- D_offsets_data,
364
- hash_size_cumsum,
365
- indices,
366
- offsets,
367
- pooling_mode,
368
- indice_weights,
369
- num_tables,
370
- B,
371
- table_to_feature_offset,
372
- {% if " momentum1_offsets" in args.split_function_arg_names %}
373
- momentum1_offsets_data,
374
- {% endif %}
375
- {% if " momentum2_offsets" in args.split_function_arg_names %}
376
- momentum2_offsets_data,
377
- {% endif %}
378
- {{ args.split_cpu_kernel_arg_constructors | join (" , " ) }});
379
- });
363
+ host_weights.scalar_type (),
364
+ " split_embedding_backward_exact_cpu_kernel_3" , [&] {
365
+
366
+ split_embedding_backward_exact_cpu_kernel<index_t , scalar_t , grad_t >(
367
+ grad_output,
368
+ host_weights,
369
+ weights_offsets_data,
370
+ D_offsets_data,
371
+ hash_size_cumsum,
372
+ indices,
373
+ offsets,
374
+ pooling_mode,
375
+ indice_weights,
376
+ num_tables,
377
+ B,
378
+ table_to_feature_offset,
379
+ {% if " momentum1_offsets" in args.split_function_arg_names %}
380
+ momentum1_offsets_data,
381
+ {% endif %}
382
+ {% if " momentum2_offsets" in args.split_function_arg_names %}
383
+ momentum2_offsets_data,
384
+ {% endif %}
385
+ {{ args.split_cpu_kernel_arg_constructors | join (" , " ) }});
386
+ });
387
+ });
380
388
});
381
389
382
390
return ;
@@ -385,10 +393,16 @@ for (const auto d : c10::irange(D)) {
385
393
386
394
// When input is dense enough, avoid sorting and just treat as dense.
387
395
auto grad = zeros_like (host_weights, grad_output.dtype ());
388
- FBGEMM_DISPATCH_FLOAT_AND_HALF (
389
- grad_output.scalar_type (), " split_embedding_backward_exact_cpu" , [&] {
396
+ FBGEMM_DISPATCH_INTEGRAL_TYPES (
397
+ indices.scalar_type (),
398
+ " split_embedding_backward_exact_cpu_dense_kernel" , [&] {
399
+ using index_t = scalar_t ;
390
400
391
- split_embedding_backward_exact_cpu_dense_kernel<scalar_t >(
401
+ FBGEMM_DISPATCH_FLOAT_AND_HALF (
402
+ grad_output.scalar_type (),
403
+ " split_embedding_backward_exact_cpu" , [&] {
404
+
405
+ split_embedding_backward_exact_cpu_dense_kernel<index_t , scalar_t >(
392
406
grad,
393
407
grad_output,
394
408
weights_offsets_data,
@@ -400,7 +414,8 @@ for (const auto d : c10::irange(D)) {
400
414
num_tables,
401
415
B,
402
416
table_to_feature_offset);
403
- }); // dispatch host_weights.scalar_type()
417
+ });
418
+ });
404
419
405
420
return grad;
406
421
{% endif %}
0 commit comments