@@ -284,7 +284,8 @@ void _block_bucketize_sparse_features_cpu(
284
284
c10::optional<Tensor> new_weights,
285
285
c10::optional<Tensor> new_pos,
286
286
const c10::optional<Tensor>& unbucketize_permute,
287
- const c10::optional<Tensor>& batch_size_per_feature) {
287
+ const c10::optional<Tensor>& batch_size_per_feature,
288
+ const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos) {
288
289
// allocate tensors and buffers
289
290
const auto lengths_size = lengths.numel ();
290
291
const auto new_lengths_size = lengths_size * my_size;
@@ -305,7 +306,7 @@ void _block_bucketize_sparse_features_cpu(
305
306
const index_t * const block_sizes_data = block_sizes.data_ptr <index_t >();
306
307
offset_t * batch_sizes_data = nullptr ;
307
308
const auto variable_batch_size = batch_size_per_feature.has_value ();
308
-
309
+ const auto variable_bucket_sizes = block_bucketize_pos. has_value ();
309
310
using uindex_t = std::make_unsigned_t <index_t >;
310
311
using uoffset_t = std::make_unsigned_t <offset_t >;
311
312
@@ -331,6 +332,12 @@ void _block_bucketize_sparse_features_cpu(
331
332
for (const auto t : c10::irange (T)) {
332
333
const auto blk_size = block_sizes_data[t];
333
334
const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B;
335
+ const index_t * bucketize_offset = nullptr ;
336
+ int64_t bucket_size = 0 ;
337
+ if (variable_bucket_sizes) {
338
+ bucketize_offset = block_bucketize_pos.value ()[t].data_ptr <index_t >();
339
+ bucket_size = block_bucketize_pos.value ()[t].numel ();
340
+ }
334
341
for (const auto b : c10::irange (cur_batch_size)) {
335
342
const auto b_t = (variable_batch_size ? cur_offset : t * B) + b;
336
343
const offset_t rowstart = offsets_data[b_t ];
@@ -343,10 +350,20 @@ void _block_bucketize_sparse_features_cpu(
343
350
// range of blk_size, we expect the later embedding module to take care
344
351
// of hashing indices calculation.
345
352
uindex_t idx = static_cast <uindex_t >(indices_data[i]);
346
- uindex_t p = idx < static_cast <uindex_t >(blk_size * my_size)
347
- ? idx / blk_size
348
- : idx % my_size;
349
- new_lengths_data[p * lengths_size + b_t ]++;
353
+ if (variable_bucket_sizes) {
354
+ int64_t lb = std::upper_bound (
355
+ bucketize_offset,
356
+ bucketize_offset + bucket_size,
357
+ indices_data[i]) -
358
+ bucketize_offset - 1 ;
359
+ uindex_t p = lb < my_size ? lb : idx % my_size;
360
+ new_lengths_data[p * lengths_size + b_t ]++;
361
+ } else {
362
+ uindex_t p = idx < static_cast <uindex_t >(blk_size * my_size)
363
+ ? idx / blk_size
364
+ : idx % my_size;
365
+ new_lengths_data[p * lengths_size + b_t ]++;
366
+ }
350
367
}
351
368
}
352
369
cur_offset += cur_batch_size;
@@ -359,6 +376,12 @@ void _block_bucketize_sparse_features_cpu(
359
376
for (const auto t : c10::irange (T)) {
360
377
const auto blk_size = block_sizes_data[t];
361
378
const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B;
379
+ const index_t * bucketize_offset = nullptr ;
380
+ int64_t bucket_size = 0 ;
381
+ if (variable_bucket_sizes) {
382
+ bucketize_offset = block_bucketize_pos.value ()[t].data_ptr <index_t >();
383
+ bucket_size = block_bucketize_pos.value ()[t].numel ();
384
+ }
362
385
for (const auto b : c10::irange (cur_batch_size)) {
363
386
const auto b_t = (variable_batch_size ? cur_offset : t * B) + b;
364
387
const offset_t rowstart = offsets_data[b_t ];
@@ -371,12 +394,22 @@ void _block_bucketize_sparse_features_cpu(
371
394
// range of blk_size, we expect the later embedding module to take care
372
395
// of hashing indices calculation.
373
396
const uindex_t idx = static_cast <uindex_t >(indices_data[i]);
374
- const uindex_t p = idx < static_cast <uindex_t >(blk_size * my_size)
375
- ? idx / blk_size
376
- : idx % my_size;
377
- const uindex_t new_idx = idx < static_cast <uindex_t >(blk_size * my_size)
378
- ? idx % blk_size
379
- : idx / my_size;
397
+ uindex_t p, new_idx;
398
+ if (variable_bucket_sizes) {
399
+ auto lb = std::upper_bound (
400
+ bucketize_offset,
401
+ bucketize_offset + bucket_size,
402
+ indices_data[i]) -
403
+ bucketize_offset - 1 ;
404
+ p = lb < my_size ? lb : idx % my_size;
405
+ new_idx = lb < my_size ? idx - bucketize_offset[lb] : idx / my_size;
406
+ } else {
407
+ p = idx < static_cast <uindex_t >(blk_size * my_size) ? idx / blk_size
408
+ : idx % my_size;
409
+ new_idx = idx < static_cast <uindex_t >(blk_size * my_size)
410
+ ? idx % blk_size
411
+ : idx / my_size;
412
+ }
380
413
const uoffset_t pos = new_offsets_data[p * lengths_size + b_t ];
381
414
new_indices_data[pos] = new_idx;
382
415
if (sequence) {
@@ -911,8 +944,8 @@ block_bucketize_sparse_features_cpu(
911
944
const int64_t my_size,
912
945
const c10::optional<Tensor>& weights,
913
946
const c10::optional<Tensor>& batch_size_per_feature,
914
- const int64_t /* max_batch_size */ // Only used in GPU variant
915
- ) {
947
+ const int64_t /* max_batch_size */ , // Only used in GPU variant
948
+ const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos ) {
916
949
const auto lengths_size = lengths.numel ();
917
950
const auto new_lengths_size = lengths_size * my_size;
918
951
auto new_lengths = at::zeros ({new_lengths_size}, lengths.options ());
@@ -959,7 +992,8 @@ block_bucketize_sparse_features_cpu(
959
992
new_weights,
960
993
new_pos,
961
994
unbucketize_permute,
962
- batch_size_per_feature);
995
+ batch_size_per_feature,
996
+ block_bucketize_pos);
963
997
});
964
998
});
965
999
});
@@ -994,7 +1028,8 @@ block_bucketize_sparse_features_cpu(
994
1028
new_weights,
995
1029
new_pos,
996
1030
unbucketize_permute,
997
- batch_size_per_feature);
1031
+ batch_size_per_feature,
1032
+ block_bucketize_pos);
998
1033
});
999
1034
});
1000
1035
});
@@ -1027,7 +1062,8 @@ block_bucketize_sparse_features_cpu(
1027
1062
new_weights,
1028
1063
new_pos,
1029
1064
unbucketize_permute,
1030
- batch_size_per_feature);
1065
+ batch_size_per_feature,
1066
+ block_bucketize_pos);
1031
1067
});
1032
1068
});
1033
1069
} else {
@@ -1055,7 +1091,8 @@ block_bucketize_sparse_features_cpu(
1055
1091
new_weights,
1056
1092
new_pos,
1057
1093
unbucketize_permute,
1058
- batch_size_per_feature);
1094
+ batch_size_per_feature,
1095
+ block_bucketize_pos);
1059
1096
});
1060
1097
});
1061
1098
}
@@ -2702,7 +2739,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
2702
2739
m.def (
2703
2740
" expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, SymInt output_size) -> Tensor" );
2704
2741
m.def (
2705
- " block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)" );
2742
+ " block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None ) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)" );
2706
2743
m.def (
2707
2744
" bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, SymInt my_size, Tensor? weights=None) -> (Tensor, Tensor, Tensor?, Tensor?)" );
2708
2745
m.def (
0 commit comments