Skip to content

Commit 9cc6db5

Browse files
q10facebook-github-bot
authored andcommitted
Add support for int32_t indices in TBE training (2I/N) (#3556)
Summary: X-link: facebookresearch/FBGEMM#642 - Add `int21_t` support to `::internal::csr2csc`, for eventual `int32_t` indices support in TBE CPU Reviewed By: basilwong, jianyuh Differential Revision: D67920539
1 parent 8a6e19a commit 9cc6db5

File tree

3 files changed

+76
-63
lines changed

3 files changed

+76
-63
lines changed

fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp

Lines changed: 55 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -380,18 +380,18 @@ namespace internal {
380380

381381
namespace {
382382

383-
template <typename scalar_t, bool IS_VALUE_PAIR>
383+
template <typename index_t, typename scalar_t, bool IS_VALUE_PAIR>
384384
void csr2csc_template_(
385385
HyperCompressedSparseColumn& csc,
386386
int B,
387-
const at::TensorAccessor<int64_t, 1>& csr_offsets,
388-
const at::TensorAccessor<int64_t, 1>& csr_indices,
387+
const at::TensorAccessor<index_t, 1>& csr_offsets,
388+
const at::TensorAccessor<index_t, 1>& csr_indices,
389389
const at::TensorAccessor<scalar_t, 1>& csr_weights,
390390
int64_t pooling_mode,
391391
const int* table_to_feature_offset,
392392
int64_t num_embeddings) {
393393
csc.num_non_zero_columns = 0;
394-
int64_t nnz = csr_offsets[table_to_feature_offset[1] * B] -
394+
const auto nnz = csr_offsets[table_to_feature_offset[1] * B] -
395395
csr_offsets[table_to_feature_offset[0] * B];
396396
if (nnz == 0) {
397397
return;
@@ -407,7 +407,7 @@ void csr2csc_template_(
407407
[[maybe_unused]] int column_ptr_curr = 0;
408408
bool is_shared_table =
409409
table_to_feature_offset[1] > table_to_feature_offset[0] + 1;
410-
auto NS = csr_offsets[table_to_feature_offset[1] * B] -
410+
const auto NS = csr_offsets[(size_t)table_to_feature_offset[1] * B] -
411411
csr_offsets[table_to_feature_offset[0] * B];
412412

413413
using pair_t = std::pair<int, scalar_t>;
@@ -432,9 +432,9 @@ void csr2csc_template_(
432432
#pragma omp parallel for
433433
for (int b = 0; b < B; ++b) {
434434
const auto FBb = feature * B + b;
435-
int64_t pool_begin = csr_offsets[FBb];
436-
int64_t pool_end = csr_offsets[FBb + 1];
437-
int64_t L = pool_end - pool_begin;
435+
const auto pool_begin = csr_offsets[FBb];
436+
const auto pool_end = csr_offsets[FBb + 1];
437+
const auto L = pool_end - pool_begin;
438438
// MEAN pooling will not work with indice_weights!
439439
double scale_factor =
440440
(static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN &&
@@ -581,47 +581,48 @@ void csr2csc_template_(
581581
assert(column_ptr_curr == nnz);
582582
}
583583

584-
#define INSTANTIATE_BATCHED_CSR2CSC(SCALAR_T) \
585-
template void csr2csc_template_<SCALAR_T, true>( \
586-
HyperCompressedSparseColumn & csc, \
587-
int B, \
588-
const at::TensorAccessor<int64_t, 1>& csr_offsets, \
589-
const at::TensorAccessor<int64_t, 1>& csr_indices, \
590-
const at::TensorAccessor<SCALAR_T, 1>& csr_weights, \
591-
int64_t pooling_mode, \
592-
const int* table_to_feature_offset, \
593-
int64_t num_embeddings); \
594-
\
595-
template void csr2csc_template_<SCALAR_T, false>( \
596-
HyperCompressedSparseColumn & csc, \
597-
int B, \
598-
const at::TensorAccessor<int64_t, 1>& csr_offsets, \
599-
const at::TensorAccessor<int64_t, 1>& csr_indices, \
600-
const at::TensorAccessor<SCALAR_T, 1>& csr_weights, \
601-
int64_t pooling_mode, \
602-
const int* table_to_feature_offset, \
584+
#define INSTANTIATE_CSR2CSC_TEMPLATE_0(index_t, scalar_t, is_value_pair) \
585+
template void csr2csc_template_<index_t, scalar_t, is_value_pair>( \
586+
HyperCompressedSparseColumn & csc, \
587+
int B, \
588+
const at::TensorAccessor<index_t, 1>& csr_offsets, \
589+
const at::TensorAccessor<index_t, 1>& csr_indices, \
590+
const at::TensorAccessor<scalar_t, 1>& csr_weights, \
591+
int64_t pooling_mode, \
592+
const int* table_to_feature_offset, \
603593
int64_t num_embeddings);
604594

605-
INSTANTIATE_BATCHED_CSR2CSC(float)
606-
INSTANTIATE_BATCHED_CSR2CSC(double)
607-
#undef INSTANTIATE_BATCHED_CSR2CSC
595+
#define INSTANTIATE_CSR2CSC_TEMPLATE_1(index_t, scalar_t) \
596+
INSTANTIATE_CSR2CSC_TEMPLATE_0(index_t, scalar_t, true); \
597+
INSTANTIATE_CSR2CSC_TEMPLATE_0(index_t, scalar_t, false);
598+
599+
#define INSTANTIATE_CSR2CSC_TEMPLATE_2(index_t) \
600+
INSTANTIATE_CSR2CSC_TEMPLATE_1(index_t, float); \
601+
INSTANTIATE_CSR2CSC_TEMPLATE_1(index_t, double);
602+
603+
INSTANTIATE_CSR2CSC_TEMPLATE_2(int32_t);
604+
INSTANTIATE_CSR2CSC_TEMPLATE_2(int64_t);
605+
606+
#undef INSTANTIATE_CSR2CSC_TEMPLATE_2
607+
#undef INSTANTIATE_CSR2CSC_TEMPLATE_1
608+
#undef INSTANTIATE_CSR2CSC_TEMPLATE_0
608609

609610
} // namespace
610611

611-
template <typename scalar_t>
612+
template <typename index_t, typename scalar_t>
612613
void csr2csc(
613614
HyperCompressedSparseColumn& csc,
614615
int B,
615-
const at::TensorAccessor<int64_t, 1>& csr_offsets,
616-
const at::TensorAccessor<int64_t, 1>& csr_indices,
616+
const at::TensorAccessor<index_t, 1>& csr_offsets,
617+
const at::TensorAccessor<index_t, 1>& csr_indices,
617618
const at::TensorAccessor<scalar_t, 1>& csr_weights,
618619
int64_t pooling_mode,
619620
const int* table_to_feature_offset,
620621
int64_t num_embeddings) {
621622
bool has_weights = csr_weights.data() != nullptr;
622623
if (has_weights ||
623624
static_cast<PoolingMode>(pooling_mode) == PoolingMode::MEAN) {
624-
csr2csc_template_<scalar_t, /*IS_VALUE_PAIR=*/true>(
625+
csr2csc_template_<index_t, scalar_t, /*IS_VALUE_PAIR=*/true>(
625626
csc,
626627
B,
627628
csr_offsets,
@@ -631,7 +632,7 @@ void csr2csc(
631632
table_to_feature_offset,
632633
num_embeddings);
633634
} else {
634-
csr2csc_template_<scalar_t, /*IS_VALUE_PAIR=*/false>(
635+
csr2csc_template_<index_t, scalar_t, /*IS_VALUE_PAIR=*/false>(
635636
csc,
636637
B,
637638
csr_offsets,
@@ -643,25 +644,26 @@ void csr2csc(
643644
}
644645
}
645646

646-
template void csr2csc<float>(
647-
HyperCompressedSparseColumn& csc,
648-
int B,
649-
const at::TensorAccessor<int64_t, 1>& csr_offsets,
650-
const at::TensorAccessor<int64_t, 1>& csr_indices,
651-
const at::TensorAccessor<float, 1>& csr_weights,
652-
int64_t pooling_mode,
653-
const int* table_to_feature_offset,
654-
int64_t num_embeddings);
647+
#define INSTANTIATE_CSR2CSC_0(index_t, scalar_t) \
648+
template void csr2csc<index_t, scalar_t>( \
649+
HyperCompressedSparseColumn & csc, \
650+
int B, \
651+
const at::TensorAccessor<index_t, 1>& csr_offsets, \
652+
const at::TensorAccessor<index_t, 1>& csr_indices, \
653+
const at::TensorAccessor<scalar_t, 1>& csr_weights, \
654+
int64_t pooling_mode, \
655+
const int* table_to_feature_offset, \
656+
int64_t num_embeddings);
655657

656-
template void csr2csc<double>(
657-
HyperCompressedSparseColumn& csc,
658-
int B,
659-
const at::TensorAccessor<int64_t, 1>& csr_offsets,
660-
const at::TensorAccessor<int64_t, 1>& csr_indices,
661-
const at::TensorAccessor<double, 1>& csr_weights,
662-
int64_t pooling_mode,
663-
const int* table_to_feature_offset,
664-
int64_t num_embeddings);
658+
#define INSTANTIATE_CSR2CSC_1(index_t) \
659+
INSTANTIATE_CSR2CSC_0(index_t, float); \
660+
INSTANTIATE_CSR2CSC_0(index_t, double);
661+
662+
INSTANTIATE_CSR2CSC_1(int32_t);
663+
INSTANTIATE_CSR2CSC_1(int64_t);
664+
665+
#undef INSTANTIATE_CSR2CSC_1
666+
#undef INSTANTIATE_CSR2CSC_0
665667

666668
} // namespace internal
667669

fbgemm_gpu/include/fbgemm_gpu/embedding_forward_split_cpu.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,12 @@ struct HyperCompressedSparseColumn {
116116
}
117117
};
118118

119-
template <typename scalar_t>
119+
template <typename index_t, typename scalar_t>
120120
void csr2csc(
121121
HyperCompressedSparseColumn& csc,
122122
int B,
123-
const at::TensorAccessor<int64_t, 1>& csr_offsets,
124-
const at::TensorAccessor<int64_t, 1>& csr_indices,
123+
const at::TensorAccessor<index_t, 1>& csr_offsets,
124+
const at::TensorAccessor<index_t, 1>& csr_indices,
125125
const at::TensorAccessor<scalar_t, 1>& csr_weights,
126126
int64_t pooling_mode,
127127
const int* table_to_feature_offset,

fbgemm_gpu/test/tbe/utils/cpu_kernel_test.cpp

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,23 @@
1515
#include "fbgemm_gpu/embedding_forward_split_cpu.h"
1616
#include "torch/types.h" // @manual=//caffe2:torch-cpp-cpu
1717

18-
TEST(CpuKernelTest, csr2csc_test) {
18+
template <c10::ScalarType DType, typename T>
19+
void test_csr2csc() {
1920
internal::HyperCompressedSparseColumn csc;
2021
int B = 2;
21-
at::Tensor offsets = torch::tensor({0, 4, 8});
22-
at::Tensor indices = torch::tensor({1, 2, 4, 5, 4, 3, 2, 9});
22+
at::Tensor offsets =
23+
torch::tensor({0, 4, 8}, torch::TensorOptions().dtype(DType));
24+
at::Tensor indices = torch::tensor(
25+
{1, 2, 4, 5, 4, 3, 2, 9}, torch::TensorOptions().dtype(DType));
2326
int64_t pooling_mode = (int64_t)fbgemm_gpu::PoolingMode::SUM;
2427
int table_to_feature_offset[2] = {0, 1};
2528
int num_embeddings = 10;
2629

2730
::internal::csr2csc(
2831
csc,
2932
B,
30-
offsets.accessor<int64_t, 1>(),
31-
indices.accessor<int64_t, 1>(),
33+
offsets.accessor<T, 1>(),
34+
indices.accessor<T, 1>(),
3235
at::TensorAccessor<at::acc_type<float, true>, 1>(
3336
nullptr, nullptr, nullptr), // no weights
3437
pooling_mode,
@@ -61,8 +64,8 @@ TEST(CpuKernelTest, csr2csc_test) {
6164
::internal::csr2csc(
6265
csc_weighted,
6366
B,
64-
offsets.accessor<int64_t, 1>(),
65-
indices.accessor<int64_t, 1>(),
67+
offsets.accessor<T, 1>(),
68+
indices.accessor<T, 1>(),
6669
indice_weights.accessor<at::acc_type<float, true>, 1>(),
6770
pooling_mode,
6871
table_to_feature_offset,
@@ -88,3 +91,11 @@ TEST(CpuKernelTest, csr2csc_test) {
8891
EXPECT_EQ(expect_weights[i], csc_weighted.weights[i]);
8992
}
9093
}
94+
95+
TEST(CpuKernelTest, csr2csc_test_int32) {
96+
test_csr2csc<torch::kInt32, int32_t>();
97+
}
98+
99+
TEST(CpuKernelTest, csr2csc_test_int64) {
100+
test_csr2csc<torch::kInt64, int64_t>();
101+
}

0 commit comments

Comments
 (0)