@@ -380,18 +380,18 @@ namespace internal {
380
380
381
381
namespace {
382
382
383
- template <typename scalar_t , bool IS_VALUE_PAIR>
383
+ template <typename index_t , typename scalar_t , bool IS_VALUE_PAIR>
384
384
void csr2csc_template_ (
385
385
HyperCompressedSparseColumn& csc,
386
386
int B,
387
- const at::TensorAccessor<int64_t , 1 >& csr_offsets,
388
- const at::TensorAccessor<int64_t , 1 >& csr_indices,
387
+ const at::TensorAccessor<index_t , 1 >& csr_offsets,
388
+ const at::TensorAccessor<index_t , 1 >& csr_indices,
389
389
const at::TensorAccessor<scalar_t , 1 >& csr_weights,
390
390
int64_t pooling_mode,
391
391
const int * table_to_feature_offset,
392
392
int64_t num_embeddings) {
393
393
csc.num_non_zero_columns = 0 ;
394
- int64_t nnz = csr_offsets[table_to_feature_offset[1 ] * B] -
394
+ const auto nnz = csr_offsets[table_to_feature_offset[1 ] * B] -
395
395
csr_offsets[table_to_feature_offset[0 ] * B];
396
396
if (nnz == 0 ) {
397
397
return ;
@@ -407,7 +407,7 @@ void csr2csc_template_(
407
407
[[maybe_unused]] int column_ptr_curr = 0 ;
408
408
bool is_shared_table =
409
409
table_to_feature_offset[1 ] > table_to_feature_offset[0 ] + 1 ;
410
- auto NS = csr_offsets[table_to_feature_offset[1 ] * B] -
410
+ const auto NS = csr_offsets[( size_t ) table_to_feature_offset[1 ] * B] -
411
411
csr_offsets[table_to_feature_offset[0 ] * B];
412
412
413
413
using pair_t = std::pair<int , scalar_t >;
@@ -432,9 +432,9 @@ void csr2csc_template_(
432
432
#pragma omp parallel for
433
433
for (int b = 0 ; b < B; ++b) {
434
434
const auto FBb = feature * B + b;
435
- int64_t pool_begin = csr_offsets[FBb];
436
- int64_t pool_end = csr_offsets[FBb + 1 ];
437
- int64_t L = pool_end - pool_begin;
435
+ const auto pool_begin = csr_offsets[FBb];
436
+ const auto pool_end = csr_offsets[FBb + 1 ];
437
+ const auto L = pool_end - pool_begin;
438
438
// MEAN pooling will not work with indice_weights!
439
439
double scale_factor =
440
440
(static_cast <PoolingMode>(pooling_mode) == PoolingMode::MEAN &&
@@ -581,47 +581,48 @@ void csr2csc_template_(
581
581
assert (column_ptr_curr == nnz);
582
582
}
583
583
584
- #define INSTANTIATE_BATCHED_CSR2CSC (SCALAR_T ) \
585
- template void csr2csc_template_<SCALAR_T, true >( \
586
- HyperCompressedSparseColumn & csc, \
587
- int B, \
588
- const at::TensorAccessor<int64_t , 1 >& csr_offsets, \
589
- const at::TensorAccessor<int64_t , 1 >& csr_indices, \
590
- const at::TensorAccessor<SCALAR_T, 1 >& csr_weights, \
591
- int64_t pooling_mode, \
592
- const int * table_to_feature_offset, \
593
- int64_t num_embeddings); \
594
- \
595
- template void csr2csc_template_<SCALAR_T, false >( \
596
- HyperCompressedSparseColumn & csc, \
597
- int B, \
598
- const at::TensorAccessor<int64_t , 1 >& csr_offsets, \
599
- const at::TensorAccessor<int64_t , 1 >& csr_indices, \
600
- const at::TensorAccessor<SCALAR_T, 1 >& csr_weights, \
601
- int64_t pooling_mode, \
602
- const int * table_to_feature_offset, \
584
+ #define INSTANTIATE_CSR2CSC_TEMPLATE_0 (index_t, scalar_t, is_value_pair ) \
585
+ template void csr2csc_template_<index_t , scalar_t , is_value_pair>( \
586
+ HyperCompressedSparseColumn & csc, \
587
+ int B, \
588
+ const at::TensorAccessor<index_t , 1 >& csr_offsets, \
589
+ const at::TensorAccessor<index_t , 1 >& csr_indices, \
590
+ const at::TensorAccessor<scalar_t , 1 >& csr_weights, \
591
+ int64_t pooling_mode, \
592
+ const int * table_to_feature_offset, \
603
593
int64_t num_embeddings);
604
594
605
- INSTANTIATE_BATCHED_CSR2CSC (float )
606
- INSTANTIATE_BATCHED_CSR2CSC (double )
607
- #undef INSTANTIATE_BATCHED_CSR2CSC
595
+ #define INSTANTIATE_CSR2CSC_TEMPLATE_1 (index_t, scalar_t ) \
596
+ INSTANTIATE_CSR2CSC_TEMPLATE_0 (index_t , scalar_t , true ); \
597
+ INSTANTIATE_CSR2CSC_TEMPLATE_0 (index_t , scalar_t , false );
598
+
599
+ #define INSTANTIATE_CSR2CSC_TEMPLATE_2 (index_t ) \
600
+ INSTANTIATE_CSR2CSC_TEMPLATE_1 (index_t , float ); \
601
+ INSTANTIATE_CSR2CSC_TEMPLATE_1 (index_t , double );
602
+
603
+ INSTANTIATE_CSR2CSC_TEMPLATE_2 (int32_t );
604
+ INSTANTIATE_CSR2CSC_TEMPLATE_2 (int64_t );
605
+
606
+ #undef INSTANTIATE_CSR2CSC_TEMPLATE_2
607
+ #undef INSTANTIATE_CSR2CSC_TEMPLATE_1
608
+ #undef INSTANTIATE_CSR2CSC_TEMPLATE_0
608
609
609
610
} // namespace
610
611
611
- template <typename scalar_t >
612
+ template <typename index_t , typename scalar_t >
612
613
void csr2csc (
613
614
HyperCompressedSparseColumn& csc,
614
615
int B,
615
- const at::TensorAccessor<int64_t , 1 >& csr_offsets,
616
- const at::TensorAccessor<int64_t , 1 >& csr_indices,
616
+ const at::TensorAccessor<index_t , 1 >& csr_offsets,
617
+ const at::TensorAccessor<index_t , 1 >& csr_indices,
617
618
const at::TensorAccessor<scalar_t , 1 >& csr_weights,
618
619
int64_t pooling_mode,
619
620
const int * table_to_feature_offset,
620
621
int64_t num_embeddings) {
621
622
bool has_weights = csr_weights.data () != nullptr ;
622
623
if (has_weights ||
623
624
static_cast <PoolingMode>(pooling_mode) == PoolingMode::MEAN) {
624
- csr2csc_template_<scalar_t , /* IS_VALUE_PAIR=*/ true >(
625
+ csr2csc_template_<index_t , scalar_t , /* IS_VALUE_PAIR=*/ true >(
625
626
csc,
626
627
B,
627
628
csr_offsets,
@@ -631,7 +632,7 @@ void csr2csc(
631
632
table_to_feature_offset,
632
633
num_embeddings);
633
634
} else {
634
- csr2csc_template_<scalar_t , /* IS_VALUE_PAIR=*/ false >(
635
+ csr2csc_template_<index_t , scalar_t , /* IS_VALUE_PAIR=*/ false >(
635
636
csc,
636
637
B,
637
638
csr_offsets,
@@ -643,25 +644,26 @@ void csr2csc(
643
644
}
644
645
}
645
646
646
- template void csr2csc<float >(
647
- HyperCompressedSparseColumn& csc,
648
- int B,
649
- const at::TensorAccessor<int64_t , 1 >& csr_offsets,
650
- const at::TensorAccessor<int64_t , 1 >& csr_indices,
651
- const at::TensorAccessor<float , 1 >& csr_weights,
652
- int64_t pooling_mode,
653
- const int * table_to_feature_offset,
654
- int64_t num_embeddings);
647
+ #define INSTANTIATE_CSR2CSC_0 (index_t, scalar_t ) \
648
+ template void csr2csc<index_t , scalar_t >( \
649
+ HyperCompressedSparseColumn & csc, \
650
+ int B, \
651
+ const at::TensorAccessor<index_t , 1 >& csr_offsets, \
652
+ const at::TensorAccessor<index_t , 1 >& csr_indices, \
653
+ const at::TensorAccessor<scalar_t , 1 >& csr_weights, \
654
+ int64_t pooling_mode, \
655
+ const int * table_to_feature_offset, \
656
+ int64_t num_embeddings);
655
657
656
- template void csr2csc< double >(
657
- HyperCompressedSparseColumn& csc,
658
- int B,
659
- const at::TensorAccessor< int64_t , 1 >& csr_offsets,
660
- const at::TensorAccessor< int64_t , 1 >& csr_indices,
661
- const at::TensorAccessor< double , 1 >& csr_weights,
662
- int64_t pooling_mode,
663
- const int * table_to_feature_offset,
664
- int64_t num_embeddings);
658
+ # define INSTANTIATE_CSR2CSC_1 ( index_t ) \
659
+ INSTANTIATE_CSR2CSC_0 ( index_t , float ); \
660
+ INSTANTIATE_CSR2CSC_0 ( index_t , double );
661
+
662
+ INSTANTIATE_CSR2CSC_1 ( int32_t );
663
+ INSTANTIATE_CSR2CSC_1 ( int64_t );
664
+
665
+ # undef INSTANTIATE_CSR2CSC_1
666
+ # undef INSTANTIATE_CSR2CSC_0
665
667
666
668
} // namespace internal
667
669
0 commit comments