@@ -40,7 +40,7 @@ struct half2float16<at::Half> {
40
40
} // namespace internal
41
41
42
42
namespace {
43
- template <typename scalar_t , typename grad_t >
43
+ template <typename index_t , typename scalar_t , typename grad_t >
44
44
void split_embedding_backward_exact_cpu_kernel (
45
45
Tensor grad_output,
46
46
Tensor host_weights,
@@ -90,8 +90,8 @@ for (const auto t : c10::irange(num_tables)) {
90
90
::internal::csr2csc (
91
91
cscs[t],
92
92
B,
93
- offsets.accessor<int64_t , 1 >(),
94
- indices.accessor<int64_t , 1>(),
93
+ offsets.accessor<index_t , 1 >(),
94
+ indices.accessor<index_t , 1>(),
95
95
indice_weights.defined()
96
96
? indice_weights.accessor<at::acc_type<scalar_t, true>, 1>()
97
97
: at::TensorAccessor<at::acc_type<scalar_t, true>, 1>(nullptr , nullptr , nullptr ),
@@ -223,7 +223,7 @@ for (const auto d : c10::irange(D)) {
223
223
} // for each table
224
224
}
225
225
226
- template <typename scalar_t >
226
+ template <typename index_t , typename scalar_t >
227
227
void split_embedding_backward_exact_cpu_dense_kernel (
228
228
Tensor grad,
229
229
Tensor grad_output,
@@ -240,8 +240,8 @@ void split_embedding_backward_exact_cpu_dense_kernel(
240
240
241
241
auto grad_output_data = grad_output.accessor <scalar_t , 2 >();
242
242
243
- const auto indices_data = indices.accessor <int64_t , 1 >();
244
- const auto offsets_data = offsets.accessor <int64_t , 1 >();
243
+ const auto indices_data = indices.accessor <index_t , 1 >();
244
+ const auto offsets_data = offsets.accessor <index_t , 1 >();
245
245
const auto indice_weights_data = indice_weights.defined ()
246
246
?
247
247
// If indice_weights are not defined, then this accessor won't be
@@ -347,34 +347,42 @@ for (const auto d : c10::irange(D)) {
347
347
348
348
grad_output = grad_output.contiguous ();
349
349
350
-
351
- FBGEMM_DISPATCH_FLOAT_AND_HALF (
350
+ FBGEMM_DISPATCH_INTEGRAL_TYPES (
351
+ indices.scalar_type (),
352
+ " split_embedding_backward_exact_cpu_kernel_1" , [&] {
353
+ using index_t = scalar_t ;
354
+
355
+ FBGEMM_DISPATCH_FLOAT_AND_HALF (
352
356
grad_output.scalar_type (),
353
- " split_embedding_backward_exact_cpu_outer" , [&]() {
354
- using grad_t = scalar_t ;
357
+ " split_embedding_backward_exact_cpu_kernel_2" , [&] {
358
+ using grad_t = scalar_t ;
359
+
355
360
FBGEMM_DISPATCH_FLOAT_AND_HALF (
356
- host_weights.scalar_type (), " split_embedding_backward_exact_cpu" , [&] {
357
- split_embedding_backward_exact_cpu_kernel<scalar_t , grad_t >(
358
- grad_output,
359
- host_weights,
360
- weights_offsets_data,
361
- D_offsets_data,
362
- hash_size_cumsum,
363
- indices,
364
- offsets,
365
- pooling_mode,
366
- indice_weights,
367
- num_tables,
368
- B,
369
- table_to_feature_offset,
370
- {% if " momentum1_offsets" in args.split_function_arg_names %}
371
- momentum1_offsets_data,
372
- {% endif %}
373
- {% if " momentum2_offsets" in args.split_function_arg_names %}
374
- momentum2_offsets_data,
375
- {% endif %}
376
- {{ args.split_cpu_kernel_arg_constructors | join (" , " ) }});
377
- });
361
+ host_weights.scalar_type (),
362
+ " split_embedding_backward_exact_cpu_kernel_3" , [&] {
363
+
364
+ split_embedding_backward_exact_cpu_kernel<index_t , scalar_t , grad_t >(
365
+ grad_output,
366
+ host_weights,
367
+ weights_offsets_data,
368
+ D_offsets_data,
369
+ hash_size_cumsum,
370
+ indices,
371
+ offsets,
372
+ pooling_mode,
373
+ indice_weights,
374
+ num_tables,
375
+ B,
376
+ table_to_feature_offset,
377
+ {% if " momentum1_offsets" in args.split_function_arg_names %}
378
+ momentum1_offsets_data,
379
+ {% endif %}
380
+ {% if " momentum2_offsets" in args.split_function_arg_names %}
381
+ momentum2_offsets_data,
382
+ {% endif %}
383
+ {{ args.split_cpu_kernel_arg_constructors | join (" , " ) }});
384
+ });
385
+ });
378
386
});
379
387
380
388
return ;
@@ -383,10 +391,16 @@ for (const auto d : c10::irange(D)) {
383
391
384
392
// When input is dense enough, avoid sorting and just treat as dense.
385
393
auto grad = zeros_like (host_weights, grad_output.dtype ());
386
- FBGEMM_DISPATCH_FLOAT_AND_HALF (
387
- grad_output.scalar_type (), " split_embedding_backward_exact_cpu" , [&] {
394
+ FBGEMM_DISPATCH_INTEGRAL_TYPES (
395
+ indices.scalar_type (),
396
+ " split_embedding_backward_exact_cpu_dense_kernel" , [&] {
397
+ using index_t = scalar_t ;
388
398
389
- split_embedding_backward_exact_cpu_dense_kernel<scalar_t >(
399
+ FBGEMM_DISPATCH_FLOAT_AND_HALF (
400
+ grad_output.scalar_type (),
401
+ " split_embedding_backward_exact_cpu" , [&] {
402
+
403
+ split_embedding_backward_exact_cpu_dense_kernel<index_t , scalar_t >(
390
404
grad,
391
405
grad_output,
392
406
weights_offsets_data,
@@ -398,7 +412,8 @@ for (const auto d : c10::irange(D)) {
398
412
num_tables,
399
413
B,
400
414
table_to_feature_offset);
401
- }); // dispatch host_weights.scalar_type()
415
+ });
416
+ });
402
417
403
418
return grad;
404
419
{% endif %}
0 commit comments