Skip to content

Commit c2a8dc8

Browse files
q10facebook-github-bot
authored andcommitted
Add support for int32_t indices in TBE training (2F/N) (pytorch#3376)
Summary: X-link: facebookresearch/FBGEMM#623 X-link: facebookresearch/FBGEMM#467 - Add `index_t` support to TBE training backward kernels Reviewed By: basilwong Differential Revision: D65938455
1 parent 3e0db25 commit c2a8dc8

File tree

3 files changed

+36
-18
lines changed

3 files changed

+36
-18
lines changed

fbgemm_gpu/codegen/training/backward/embedding_backward_split_kernel_warp_template.cu

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ template <
6262
typename emb_t,
6363
typename grad_t,
6464
typename cache_t,
65+
typename index_t,
6566
{%- for ph_name in args.placeholder_tensor_names %}
6667
typename {{ ph_name + "_ph_t"}},
6768
{%- endfor %}
@@ -90,7 +91,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
9091
int64_t D,
9192
{%- endif %}
9293
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
93-
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
94+
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
9495
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
9596
{%- if not nobag %}
9697
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
@@ -341,6 +342,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
341342
emb_type,
342343
grad_type,
343344
cache_type,
345+
index_type,
344346
ph_type_combo,
345347
kFixedMaxVecsPerThread,
346348
kThreadGroupSize,
@@ -358,6 +360,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
358360
< {{ emb_type }},
359361
{{ grad_type }},
360362
{{ cache_type }},
363+
{{ index_type }},
361364
{%- for ph_name in args.placeholder_tensor_names %}
362365
{{ ph_type_combo[ph_name].primitive_type }},
363366
{%- endfor %}
@@ -381,7 +384,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
381384
int64_t D,
382385
{%- endif %}
383386
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
384-
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
387+
const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
385388
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
386389
{%- if not nobag %}
387390
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
@@ -441,11 +444,13 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
441444
{%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %}
442445
{%- for emb_type in ['float', 'at::Half'] %}
443446
{%- for cache_type in ['float', 'at::Half'] %}
447+
{%- for index_type in ['int32_t', 'int64_t'] %}
444448
{%- for ph_type_combo in args.placeholder_type_combos %}
445449
{{ template_instantiation(
446450
emb_type,
447451
grad_type,
448452
cache_type,
453+
index_type,
449454
ph_type_combo,
450455
kFixedMaxVecsPerThread,
451456
kThreadGroupSize,
@@ -456,6 +461,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row
456461
{%- endfor %}
457462
{%- endfor %}
458463
{%- endfor %}
464+
{%- endfor %}
459465
{%- endmacro %}
460466

461467

@@ -533,6 +539,7 @@ template <
533539
typename emb_t,
534540
typename grad_t,
535541
typename cache_t,
542+
typename index_t,
536543
int32_t kFixedMaxVecsPerThread,
537544
int32_t kThreadGroupSize,
538545
bool kUseVecBlocking,
@@ -556,7 +563,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
556563
int64_t D,
557564
{%- endif %}
558565
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
559-
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
566+
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
560567
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
561568
{%- if not nobag %}
562569
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
@@ -652,6 +659,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
652659
emb_t,
653660
cache_t,
654661
grad_t,
662+
index_t,
655663
BLOCK_SIZE,
656664
embedding_dim,
657665
segment_prefetch,
@@ -684,6 +692,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
684692
emb_type,
685693
grad_type,
686694
cache_type,
695+
index_type,
687696
kFixedMaxVecsPerThread,
688697
kThreadGroupSize,
689698
kUseVecBlocking,
@@ -696,6 +705,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
696705
< {{ emb_type }},
697706
{{ grad_type }},
698707
{{ cache_type }},
708+
{{ index_type }},
699709
{{ kFixedMaxVecsPerThread }},
700710
{{ kThreadGroupSize }},
701711
{{ kUseVecBlocking }},
@@ -718,7 +728,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
718728
int64_t D,
719729
{%- endif %}
720730
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
721-
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
731+
const pta::PackedTensorAccessor32<{{ index_type }}, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
722732
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
723733
{%- if not nobag %}
724734
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
@@ -764,12 +774,14 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
764774
{%- for grad_type in ['float', 'at::Half', 'at::BFloat16'] %}
765775
{%- for emb_type in ['float', 'at::Half'] %}
766776
{%- for cache_type in ['float', 'at::Half'] %}
777+
{%- for index_type in ['int32_t', 'int64_t'] %}
767778
{%- for kEmbeddingDim in [64, 128, 160, 192, 256] %}
768779
{%- for kWeighDecayMode in [0, 1, 2] %}
769780
{{ hip_template_instantiation(
770781
emb_type,
771782
grad_type,
772783
cache_type,
784+
index_type,
773785
kFixedMaxVecsPerThread,
774786
kThreadGroupSize,
775787
kUseVecBlocking,
@@ -782,6 +794,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
782794
{%- endfor %}
783795
{%- endfor %}
784796
{%- endfor %}
797+
{%- endfor %}
785798
{%- endmacro %}
786799

787800
{%- macro hip_instantiate_templates(use_subwarp_shuffle) %}

fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ template <
139139
typename emb_t,
140140
typename grad_t,
141141
typename cache_t,
142+
typename index_t,
142143
{%- for ph_name in args.placeholder_tensor_names %}
143144
typename {{ ph_name + "_ph_t" }},
144145
{%- endfor %}
@@ -167,7 +168,7 @@ batch_index_select_dim0_codegen_backward_kernel_warp_per_row(
167168
int64_t D,
168169
{%- endif %}
169170
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
170-
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
171+
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
171172
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
172173
{%- if not nobag %}
173174
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
@@ -224,6 +225,7 @@ template <
224225
typename emb_t,
225226
typename grad_t,
226227
typename cache_t,
228+
typename index_t,
227229
int32_t kFixedMaxVecsPerThread,
228230
int32_t kThreadGroupSize,
229231
bool kUseVecBlocking,
@@ -247,7 +249,7 @@ hip_split_embedding{{ ndesc }}_backward_codegen_{{ optimizer }}_{{ wdesc }}{{ vd
247249
int64_t D,
248250
{%- endif %}
249251
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> hash_size_cumsum,
250-
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
252+
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> sorted_linear_indices_run,
251253
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_linear_indices_cumulative_run_lengths,
252254
{%- if not nobag %}
253255
const pta::PackedTensorAccessor32<int32_t, 1, at::RestrictPtrTraits> sorted_infos,
@@ -826,8 +828,8 @@ Tensor {{ embedding_cuda_op }}(
826828
AT_CUDA_CHECK(radix_sort_pairs(
827829
nullptr,
828830
temp_storage_bytes,
829-
linear_indices.data_ptr<int64_t>(),
830-
linear_indices_sorted.data_ptr<int64_t>(),
831+
linear_indices.data_ptr<index_t>(),
832+
linear_indices_sorted.data_ptr<index_t>(),
831833
{{ locs_or_addrs_tensor }}.data_ptr<{{ locs_or_addrs_type }}>(),
832834
{{ locs_or_addrs_tensor }}_sorted.data_ptr<{{ locs_or_addrs_type }}>(),
833835
linear_indices.numel(),
@@ -842,8 +844,8 @@ Tensor {{ embedding_cuda_op }}(
842844
AT_CUDA_CHECK(radix_sort_pairs(
843845
temp_storage.data_ptr(),
844846
temp_storage_bytes,
845-
linear_indices.data_ptr<int64_t>(),
846-
linear_indices_sorted.data_ptr<int64_t>(),
847+
linear_indices.data_ptr<index_t>(),
848+
linear_indices_sorted.data_ptr<index_t>(),
847849
{{ locs_or_addrs_tensor }}.data_ptr<{{ locs_or_addrs_type }}>(),
848850
{{ locs_or_addrs_tensor }}_sorted.data_ptr<{{ locs_or_addrs_type }}>(),
849851
linear_indices.numel(),
@@ -888,8 +890,8 @@ Tensor {{ embedding_cuda_op }}(
888890
AT_CUDA_CHECK(radix_sort_pairs(
889891
nullptr,
890892
temp_storage_bytes,
891-
linear_indices.data_ptr<int64_t>(),
892-
linear_indices_sorted.data_ptr<int64_t>(),
893+
linear_indices.data_ptr<index_t>(),
894+
linear_indices_sorted.data_ptr<index_t>(),
893895
indice_weights.data_ptr<at::acc_type<cache_t, true>>(),
894896
indice_weights_sorted.data_ptr<at::acc_type<cache_t, true>>(),
895897
linear_indices.numel(),
@@ -904,8 +906,8 @@ Tensor {{ embedding_cuda_op }}(
904906
AT_CUDA_CHECK(radix_sort_pairs(
905907
temp_storage.data_ptr(),
906908
temp_storage_bytes,
907-
linear_indices.data_ptr<int64_t>(),
908-
linear_indices_sorted.data_ptr<int64_t>(),
909+
linear_indices.data_ptr<index_t>(),
910+
linear_indices_sorted.data_ptr<index_t>(),
909911
indice_weights.data_ptr<at::acc_type<cache_t, true>>(),
910912
indice_weights_sorted.data_ptr<at::acc_type<cache_t, true>>(),
911913
linear_indices.numel(),
@@ -1174,6 +1176,7 @@ Tensor {{ embedding_cuda_op }}(
11741176
<emb_t,
11751177
grad_t,
11761178
cache_t,
1179+
index_t,
11771180
{%- for ph_name in args.placeholder_tensor_names %}
11781181
{{ ph_name + "_ph_t" }},
11791182
{%- endfor %}
@@ -1225,6 +1228,7 @@ Tensor {{ embedding_cuda_op }}(
12251228
<emb_t,
12261229
grad_t,
12271230
cache_t,
1231+
index_t,
12281232
kFixedMaxVecsPerThread,
12291233
kThreadGroupSize,
12301234
kUseVecBlocking,
@@ -1264,7 +1268,7 @@ Tensor {{ embedding_cuda_op }}(
12641268
D,
12651269
{%- endif %}
12661270
MAKE_PTA_WITH_NAME(func_name4, hash_size_cumsum, int64_t, 1, 32),
1267-
MAKE_PTA_WITH_NAME(func_name4, sorted_linear_indices_run, int64_t, 1, 32),
1271+
MAKE_PTA_WITH_NAME(func_name4, sorted_linear_indices_run, index_t, 1, 32),
12681272
MAKE_PTA_WITH_NAME(func_name4, sorted_linear_indices_cumulative_run_lengths, int32_t, 1, 32),
12691273
{%- if not nobag %}
12701274
MAKE_PTA_WITH_NAME(func_name4, infos_sorted, int32_t, 1, 32),

fbgemm_gpu/codegen/training/backward/rocm/embedding_backward_split_device_kernel_template.hip

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ template <typename optimizer_t,
108108
typename emb_t,
109109
typename cache_t,
110110
typename grad_t,
111+
typename index_t,
111112
int32_t block_size,
112113
int32_t embedding_dim,
113114
int32_t segment_prefetch, // 2
@@ -118,7 +119,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
118119
const grad_t* p_output_grad,
119120
emb_t* p_emb_table,
120121
const int64_t* p_hash_size_cumsum,
121-
const int64_t* p_sorted_linear_indices_run,
122+
const index_t* p_sorted_linear_indices_run,
122123
const int32_t* p_sorted_linear_indices_cumulative_run_lengths,
123124
const int32_t* p_sorted_linear_indices_num_runs,
124125
{%- if not nobag %}
@@ -151,7 +152,7 @@ __device__ void split_tbe_backward_hip_kernel_{{ kdesc }}(
151152
return;
152153
}
153154

154-
const int64_t linear_index = p_sorted_linear_indices_run[run_id];
155+
const auto linear_index = p_sorted_linear_indices_run[run_id];
155156

156157
const int32_t segment_start = p_sorted_linear_indices_cumulative_run_lengths[run_id];
157158
const int32_t segment_end = p_sorted_linear_indices_cumulative_run_lengths[run_id + 1];
@@ -458,4 +459,4 @@ L_tail_grad_acc:
458459

459460
store_row_per_warp<emb_t, embedding_dim, emb_t>::run(&emb_data[0], p_emb_table + emb_idx * embedding_dim, lane_id);
460461
}
461-
} // namespace fbgemm_gpu::rocm
462+
} // namespace fbgemm_gpu::rocm

0 commit comments

Comments
 (0)