Skip to content

Commit 91e9ddd

Browse files
q10facebook-github-bot
authored andcommitted
Add support for int32_t indices in TBE training (2H/N) (pytorch#3539)
Summary: X-link: facebookresearch/FBGEMM#626 - Add `int32_t` indices support for the TBE CPU kernels Reviewed By: sryap Differential Revision: D67784746
1 parent 5b89d8c commit 91e9ddd

File tree

5 files changed

+195
-126
lines changed

5 files changed

+195
-126
lines changed

fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_approx_template.cpp

Lines changed: 85 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,17 @@
1919
#include "fbgemm_gpu/embedding_common.h"
2020
#include "fbgemm_gpu/utils/dispatch_macros.h"
2121

22+
#if FBGEMM_GPU_MEMCHECK
23+
#define FBGEMM_MEM_CHECK_ONLY
24+
#else
25+
#define FBGEMM_MEM_CHECK_ONLY maybe_unused
26+
#endif
27+
2228
using Tensor = at::Tensor;
2329
using namespace fbgemm_gpu;
2430

2531
namespace {
26-
template <typename scalar_t, typename grad_t>
32+
template <typename index_t, typename scalar_t, typename grad_t>
2733
void split_embedding_backward_approx_cpu_kernel(
2834
Tensor grad_output,
2935
Tensor host_weights,
@@ -44,8 +50,11 @@ void split_embedding_backward_approx_cpu_kernel(
4450
{{ args.split_cpu_kernel_args | join(", ") }}) {
4551
auto grad_output_data = grad_output.accessor<grad_t, 2>();
4652
auto host_weights_data = host_weights.accessor<scalar_t, 1>();
47-
const auto indices_data = indices.accessor<int64_t, 1>();
48-
const auto offsets_data = offsets.accessor<int64_t, 1>();
53+
54+
[[FBGEMM_MEM_CHECK_ONLY]] const auto func_name = "split_embedding_backward_approx_cpu_kernel";
55+
const auto indices_data = MAKE_TA_WITH_NAME(func_name, indices, index_t, 1);
56+
const auto offsets_data = MAKE_TA_WITH_NAME(func_name, offsets, index_t, 1);
57+
4958
// If indice_weights are not defined, then this accessor won't be used
5059
auto indice_weights_data = indice_weights.defined()
5160
? indice_weights.accessor<at::acc_type<scalar_t, true>, 1>()
@@ -133,75 +142,84 @@ split_embedding_backward_codegen_{{ optimizer }}_cpu(
133142
!indice_weights.defined() && static_cast<PoolingMode>(pooling_mode) == PoolingMode::SUM;
134143

135144
if (use_fbgemm) {
136-
auto grad_stride = grad_output.size(1);
137-
const float* grad_output_data = grad_output.data_ptr<float>();
138-
float* host_weights_data = host_weights.data_ptr<float>();
139-
const int64_t* indices_data = indices.data_ptr<int64_t>();
140-
const int64_t* offsets_data = offsets.data_ptr<int64_t>();
141-
const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();
142-
float* momentum1_data = momentum1_host.data_ptr<float>();
143-
144-
at::parallel_for(0, T * B, 0, [&](int64_t tb_begin, int64_t tb_end) {
145-
int t_begin = tb_begin / B;
146-
int t_end = (tb_end + B - 1) / B;
147-
for (const auto t : c10::irange(t_begin,t_end)) {
148-
auto D_begin = D_offsets_data[t];
149-
auto D = D_offsets_data[t + 1] - D_offsets_data[t];
150-
auto table_begin = weights_offsets_data[t];
151-
auto momentum_begin = momentum1_offsets_data[t];
152-
153-
int64_t hash_size;
154-
int t_temp = t + 1;
155-
do {
156-
hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t];
157-
++t_temp;
158-
} while (hash_size == 0);
159-
160-
int b_begin = (t == t_begin) ? tb_begin % B : 0;
161-
int b_end = (t == t_end - 1 && tb_end % B != 0) ? tb_end % B : B;
162-
163-
auto kernel =
164-
fbgemm::GenerateRowWiseSparseAdaGradFused<int64_t, int64_t, float>(
165-
D,
166-
/*prefetch=*/16,
167-
/*use_offsets=*/true,
168-
/*use_stochastic_round=*/true,
169-
/*grad_stride=*/grad_stride);
170-
auto offsets_begin_ptr = offsets_data + t * B + b_begin;
171-
auto index_size = offsets_data[t * B + b_end] - *offsets_begin_ptr;
172-
bool success = kernel(
173-
b_end - b_begin,
174-
index_size,
175-
hash_size,
176-
reinterpret_cast<float*>(host_weights_data + table_begin),
177-
reinterpret_cast<const float*>(
178-
grad_output_data + b_begin * grad_stride + D_begin),
179-
reinterpret_cast<float*>(momentum1_data + momentum_begin),
180-
indices_data + *offsets_begin_ptr,
181-
offsets_begin_ptr,
182-
eps,
183-
// fbgemm follows caffe2 convention of negative learning rate
184-
-learning_rate);
185-
186-
if (!success) {
187-
fbgemm_gpu::report_embedding_error(
188-
t, B, b_begin, b_end, offsets_data, indices_data, hash_size);
145+
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "split_embedding_backward_approx_cpu_kernel_1", [&] {
146+
147+
auto grad_stride = grad_output.size(1);
148+
const float* grad_output_data = grad_output.data_ptr<float>();
149+
float* host_weights_data = host_weights.data_ptr<float>();
150+
151+
const auto* indices_data = indices.data_ptr<index_t>();
152+
const auto* offsets_data = offsets.data_ptr<index_t>();
153+
154+
const auto hash_size_cumsum_data = hash_size_cumsum.accessor<int64_t, 1>();
155+
float* momentum1_data = momentum1_host.data_ptr<float>();
156+
157+
at::parallel_for(0, T * B, 0, [&](int64_t tb_begin, int64_t tb_end) {
158+
int t_begin = tb_begin / B;
159+
int t_end = (tb_end + B - 1) / B;
160+
161+
for (const auto t : c10::irange(t_begin,t_end)) {
162+
auto D_begin = D_offsets_data[t];
163+
auto D = D_offsets_data[t + 1] - D_offsets_data[t];
164+
auto table_begin = weights_offsets_data[t];
165+
auto momentum_begin = momentum1_offsets_data[t];
166+
167+
int64_t hash_size;
168+
int t_temp = t + 1;
169+
do {
170+
hash_size = hash_size_cumsum_data[t_temp] - hash_size_cumsum_data[t];
171+
++t_temp;
172+
} while (hash_size == 0);
173+
174+
int b_begin = (t == t_begin) ? tb_begin % B : 0;
175+
int b_end = (t == t_end - 1 && tb_end % B != 0) ? tb_end % B : B;
176+
177+
auto kernel =
178+
fbgemm::GenerateRowWiseSparseAdaGradFused<index_t, index_t, float>(
179+
D,
180+
/*prefetch=*/16,
181+
/*use_offsets=*/true,
182+
/*use_stochastic_round=*/true,
183+
/*grad_stride=*/grad_stride);
184+
auto offsets_begin_ptr = offsets_data + t * B + b_begin;
185+
auto index_size = offsets_data[t * B + b_end] - *offsets_begin_ptr;
186+
bool success = kernel(
187+
b_end - b_begin,
188+
index_size,
189+
hash_size,
190+
reinterpret_cast<float*>(host_weights_data + table_begin),
191+
reinterpret_cast<const float*>(
192+
grad_output_data + b_begin * grad_stride + D_begin),
193+
reinterpret_cast<float*>(momentum1_data + momentum_begin),
194+
indices_data + *offsets_begin_ptr,
195+
offsets_begin_ptr,
196+
eps,
197+
// fbgemm follows caffe2 convention of negative learning rate
198+
-learning_rate);
199+
200+
if (!success) {
201+
fbgemm_gpu::report_embedding_error(
202+
t, B, b_begin, b_end, offsets_data, indices_data, hash_size);
203+
}
189204
}
190-
}
191-
}); // parallel_for
205+
}); // parallel_for
206+
}); // dispatch indices.scalar_type()
207+
192208
return;
193209
} // use_fbgemm
194210

195211
{% endif %}
196212

197-
FBGEMM_DISPATCH_FLOAT_AND_HALF(
198-
grad_output.scalar_type(), "split_embedding_backward_cpu", [&] {
213+
AT_DISPATCH_INDEX_TYPES(
214+
indices.scalar_type(), "split_embedding_backward_approx_cpu_kernel_1", [&] {
215+
216+
FBGEMM_DISPATCH_FLOAT_AND_HALF(
217+
grad_output.scalar_type(), "split_embedding_backward_approx_cpu_kernel_2", [&] {
199218
using grad_t = scalar_t;
200-
FBGEMM_DISPATCH_FLOAT_AND_HALF(
201-
host_weights.scalar_type(),
202-
"split_embedding_backward_cpu_inner",
203-
[&] {
204-
split_embedding_backward_approx_cpu_kernel<scalar_t, grad_t>(
219+
220+
FBGEMM_DISPATCH_FLOAT_AND_HALF(
221+
host_weights.scalar_type(), "split_embedding_backward_approx_cpu_kernel_3", [&] {
222+
split_embedding_backward_approx_cpu_kernel<index_t, scalar_t, grad_t>(
205223
grad_output,
206224
host_weights,
207225
weights_offsets_data,
@@ -220,7 +238,8 @@ for (const auto t : c10::irange(t_begin,t_end)) {
220238
{% endif %}
221239
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
222240
}); // dispatch host_weights.scalar_type()
223-
}); // dispatch grad_output.scalar_type()
241+
}); // dispatch grad_output.scalar_type()
242+
}); // dispatch indices.scalar_type()
224243

225244
return;
226245
}

fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,12 @@
2424
#include "fbgemm_gpu/utils/cpu_utils.h"
2525
#include "fbgemm_gpu/utils/ops_utils.h"
2626

27+
#if FBGEMM_GPU_MEMCHECK
28+
#define FBGEMM_MEM_CHECK_ONLY
29+
#else
30+
#define FBGEMM_MEM_CHECK_ONLY maybe_unused
31+
#endif
32+
2733
using Tensor = at::Tensor;
2834
using namespace fbgemm_gpu;
2935

@@ -40,7 +46,7 @@ struct half2float16<at::Half> {
4046
} // namespace internal
4147

4248
namespace {
43-
template <typename scalar_t, typename grad_t>
49+
template <typename index_t, typename scalar_t, typename grad_t>
4450
void split_embedding_backward_exact_cpu_kernel(
4551
Tensor grad_output,
4652
Tensor host_weights,
@@ -94,8 +100,8 @@ for (const auto t : c10::irange(num_tables)) {
94100
::internal::csr2csc(
95101
cscs[t],
96102
B,
97-
MAKE_TA_WITH_NAME(func_name, offsets, int64_t, 1),
98-
MAKE_TA_WITH_NAME(func_name, indices, int64_t, 1),
103+
MAKE_TA_WITH_NAME(func_name, offsets, index_t, 1),
104+
MAKE_TA_WITH_NAME(func_name, indices, index_t, 1),
99105
MAKE_TA_WITH_NAME(func_name, indice_weights, weight_t, 1),
100106
pooling_mode,
101107
table_to_feature_offset + t,
@@ -196,19 +202,21 @@ for (const auto t : c10::irange(num_tables)) {
196202
// TODO: to parallelize, we should easily identify segments belong to
197203
// the same column.
198204
at::acc_type<grad_t, true> grad_buffer[D];
199-
for (const auto c : c10::irange(num_non_zero_columns)) {
205+
for (const auto c : c10::irange(num_non_zero_columns)) {
200206
int64_t idx = col_segment_indices[c];
201207
if (c == 0 || col_segment_indices[c - 1] != idx) {
202208
memset(grad_buffer, 0, D * sizeof(at::acc_type<grad_t, true>));
203209
}
204210
[[maybe_unused]] const int64_t embedding_begin = table_begin + idx * D;
211+
205212
for (int r = col_segment_ptr[c]; r < col_segment_ptr[c + 1]; ++r) {
206213
int D_offset = D_begin;
207214
if (is_shared_table) {
208215
D_offset += cscs[t].column_segment_ids[r] * D;
209216
}
210217
int b = cscs[t].row_indices[r];
211-
for (const auto d : c10::irange(D)) {
218+
219+
for (const auto d : c10::irange(D)) {
212220
if (cscs[t].weights != nullptr) {
213221
grad_buffer[d] += grad_output_data[b * grad_stride + D_offset + d] *
214222
cscs[t].weights[r];
@@ -225,7 +233,7 @@ for (const auto d : c10::irange(D)) {
225233
} // for each table
226234
}
227235

228-
template <typename scalar_t>
236+
template <typename index_t, typename scalar_t>
229237
void split_embedding_backward_exact_cpu_dense_kernel(
230238
Tensor grad,
231239
Tensor grad_output,
@@ -242,8 +250,10 @@ void split_embedding_backward_exact_cpu_dense_kernel(
242250

243251
auto grad_output_data = grad_output.accessor<scalar_t, 2>();
244252

245-
const auto indices_data = indices.accessor<int64_t, 1>();
246-
const auto offsets_data = offsets.accessor<int64_t, 1>();
253+
[[FBGEMM_MEM_CHECK_ONLY]] const auto func_name = "split_embedding_backward_exact_cpu_dense_kernel";
254+
255+
const auto indices_data = MAKE_TA_WITH_NAME(func_name, indices, index_t, 1);
256+
const auto offsets_data = MAKE_TA_WITH_NAME(func_name, offsets, index_t, 1);
247257
const auto indice_weights_data = indice_weights.defined()
248258
?
249259
// If indice_weights are not defined, then this accessor won't be
@@ -349,34 +359,41 @@ for (const auto d : c10::irange(D)) {
349359

350360
grad_output = grad_output.contiguous();
351361

352-
353-
FBGEMM_DISPATCH_FLOAT_AND_HALF(
362+
AT_DISPATCH_INDEX_TYPES(
363+
indices.scalar_type(),
364+
"split_embedding_backward_exact_cpu_kernel_1", [&] {
365+
366+
FBGEMM_DISPATCH_FLOAT_AND_HALF(
354367
grad_output.scalar_type(),
355-
"split_embedding_backward_exact_cpu_outer", [&]() {
356-
using grad_t = scalar_t;
368+
"split_embedding_backward_exact_cpu_kernel_2", [&] {
369+
using grad_t = scalar_t;
370+
357371
FBGEMM_DISPATCH_FLOAT_AND_HALF(
358-
host_weights.scalar_type(), "split_embedding_backward_exact_cpu", [&] {
359-
split_embedding_backward_exact_cpu_kernel<scalar_t, grad_t>(
360-
grad_output,
361-
host_weights,
362-
weights_offsets_data,
363-
D_offsets_data,
364-
hash_size_cumsum,
365-
indices,
366-
offsets,
367-
pooling_mode,
368-
indice_weights,
369-
num_tables,
370-
B,
371-
table_to_feature_offset,
372-
{% if "momentum1_offsets" in args.split_function_arg_names %}
373-
momentum1_offsets_data,
374-
{% endif %}
375-
{% if "momentum2_offsets" in args.split_function_arg_names %}
376-
momentum2_offsets_data,
377-
{% endif %}
378-
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
379-
});
372+
host_weights.scalar_type(),
373+
"split_embedding_backward_exact_cpu_kernel_3", [&] {
374+
375+
split_embedding_backward_exact_cpu_kernel<index_t, scalar_t, grad_t>(
376+
grad_output,
377+
host_weights,
378+
weights_offsets_data,
379+
D_offsets_data,
380+
hash_size_cumsum,
381+
indices,
382+
offsets,
383+
pooling_mode,
384+
indice_weights,
385+
num_tables,
386+
B,
387+
table_to_feature_offset,
388+
{% if "momentum1_offsets" in args.split_function_arg_names %}
389+
momentum1_offsets_data,
390+
{% endif %}
391+
{% if "momentum2_offsets" in args.split_function_arg_names %}
392+
momentum2_offsets_data,
393+
{% endif %}
394+
{{ args.split_cpu_kernel_arg_constructors | join(", ") }});
395+
});
396+
});
380397
});
381398

382399
return;
@@ -385,10 +402,15 @@ for (const auto d : c10::irange(D)) {
385402

386403
// When input is dense enough, avoid sorting and just treat as dense.
387404
auto grad = zeros_like(host_weights, grad_output.dtype());
388-
FBGEMM_DISPATCH_FLOAT_AND_HALF(
389-
grad_output.scalar_type(), "split_embedding_backward_exact_cpu", [&] {
405+
AT_DISPATCH_INDEX_TYPES(
406+
indices.scalar_type(),
407+
"split_embedding_backward_exact_cpu_dense_kernel", [&] {
390408

391-
split_embedding_backward_exact_cpu_dense_kernel<scalar_t>(
409+
FBGEMM_DISPATCH_FLOAT_AND_HALF(
410+
grad_output.scalar_type(),
411+
"split_embedding_backward_exact_cpu", [&] {
412+
413+
split_embedding_backward_exact_cpu_dense_kernel<index_t, scalar_t>(
392414
grad,
393415
grad_output,
394416
weights_offsets_data,
@@ -400,7 +422,8 @@ for (const auto d : c10::irange(D)) {
400422
num_tables,
401423
B,
402424
table_to_feature_offset);
403-
}); // dispatch host_weights.scalar_type()
425+
});
426+
});
404427

405428
return grad;
406429
{% endif %}

fbgemm_gpu/codegen/training/backward/embedding_backward_split_host_cpu_template.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ class SplitLookupFunction_{{ optimizer }}_Op : public torch::autograd::Function<
158158
stochastic_rounding,
159159
{{ args.split_function_arg_names | join(", ") }},
160160
output_dtype);
161+
161162
static auto op2 =
162163
torch::Dispatcher::singleton()
163164
.findSchemaOrThrow("fbgemm::split_embedding_codegen_grad_indice_weights_cpu", "")

0 commit comments

Comments
 (0)