Skip to content

Commit 162e1f1

Browse files
q10facebook-github-bot
authored andcommitted
Add support for int32_t indices in TBE training (2H/N) (pytorch#3539)
Summary: X-link: facebookresearch/FBGEMM#626 - Update benchmark test for `int32_t` Indicies Reviewed By: sryap Differential Revision: D67784746
1 parent 2dc16f8 commit 162e1f1

File tree

4 files changed

+69
-41
lines changed

4 files changed

+69
-41
lines changed

fbgemm_gpu/bench/bench_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ def benchmark_requests(
168168

169169
if num_warmups > 0:
170170
indices, offsets, weights = requests[0].unpack_3()
171+
print(f"INDICES BENCHMARK {indices.dtype}")
172+
print(f"OFFSETS BENCHMARK {offsets.dtype}")
171173
for _ in range(num_warmups):
172174
out = func(indices, offsets, weights)
173175
if bwd_only:

fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def cli() -> None:
125125
@click.option("--flush-gpu-cache-size-mb", default=0)
126126
@click.option("--dense", is_flag=True, default=False)
127127
@click.option("--output-dtype", type=SparseType, default=SparseType.FP32)
128+
@click.option("--indices-dtype", type=click.Choice(["32", "64"]), default="64")
128129
@click.option("--requests_data_file", type=str, default=None)
129130
@click.option("--tables", type=str, default=None)
130131
@click.option("--export-trace", is_flag=True, default=False)
@@ -166,6 +167,7 @@ def device( # noqa C901
166167
flush_gpu_cache_size_mb: int,
167168
dense: bool,
168169
output_dtype: SparseType,
170+
indices_dtype: str,
169171
requests_data_file: Optional[str],
170172
tables: Optional[str],
171173
export_trace: bool,
@@ -176,6 +178,9 @@ def device( # noqa C901
176178
cache_load_factor: float,
177179
) -> None:
178180
assert not ssd or not dense, "--ssd cannot be used together with --dense"
181+
indices_dtype_torch: torch.dtype = (
182+
torch.int32 if int(indices_dtype) == 32 else torch.int64
183+
)
179184
np.random.seed(42)
180185
torch.manual_seed(42)
181186
B = batch_size
@@ -352,8 +357,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
352357
time_per_iter = benchmark_requests(
353358
requests,
354359
lambda indices, offsets, per_sample_weights: emb.forward(
355-
indices.long(),
356-
offsets.long(),
360+
indices.to(dtype=indices_dtype_torch),
361+
offsets.to(dtype=indices_dtype_torch),
357362
per_sample_weights,
358363
feature_requires_grad=feature_requires_grad,
359364
),
@@ -384,8 +389,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
384389
time_per_iter = benchmark_requests(
385390
requests,
386391
lambda indices, offsets, per_sample_weights: emb(
387-
indices.long(),
388-
offsets.long(),
392+
indices.to(dtype=indices_dtype_torch),
393+
offsets.to(dtype=indices_dtype_torch),
389394
per_sample_weights,
390395
feature_requires_grad=feature_requires_grad,
391396
),

fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ struct half2float16<at::Half> {
4040
} // namespace internal
4141

4242
namespace {
43-
template <typename scalar_t, typename grad_t>
43+
template <typename index_t, typename scalar_t, typename grad_t>
4444
void split_embedding_backward_exact_cpu_kernel(
4545
Tensor grad_output,
4646
Tensor host_weights,
@@ -225,7 +225,7 @@ for (const auto d : c10::irange(D)) {
225225
} // for each table
226226
}
227227

228-
template <typename scalar_t>
228+
template <typename index_t, typename scalar_t>
229229
void split_embedding_backward_exact_cpu_dense_kernel(
230230
Tensor grad,
231231
Tensor grad_output,
@@ -242,8 +242,8 @@ void split_embedding_backward_exact_cpu_dense_kernel(
242242

243243
auto grad_output_data = grad_output.accessor<scalar_t, 2>();
244244

245-
const auto indices_data = indices.accessor<int64_t, 1>();
246-
const auto offsets_data = offsets.accessor<int64_t, 1>();
245+
const auto indices_data = indices.accessor<index_t, 1>();
246+
const auto offsets_data = offsets.accessor<index_t, 1>();
247247
const auto indice_weights_data = indice_weights.defined()
248248
?
249249
// If indice_weights are not defined, then this accessor won't be
@@ -349,34 +349,42 @@ for (const auto d : c10::irange(D)) {
349349

350350
grad_output = grad_output.contiguous();
351351

352-
353-
FBGEMM_DISPATCH_FLOAT_AND_HALF(
352+
FBGEMM_DISPATCH_INTEGRAL_TYPES(
353+
indices.scalar_type(),
354+
"split_embedding_backward_exact_cpu_kernel_1", [&] {
355+
using index_t = scalar_t;
356+
357+
FBGEMM_DISPATCH_FLOAT_AND_HALF(
354358
grad_output.scalar_type(),
355-
"split_embedding_backward_exact_cpu_outer", [&]() {
356-
using grad_t = scalar_t;
359+
"split_embedding_backward_exact_cpu_kernel_2", [&] {
360+
using grad_t = scalar_t;
361+
357362
FBGEMM_DISPATCH_FLOAT_AND_HALF(
358-
host_weights.scalar_type(), "split_embedding_backward_exact_cpu", [&] {
359-
split_embedding_backward_exact_cpu_kernel<scalar_t, grad_t>(
360-
grad_output,
361-
host_weights,
362-
weights_offsets_data,
363-
D_offsets_data,
364-
hash_size_cumsum,
365-
indices,
366-
offsets,
367-
pooling_mode,
368-
indice_weights,
369-
num_tables,
370-
B,
371-
table_to_feature_offset,
372-
{% if "momentum1_offsets" in args.split_function_arg_names %}
373-
momentum1_offsets_data,
374-
{% endif %}
375-
{% if "momentum2_offsets" in args.split_function_arg_names %}
376-
momentum2_offsets_data,
377-
{% endif %}
378-
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
379-
});
363+
host_weights.scalar_type(),
364+
"split_embedding_backward_exact_cpu_kernel_3", [&] {
365+
366+
split_embedding_backward_exact_cpu_kernel<index_t, scalar_t, grad_t>(
367+
grad_output,
368+
host_weights,
369+
weights_offsets_data,
370+
D_offsets_data,
371+
hash_size_cumsum,
372+
indices,
373+
offsets,
374+
pooling_mode,
375+
indice_weights,
376+
num_tables,
377+
B,
378+
table_to_feature_offset,
379+
{% if "momentum1_offsets" in args.split_function_arg_names %}
380+
momentum1_offsets_data,
381+
{% endif %}
382+
{% if "momentum2_offsets" in args.split_function_arg_names %}
383+
momentum2_offsets_data,
384+
{% endif %}
385+
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
386+
});
387+
});
380388
});
381389

382390
return;
@@ -385,10 +393,16 @@ for (const auto d : c10::irange(D)) {
385393

386394
// When input is dense enough, avoid sorting and just treat as dense.
387395
auto grad = zeros_like(host_weights, grad_output.dtype());
388-
FBGEMM_DISPATCH_FLOAT_AND_HALF(
389-
grad_output.scalar_type(), "split_embedding_backward_exact_cpu", [&] {
396+
FBGEMM_DISPATCH_INTEGRAL_TYPES(
397+
indices.scalar_type(),
398+
"split_embedding_backward_exact_cpu_dense_kernel", [&] {
399+
using index_t = scalar_t;
390400

391-
split_embedding_backward_exact_cpu_dense_kernel<scalar_t>(
401+
FBGEMM_DISPATCH_FLOAT_AND_HALF(
402+
grad_output.scalar_type(),
403+
"split_embedding_backward_exact_cpu", [&] {
404+
405+
split_embedding_backward_exact_cpu_dense_kernel<index_t, scalar_t>(
392406
grad,
393407
grad_output,
394408
weights_offsets_data,
@@ -400,7 +414,8 @@ for (const auto d : c10::irange(D)) {
400414
num_tables,
401415
B,
402416
table_to_feature_offset);
403-
}); // dispatch host_weights.scalar_type()
417+
});
418+
});
404419

405420
return grad;
406421
{% endif %}

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3315,8 +3315,10 @@ def prepare_inputs(
33153315
)
33163316

33173317
if force_cast_input_types:
3318-
# Force casting indices and offsets to long
3319-
(indices, offsets) = indices.long(), offsets.long()
3318+
# NOTE: Force offsets to have the same dtype as indices since the
3319+
# kernels assume same dtype. We might need to revisit the assumption
3320+
# of same dtypes in the future.
3321+
offsets = offsets.to(dtype=indices.dtype)
33203322

33213323
# Force casting per_sample_weights to float
33223324
if per_sample_weights is not None:
@@ -3681,7 +3683,11 @@ def forward(
36813683
offsets, batch_size_per_feature_per_rank
36823684
)
36833685

3684-
(indices, offsets) = indices.long(), offsets.long()
3686+
# NOTE: Force offsets to have the same dtype as indices since the
3687+
# kernels assume same dtype. We might need to revisit the assumption
3688+
# of same dtypes in the future.
3689+
offsets = offsets.to(dtype=indices.dtype)
3690+
36853691
# Force casting per_sample_weights to float
36863692
if per_sample_weights is not None:
36873693
per_sample_weights = per_sample_weights.float()

0 commit comments

Comments
 (0)