14
14
#include " fbgemm_gpu/utils/cpu_utils.h"
15
15
#include " fbgemm_gpu/utils/dispatch_macros.h"
16
16
#include " fbgemm_gpu/utils/ops_utils.h"
17
+ #include " fbgemm_gpu/utils/tensor_accessor.h"
17
18
#ifdef FBCODE_CAFFE2
18
19
#include < libdivide.h>
19
20
#else
@@ -384,9 +385,9 @@ template <typename index_t, typename scalar_t, bool IS_VALUE_PAIR>
384
385
void csr2csc_template_ (
385
386
HyperCompressedSparseColumn& csc,
386
387
int B,
387
- const at ::TensorAccessor<index_t , 1 >& csr_offsets,
388
- const at ::TensorAccessor<index_t , 1 >& csr_indices,
389
- const at ::TensorAccessor<scalar_t , 1 >& csr_weights,
388
+ const pta ::TensorAccessor<index_t , 1 >& csr_offsets,
389
+ const pta ::TensorAccessor<index_t , 1 >& csr_indices,
390
+ const pta ::TensorAccessor<scalar_t , 1 >& csr_weights,
390
391
int64_t pooling_mode,
391
392
const int * table_to_feature_offset,
392
393
int64_t num_embeddings) {
@@ -585,9 +586,9 @@ void csr2csc_template_(
585
586
template void csr2csc_template_<index_t , scalar_t , is_value_pair>( \
586
587
HyperCompressedSparseColumn & csc, \
587
588
int B, \
588
- const at ::TensorAccessor<index_t , 1 >& csr_offsets, \
589
- const at ::TensorAccessor<index_t , 1 >& csr_indices, \
590
- const at ::TensorAccessor<scalar_t , 1 >& csr_weights, \
589
+ const pta ::TensorAccessor<index_t , 1 >& csr_offsets, \
590
+ const pta ::TensorAccessor<index_t , 1 >& csr_indices, \
591
+ const pta ::TensorAccessor<scalar_t , 1 >& csr_weights, \
591
592
int64_t pooling_mode, \
592
593
const int * table_to_feature_offset, \
593
594
int64_t num_embeddings);
@@ -613,9 +614,9 @@ template <typename index_t, typename scalar_t>
613
614
void csr2csc (
614
615
HyperCompressedSparseColumn& csc,
615
616
int B,
616
- const at ::TensorAccessor<index_t , 1 >& csr_offsets,
617
- const at ::TensorAccessor<index_t , 1 >& csr_indices,
618
- const at ::TensorAccessor<scalar_t , 1 >& csr_weights,
617
+ const pta ::TensorAccessor<index_t , 1 >& csr_offsets,
618
+ const pta ::TensorAccessor<index_t , 1 >& csr_indices,
619
+ const pta ::TensorAccessor<scalar_t , 1 >& csr_weights,
619
620
int64_t pooling_mode,
620
621
const int * table_to_feature_offset,
621
622
int64_t num_embeddings) {
@@ -644,15 +645,15 @@ void csr2csc(
644
645
}
645
646
}
646
647
647
- #define INSTANTIATE_CSR2CSC_0 (index_t, scalar_t ) \
648
- template void csr2csc<index_t , scalar_t >( \
649
- HyperCompressedSparseColumn & csc, \
650
- int B, \
651
- const at ::TensorAccessor<index_t , 1 >& csr_offsets, \
652
- const at ::TensorAccessor<index_t , 1 >& csr_indices, \
653
- const at ::TensorAccessor<scalar_t , 1 >& csr_weights, \
654
- int64_t pooling_mode, \
655
- const int * table_to_feature_offset, \
648
+ #define INSTANTIATE_CSR2CSC_0 (index_t, scalar_t ) \
649
+ template void csr2csc<index_t , scalar_t >( \
650
+ HyperCompressedSparseColumn & csc, \
651
+ int B, \
652
+ const pta ::TensorAccessor<index_t , 1 >& csr_offsets, \
653
+ const pta ::TensorAccessor<index_t , 1 >& csr_indices, \
654
+ const pta ::TensorAccessor<scalar_t , 1 >& csr_weights, \
655
+ int64_t pooling_mode, \
656
+ const int * table_to_feature_offset, \
656
657
int64_t num_embeddings);
657
658
658
659
#define INSTANTIATE_CSR2CSC_1 (index_t ) \
0 commit comments