Skip to content

Commit 736737e

Browse files
q10facebook-github-bot
authored andcommitted
Updates and fixes to tensor_accessor.h (pytorch#3571)
Summary: X-link: facebookresearch/FBGEMM#656 - Fix `TensorAccessorBase` constructor to work with empty tensors, which are used in FBGEMM code - Add better logging for errors Differential Revision: D68048640
1 parent 310982f commit 736737e

File tree

7 files changed

+368
-81
lines changed

7 files changed

+368
-81
lines changed

fbgemm_gpu/codegen/training/backward/embedding_backward_split_cpu_template.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,16 @@ for (const auto t : c10::irange(num_tables)) {
8787
int feature_begin = table_to_feature_offset[t];
8888
int64_t hash_size = get_hash_size(feature_begin);
8989

90+
#ifdef FBGEMM_GPU_MEMCHECK
91+
const auto func_name = "::internal::csr2csc";
92+
#endif
93+
using weight_t = at::acc_type<scalar_t, true>;
9094
::internal::csr2csc(
9195
cscs[t],
9296
B,
93-
offsets.accessor<int64_t, 1>(),
94-
indices.accessor<int64_t, 1>(),
95-
indice_weights.defined()
96-
? indice_weights.accessor<at::acc_type<scalar_t, true>, 1>()
97-
: at::TensorAccessor<at::acc_type<scalar_t, true>, 1>(nullptr, nullptr, nullptr),
97+
MAKE_TA_WITH_NAME(func_name, offsets, int64_t, 1),
98+
MAKE_TA_WITH_NAME(func_name, indices, int64_t, 1),
99+
MAKE_TA_WITH_NAME(func_name, indice_weights, weight_t, 1),
98100
pooling_mode,
99101
table_to_feature_offset + t,
100102
hash_size);

fbgemm_gpu/codegen/training/forward/embedding_forward_split_cpu.cpp

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "fbgemm_gpu/utils/cpu_utils.h"
1515
#include "fbgemm_gpu/utils/dispatch_macros.h"
1616
#include "fbgemm_gpu/utils/ops_utils.h"
17+
#include "fbgemm_gpu/utils/tensor_accessor.h"
1718
#ifdef FBCODE_CAFFE2
1819
#include <libdivide.h>
1920
#else
@@ -384,9 +385,9 @@ template <typename index_t, typename scalar_t, bool IS_VALUE_PAIR>
384385
void csr2csc_template_(
385386
HyperCompressedSparseColumn& csc,
386387
int B,
387-
const at::TensorAccessor<index_t, 1>& csr_offsets,
388-
const at::TensorAccessor<index_t, 1>& csr_indices,
389-
const at::TensorAccessor<scalar_t, 1>& csr_weights,
388+
const pta::TensorAccessor<index_t, 1>& csr_offsets,
389+
const pta::TensorAccessor<index_t, 1>& csr_indices,
390+
const pta::TensorAccessor<scalar_t, 1>& csr_weights,
390391
int64_t pooling_mode,
391392
const int* table_to_feature_offset,
392393
int64_t num_embeddings) {
@@ -585,9 +586,9 @@ void csr2csc_template_(
585586
template void csr2csc_template_<index_t, scalar_t, is_value_pair>( \
586587
HyperCompressedSparseColumn & csc, \
587588
int B, \
588-
const at::TensorAccessor<index_t, 1>& csr_offsets, \
589-
const at::TensorAccessor<index_t, 1>& csr_indices, \
590-
const at::TensorAccessor<scalar_t, 1>& csr_weights, \
589+
const pta::TensorAccessor<index_t, 1>& csr_offsets, \
590+
const pta::TensorAccessor<index_t, 1>& csr_indices, \
591+
const pta::TensorAccessor<scalar_t, 1>& csr_weights, \
591592
int64_t pooling_mode, \
592593
const int* table_to_feature_offset, \
593594
int64_t num_embeddings);
@@ -613,9 +614,9 @@ template <typename index_t, typename scalar_t>
613614
void csr2csc(
614615
HyperCompressedSparseColumn& csc,
615616
int B,
616-
const at::TensorAccessor<index_t, 1>& csr_offsets,
617-
const at::TensorAccessor<index_t, 1>& csr_indices,
618-
const at::TensorAccessor<scalar_t, 1>& csr_weights,
617+
const pta::TensorAccessor<index_t, 1>& csr_offsets,
618+
const pta::TensorAccessor<index_t, 1>& csr_indices,
619+
const pta::TensorAccessor<scalar_t, 1>& csr_weights,
619620
int64_t pooling_mode,
620621
const int* table_to_feature_offset,
621622
int64_t num_embeddings) {
@@ -644,15 +645,15 @@ void csr2csc(
644645
}
645646
}
646647

647-
#define INSTANTIATE_CSR2CSC_0(index_t, scalar_t) \
648-
template void csr2csc<index_t, scalar_t>( \
649-
HyperCompressedSparseColumn & csc, \
650-
int B, \
651-
const at::TensorAccessor<index_t, 1>& csr_offsets, \
652-
const at::TensorAccessor<index_t, 1>& csr_indices, \
653-
const at::TensorAccessor<scalar_t, 1>& csr_weights, \
654-
int64_t pooling_mode, \
655-
const int* table_to_feature_offset, \
648+
#define INSTANTIATE_CSR2CSC_0(index_t, scalar_t) \
649+
template void csr2csc<index_t, scalar_t>( \
650+
HyperCompressedSparseColumn & csc, \
651+
int B, \
652+
const pta::TensorAccessor<index_t, 1>& csr_offsets, \
653+
const pta::TensorAccessor<index_t, 1>& csr_indices, \
654+
const pta::TensorAccessor<scalar_t, 1>& csr_weights, \
655+
int64_t pooling_mode, \
656+
const int* table_to_feature_offset, \
656657
int64_t num_embeddings);
657658

658659
#define INSTANTIATE_CSR2CSC_1(index_t) \

fbgemm_gpu/include/fbgemm_gpu/embedding_forward_split_cpu.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <ATen/ATen.h>
1212
#include <ATen/Parallel.h>
1313
#include "fbgemm/Utils.h"
14+
#include "fbgemm_gpu/utils/tensor_accessor.h"
1415

1516
at::Tensor split_embedding_codegen_forward_cpu(
1617
at::Tensor weights,
@@ -120,9 +121,9 @@ template <typename index_t, typename scalar_t>
120121
void csr2csc(
121122
HyperCompressedSparseColumn& csc,
122123
int B,
123-
const at::TensorAccessor<index_t, 1>& csr_offsets,
124-
const at::TensorAccessor<index_t, 1>& csr_indices,
125-
const at::TensorAccessor<scalar_t, 1>& csr_weights,
124+
const pta::TensorAccessor<index_t, 1>& csr_offsets,
125+
const pta::TensorAccessor<index_t, 1>& csr_indices,
126+
const pta::TensorAccessor<scalar_t, 1>& csr_weights,
126127
int64_t pooling_mode,
127128
const int* table_to_feature_offset,
128129
int64_t num_embeddings);

0 commit comments

Comments
 (0)