@@ -284,7 +284,8 @@ void _block_bucketize_sparse_features_cpu(
284
284
c10::optional<Tensor> new_weights,
285
285
c10::optional<Tensor> new_pos,
286
286
const c10::optional<Tensor>& unbucketize_permute,
287
- const c10::optional<Tensor>& batch_size_per_feature) {
287
+ const c10::optional<Tensor>& batch_size_per_feature,
288
+ const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos) {
288
289
// allocate tensors and buffers
289
290
const auto lengths_size = lengths.numel ();
290
291
const auto new_lengths_size = lengths_size * my_size;
@@ -305,9 +306,10 @@ void _block_bucketize_sparse_features_cpu(
305
306
const index_t * const block_sizes_data = block_sizes.data_ptr <index_t >();
306
307
offset_t * batch_sizes_data = nullptr ;
307
308
const auto variable_batch_size = batch_size_per_feature.has_value ();
308
-
309
+ const auto variable_bucket_sizes = block_bucketize_pos. has_value ();
309
310
using uindex_t = std::make_unsigned_t <index_t >;
310
311
using uoffset_t = std::make_unsigned_t <offset_t >;
312
+ std::vector<int64_t > lower_bounds (indices.numel (), 0 );
311
313
312
314
if constexpr (sequence) {
313
315
unbucketize_permute_data = unbucketize_permute.value ().data_ptr <index_t >();
@@ -331,6 +333,12 @@ void _block_bucketize_sparse_features_cpu(
331
333
for (const auto t : c10::irange (T)) {
332
334
const auto blk_size = block_sizes_data[t];
333
335
const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B;
336
+ const index_t * bucketize_offset = nullptr ;
337
+ int64_t bucket_size = 0 ;
338
+ if (variable_bucket_sizes) {
339
+ bucketize_offset = block_bucketize_pos.value ()[t].data_ptr <index_t >();
340
+ bucket_size = block_bucketize_pos.value ()[t].numel ();
341
+ }
334
342
for (const auto b : c10::irange (cur_batch_size)) {
335
343
const auto b_t = (variable_batch_size ? cur_offset : t * B) + b;
336
344
const offset_t rowstart = offsets_data[b_t ];
@@ -343,10 +351,21 @@ void _block_bucketize_sparse_features_cpu(
343
351
// range of blk_size, we expect the later embedding module to take care
344
352
// of hashing indices calculation.
345
353
uindex_t idx = static_cast <uindex_t >(indices_data[i]);
346
- uindex_t p = idx < static_cast <uindex_t >(blk_size * my_size)
347
- ? idx / blk_size
348
- : idx % my_size;
349
- new_lengths_data[p * lengths_size + b_t ]++;
354
+ if (variable_bucket_sizes) {
355
+ int64_t lb = std::upper_bound (
356
+ bucketize_offset,
357
+ bucketize_offset + bucket_size,
358
+ indices_data[i]) -
359
+ bucketize_offset - 1 ;
360
+ lower_bounds[i] = lb;
361
+ uindex_t p = lb < my_size ? lb : idx % my_size;
362
+ new_lengths_data[p * lengths_size + b_t ]++;
363
+ } else {
364
+ uindex_t p = idx < static_cast <uindex_t >(blk_size * my_size)
365
+ ? idx / blk_size
366
+ : idx % my_size;
367
+ new_lengths_data[p * lengths_size + b_t ]++;
368
+ }
350
369
}
351
370
}
352
371
cur_offset += cur_batch_size;
@@ -359,6 +378,10 @@ void _block_bucketize_sparse_features_cpu(
359
378
for (const auto t : c10::irange (T)) {
360
379
const auto blk_size = block_sizes_data[t];
361
380
const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B;
381
+ const index_t * bucketize_offset = nullptr ;
382
+ if (variable_bucket_sizes) {
383
+ bucketize_offset = block_bucketize_pos.value ()[t].data_ptr <index_t >();
384
+ }
362
385
for (const auto b : c10::irange (cur_batch_size)) {
363
386
const auto b_t = (variable_batch_size ? cur_offset : t * B) + b;
364
387
const offset_t rowstart = offsets_data[b_t ];
@@ -371,12 +394,19 @@ void _block_bucketize_sparse_features_cpu(
371
394
// range of blk_size, we expect the later embedding module to take care
372
395
// of hashing indices calculation.
373
396
const uindex_t idx = static_cast <uindex_t >(indices_data[i]);
374
- const uindex_t p = idx < static_cast <uindex_t >(blk_size * my_size)
375
- ? idx / blk_size
376
- : idx % my_size;
377
- const uindex_t new_idx = idx < static_cast <uindex_t >(blk_size * my_size)
378
- ? idx % blk_size
379
- : idx / my_size;
397
+ uindex_t p, new_idx;
398
+ if (variable_bucket_sizes) {
399
+ int64_t lb = lower_bounds[i];
400
+ p = lb < my_size ? lb : idx % my_size;
401
+ new_idx = lb < my_size ? idx - bucketize_offset[lb] : idx / my_size;
402
+
403
+ } else {
404
+ p = idx < static_cast <uindex_t >(blk_size * my_size) ? idx / blk_size
405
+ : idx % my_size;
406
+ new_idx = idx < static_cast <uindex_t >(blk_size * my_size)
407
+ ? idx % blk_size
408
+ : idx / my_size;
409
+ }
380
410
const uoffset_t pos = new_offsets_data[p * lengths_size + b_t ];
381
411
new_indices_data[pos] = new_idx;
382
412
if (sequence) {
@@ -911,8 +941,8 @@ block_bucketize_sparse_features_cpu(
911
941
const int64_t my_size,
912
942
const c10::optional<Tensor>& weights,
913
943
const c10::optional<Tensor>& batch_size_per_feature,
914
- const int64_t /* max_batch_size */ // Only used in GPU variant
915
- ) {
944
+ const int64_t /* max_batch_size */ , // Only used in GPU variant
945
+ const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos ) {
916
946
const auto lengths_size = lengths.numel ();
917
947
const auto new_lengths_size = lengths_size * my_size;
918
948
auto new_lengths = at::zeros ({new_lengths_size}, lengths.options ());
@@ -959,7 +989,8 @@ block_bucketize_sparse_features_cpu(
959
989
new_weights,
960
990
new_pos,
961
991
unbucketize_permute,
962
- batch_size_per_feature);
992
+ batch_size_per_feature,
993
+ block_bucketize_pos);
963
994
});
964
995
});
965
996
});
@@ -994,7 +1025,8 @@ block_bucketize_sparse_features_cpu(
994
1025
new_weights,
995
1026
new_pos,
996
1027
unbucketize_permute,
997
- batch_size_per_feature);
1028
+ batch_size_per_feature,
1029
+ block_bucketize_pos);
998
1030
});
999
1031
});
1000
1032
});
@@ -1027,7 +1059,8 @@ block_bucketize_sparse_features_cpu(
1027
1059
new_weights,
1028
1060
new_pos,
1029
1061
unbucketize_permute,
1030
- batch_size_per_feature);
1062
+ batch_size_per_feature,
1063
+ block_bucketize_pos);
1031
1064
});
1032
1065
});
1033
1066
} else {
@@ -1055,7 +1088,8 @@ block_bucketize_sparse_features_cpu(
1055
1088
new_weights,
1056
1089
new_pos,
1057
1090
unbucketize_permute,
1058
- batch_size_per_feature);
1091
+ batch_size_per_feature,
1092
+ block_bucketize_pos);
1059
1093
});
1060
1094
});
1061
1095
}
@@ -2702,7 +2736,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
2702
2736
m.def (
2703
2737
" expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, SymInt output_size) -> Tensor" );
2704
2738
m.def (
2705
- " block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)" );
2739
+ " block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None ) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)" );
2706
2740
m.def (
2707
2741
" bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, SymInt my_size, Tensor? weights=None) -> (Tensor, Tensor, Tensor?, Tensor?)" );
2708
2742
m.def (
0 commit comments