Skip to content

Commit b2e29a7

Browse files
q10facebook-github-bot
authored andcommitted
Add support for int32_t indices in TBE training (2H/N) (pytorch#3539)
Summary: X-link: facebookresearch/FBGEMM#626 - Update benchmark test for `int32_t` Indicies Reviewed By: sryap Differential Revision: D67784746
1 parent 7cd05f6 commit b2e29a7

File tree

4 files changed

+71
-43
lines changed

4 files changed

+71
-43
lines changed

fbgemm_gpu/bench/bench_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,8 @@ def benchmark_requests(
168168

169169
if num_warmups > 0:
170170
indices, offsets, weights = requests[0].unpack_3()
171+
print(f"INDICES BENCHMARK {indices.dtype}")
172+
print(f"OFFSETS BENCHMARK {offsets.dtype}")
171173
for _ in range(num_warmups):
172174
out = func(indices, offsets, weights)
173175
if bwd_only:

fbgemm_gpu/bench/split_table_batched_embeddings_benchmark.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def cli() -> None:
125125
@click.option("--flush-gpu-cache-size-mb", default=0)
126126
@click.option("--dense", is_flag=True, default=False)
127127
@click.option("--output-dtype", type=SparseType, default=SparseType.FP32)
128+
@click.option("--indices-dtype", type=click.Choice(["32", "64"]), default="64")
128129
@click.option("--requests_data_file", type=str, default=None)
129130
@click.option("--tables", type=str, default=None)
130131
@click.option("--export-trace", is_flag=True, default=False)
@@ -166,6 +167,7 @@ def device( # noqa C901
166167
flush_gpu_cache_size_mb: int,
167168
dense: bool,
168169
output_dtype: SparseType,
170+
indices_dtype: str,
169171
requests_data_file: Optional[str],
170172
tables: Optional[str],
171173
export_trace: bool,
@@ -176,6 +178,9 @@ def device( # noqa C901
176178
cache_load_factor: float,
177179
) -> None:
178180
assert not ssd or not dense, "--ssd cannot be used together with --dense"
181+
indices_dtype_torch: torch.dtype = (
182+
torch.int32 if int(indices_dtype) == 32 else torch.int64
183+
)
179184
np.random.seed(42)
180185
torch.manual_seed(42)
181186
B = batch_size
@@ -352,8 +357,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
352357
time_per_iter = benchmark_requests(
353358
requests,
354359
lambda indices, offsets, per_sample_weights: emb.forward(
355-
indices.long(),
356-
offsets.long(),
360+
indices.to(dtype=indices_dtype_torch),
361+
offsets.to(dtype=indices_dtype_torch),
357362
per_sample_weights,
358363
feature_requires_grad=feature_requires_grad,
359364
),
@@ -384,8 +389,8 @@ def context_factory(on_trace_ready: Callable[[profile], None]):
384389
time_per_iter = benchmark_requests(
385390
requests,
386391
lambda indices, offsets, per_sample_weights: emb(
387-
indices.long(),
388-
offsets.long(),
392+
indices.to(dtype=indices_dtype_torch),
393+
offsets.to(dtype=indices_dtype_torch),
389394
per_sample_weights,
390395
feature_requires_grad=feature_requires_grad,
391396
),

fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ struct half2float16<at::Half> {
4040
} // namespace internal
4141

4242
namespace {
43-
template <typename scalar_t, typename grad_t>
43+
template <typename index_t, typename scalar_t, typename grad_t>
4444
void split_embedding_backward_exact_cpu_kernel(
4545
Tensor grad_output,
4646
Tensor host_weights,
@@ -90,8 +90,8 @@ for (const auto t : c10::irange(num_tables)) {
9090
::internal::csr2csc(
9191
cscs[t],
9292
B,
93-
offsets.accessor<int64_t, 1>(),
94-
indices.accessor<int64_t, 1>(),
93+
offsets.accessor<index_t, 1>(),
94+
indices.accessor<index_t, 1>(),
9595
indice_weights.defined()
9696
? indice_weights.accessor<at::acc_type<scalar_t, true>, 1>()
9797
: at::TensorAccessor<at::acc_type<scalar_t, true>, 1>(nullptr, nullptr, nullptr),
@@ -223,7 +223,7 @@ for (const auto d : c10::irange(D)) {
223223
} // for each table
224224
}
225225

226-
template <typename scalar_t>
226+
template <typename index_t, typename scalar_t>
227227
void split_embedding_backward_exact_cpu_dense_kernel(
228228
Tensor grad,
229229
Tensor grad_output,
@@ -240,8 +240,8 @@ void split_embedding_backward_exact_cpu_dense_kernel(
240240

241241
auto grad_output_data = grad_output.accessor<scalar_t, 2>();
242242

243-
const auto indices_data = indices.accessor<int64_t, 1>();
244-
const auto offsets_data = offsets.accessor<int64_t, 1>();
243+
const auto indices_data = indices.accessor<index_t, 1>();
244+
const auto offsets_data = offsets.accessor<index_t, 1>();
245245
const auto indice_weights_data = indice_weights.defined()
246246
?
247247
// If indice_weights are not defined, then this accessor won't be
@@ -347,34 +347,42 @@ for (const auto d : c10::irange(D)) {
347347

348348
grad_output = grad_output.contiguous();
349349

350-
351-
FBGEMM_DISPATCH_FLOAT_AND_HALF(
350+
FBGEMM_DISPATCH_INTEGRAL_TYPES(
351+
indices.scalar_type(),
352+
"split_embedding_backward_exact_cpu_kernel_1", [&] {
353+
using index_t = scalar_t;
354+
355+
FBGEMM_DISPATCH_FLOAT_AND_HALF(
352356
grad_output.scalar_type(),
353-
"split_embedding_backward_exact_cpu_outer", [&]() {
354-
using grad_t = scalar_t;
357+
"split_embedding_backward_exact_cpu_kernel_2", [&] {
358+
using grad_t = scalar_t;
359+
355360
FBGEMM_DISPATCH_FLOAT_AND_HALF(
356-
host_weights.scalar_type(), "split_embedding_backward_exact_cpu", [&] {
357-
split_embedding_backward_exact_cpu_kernel<scalar_t, grad_t>(
358-
grad_output,
359-
host_weights,
360-
weights_offsets_data,
361-
D_offsets_data,
362-
hash_size_cumsum,
363-
indices,
364-
offsets,
365-
pooling_mode,
366-
indice_weights,
367-
num_tables,
368-
B,
369-
table_to_feature_offset,
370-
{% if "momentum1_offsets" in args.split_function_arg_names %}
371-
momentum1_offsets_data,
372-
{% endif %}
373-
{% if "momentum2_offsets" in args.split_function_arg_names %}
374-
momentum2_offsets_data,
375-
{% endif %}
376-
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
377-
});
361+
host_weights.scalar_type(),
362+
"split_embedding_backward_exact_cpu_kernel_3", [&] {
363+
364+
split_embedding_backward_exact_cpu_kernel<index_t, scalar_t, grad_t>(
365+
grad_output,
366+
host_weights,
367+
weights_offsets_data,
368+
D_offsets_data,
369+
hash_size_cumsum,
370+
indices,
371+
offsets,
372+
pooling_mode,
373+
indice_weights,
374+
num_tables,
375+
B,
376+
table_to_feature_offset,
377+
{% if "momentum1_offsets" in args.split_function_arg_names %}
378+
momentum1_offsets_data,
379+
{% endif %}
380+
{% if "momentum2_offsets" in args.split_function_arg_names %}
381+
momentum2_offsets_data,
382+
{% endif %}
383+
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
384+
});
385+
});
378386
});
379387

380388
return;
@@ -383,10 +391,16 @@ for (const auto d : c10::irange(D)) {
383391

384392
// When input is dense enough, avoid sorting and just treat as dense.
385393
auto grad = zeros_like(host_weights, grad_output.dtype());
386-
FBGEMM_DISPATCH_FLOAT_AND_HALF(
387-
grad_output.scalar_type(), "split_embedding_backward_exact_cpu", [&] {
394+
FBGEMM_DISPATCH_INTEGRAL_TYPES(
395+
indices.scalar_type(),
396+
"split_embedding_backward_exact_cpu_dense_kernel", [&] {
397+
using index_t = scalar_t;
388398

389-
split_embedding_backward_exact_cpu_dense_kernel<scalar_t>(
399+
FBGEMM_DISPATCH_FLOAT_AND_HALF(
400+
grad_output.scalar_type(),
401+
"split_embedding_backward_exact_cpu", [&] {
402+
403+
split_embedding_backward_exact_cpu_dense_kernel<index_t, scalar_t>(
390404
grad,
391405
grad_output,
392406
weights_offsets_data,
@@ -398,7 +412,8 @@ for (const auto d : c10::irange(D)) {
398412
num_tables,
399413
B,
400414
table_to_feature_offset);
401-
}); // dispatch host_weights.scalar_type()
415+
});
416+
});
402417

403418
return grad;
404419
{% endif %}

fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops_training.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3315,8 +3315,10 @@ def prepare_inputs(
33153315
)
33163316

33173317
if force_cast_input_types:
3318-
# Force casting indices and offsets to long
3319-
(indices, offsets) = indices.long(), offsets.long()
3318+
# NOTE: Force offsets to have the same dtype as indices since the
3319+
# kernels assume same dtype. We might need to revisit the assumption
3320+
# of same dtypes in the future.
3321+
offsets = offsets.to(dtype=indices.dtype)
33203322

33213323
# Force casting per_sample_weights to float
33223324
if per_sample_weights is not None:
@@ -3681,7 +3683,11 @@ def forward(
36813683
offsets, batch_size_per_feature_per_rank
36823684
)
36833685

3684-
(indices, offsets) = indices.long(), offsets.long()
3686+
# NOTE: Force offsets to have the same dtype as indices since the
3687+
# kernels assume same dtype. We might need to revisit the assumption
3688+
# of same dtypes in the future.
3689+
offsets = offsets.to(dtype=indices.dtype)
3690+
36853691
# Force casting per_sample_weights to float
36863692
if per_sample_weights is not None:
36873693
per_sample_weights = per_sample_weights.float()

0 commit comments

Comments
 (0)