@@ -62,6 +62,7 @@ template <
62
62
typename emb_t ,
63
63
typename grad_t ,
64
64
typename cache_t ,
65
+ typename index_t ,
65
66
{%- for ph_name in args.placeholder_tensor_names %}
66
67
typename {{ ph_name + " _ph_t" }},
67
68
{%- endfor %}
@@ -90,7 +91,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
90
91
int64_t D,
91
92
{%- endif %}
92
93
const pta::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits> hash_size_cumsum,
93
- const pta::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits> sorted_linear_indices_run,
94
+ const pta::PackedTensorAccessor32<index_t , 1 , at::RestrictPtrTraits> sorted_linear_indices_run,
94
95
const pta::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
95
96
{%- if not nobag %}
96
97
const pta::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> sorted_infos,
@@ -341,6 +342,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
341
342
emb_type,
342
343
grad_type,
343
344
cache_type,
345
+ index_type,
344
346
ph_type_combo,
345
347
kFixedMaxVecsPerThread ,
346
348
kThreadGroupSize ,
@@ -358,6 +360,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
358
360
< {{ emb_type }},
359
361
{{ grad_type }},
360
362
{{ cache_type }},
363
+ {{ index_type }},
361
364
{%- for ph_name in args.placeholder_tensor_names %}
362
365
{{ ph_type_combo[ph_name].primitive_type }},
363
366
{%- endfor %}
@@ -381,7 +384,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
381
384
int64_t D,
382
385
{%- endif %}
383
386
const pta::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits> hash_size_cumsum,
384
- const pta::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits> sorted_linear_indices_run,
387
+ const pta::PackedTensorAccessor32<{{ index_type }} , 1 , at::RestrictPtrTraits> sorted_linear_indices_run,
385
388
const pta::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
386
389
{%- if not nobag %}
387
390
const pta::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> sorted_infos,
@@ -441,11 +444,13 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
441
444
{%- for grad_type in [' float' , ' at::Half' , ' at::BFloat16' ] %}
442
445
{%- for emb_type in [' float' , ' at::Half' ] %}
443
446
{%- for cache_type in [' float' , ' at::Half' ] %}
447
+ {%- for index_type in [' int32_t' , ' int64_t' ] %}
444
448
{%- for ph_type_combo in args.placeholder_type_combos %}
445
449
{{ template_instantiation (
446
450
emb_type,
447
451
grad_type,
448
452
cache_type,
453
+ index_type,
449
454
ph_type_combo,
450
455
kFixedMaxVecsPerThread ,
451
456
kThreadGroupSize ,
@@ -456,6 +461,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
456
461
{%- endfor %}
457
462
{%- endfor %}
458
463
{%- endfor %}
464
+ {%- endfor %}
459
465
{%- endmacro %}
460
466
461
467
@@ -533,6 +539,7 @@ template <
533
539
typename emb_t ,
534
540
typename grad_t ,
535
541
typename cache_t ,
542
+ typename index_t ,
536
543
int32_t kFixedMaxVecsPerThread ,
537
544
int32_t kThreadGroupSize ,
538
545
bool kUseVecBlocking ,
@@ -556,7 +563,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
556
563
int64_t D,
557
564
{%- endif %}
558
565
const pta::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits> hash_size_cumsum,
559
- const pta::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits> sorted_linear_indices_run,
566
+ const pta::PackedTensorAccessor32<index_t , 1 , at::RestrictPtrTraits> sorted_linear_indices_run,
560
567
const pta::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
561
568
{%- if not nobag %}
562
569
const pta::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> sorted_infos,
@@ -652,6 +659,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
652
659
emb_t ,
653
660
cache_t ,
654
661
grad_t ,
662
+ index_t ,
655
663
BLOCK_SIZE,
656
664
embedding_dim,
657
665
segment_prefetch,
@@ -684,6 +692,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
684
692
emb_type,
685
693
grad_type,
686
694
cache_type,
695
+ index_type,
687
696
kFixedMaxVecsPerThread ,
688
697
kThreadGroupSize ,
689
698
kUseVecBlocking ,
@@ -696,6 +705,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
696
705
< {{ emb_type }},
697
706
{{ grad_type }},
698
707
{{ cache_type }},
708
+ {{ index_type }},
699
709
{{ kFixedMaxVecsPerThread }},
700
710
{{ kThreadGroupSize }},
701
711
{{ kUseVecBlocking }},
@@ -718,7 +728,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
718
728
int64_t D,
719
729
{%- endif %}
720
730
const pta::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits> hash_size_cumsum,
721
- const pta::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits> sorted_linear_indices_run,
731
+ const pta::PackedTensorAccessor32<{{ index_type }} , 1 , at::RestrictPtrTraits> sorted_linear_indices_run,
722
732
const pta::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
723
733
{%- if not nobag %}
724
734
const pta::PackedTensorAccessor32<int32_t , 1 , at::RestrictPtrTraits> sorted_infos,
@@ -764,12 +774,14 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
764
774
{%- for grad_type in [' float' , ' at::Half' , ' at::BFloat16' ] %}
765
775
{%- for emb_type in [' float' , ' at::Half' ] %}
766
776
{%- for cache_type in [' float' , ' at::Half' ] %}
777
+ {%- for index_type in [' int32_t' , ' int64_t' ] %}
767
778
{%- for kEmbeddingDim in [64 , 128 , 160 , 192 , 256 ] %}
768
779
{%- for kWeighDecayMode in [0 , 1 , 2 ] %}
769
780
{{ hip_template_instantiation (
770
781
emb_type,
771
782
grad_type,
772
783
cache_type,
784
+ index_type,
773
785
kFixedMaxVecsPerThread ,
774
786
kThreadGroupSize ,
775
787
kUseVecBlocking ,
@@ -782,6 +794,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
782
794
{%- endfor %}
783
795
{%- endfor %}
784
796
{%- endfor %}
797
+ {%- endfor %}
785
798
{%- endmacro %}
786
799
787
800
{%- macro hip_instantiate_templates (use_subwarp_shuffle) %}
0 commit comments