@@ -264,6 +264,21 @@ void _permute_2D_lengths_cpu_kernel(
264
264
input_offsets[i + 1 ] = lengths[i] + input_offsets[i];
265
265
}
266
266
}
267
+ template <typename index_t >
268
+ int64_t
269
+ _find_lower_bound (const index_t * arr, const int64_t size, index_t target) {
270
+ int64_t l = 0 ;
271
+ int64_t h = size - 1 ;
272
+ while (l <= h) {
273
+ int mid = l + (h - l) / 2 ;
274
+ if (arr[mid] > target) {
275
+ h = mid - 1 ;
276
+ } else {
277
+ l = mid + 1 ;
278
+ }
279
+ }
280
+ return l - 1 ;
281
+ }
267
282
268
283
template <
269
284
bool sequence,
@@ -283,7 +298,8 @@ void _block_bucketize_sparse_features_cpu(
283
298
c10::optional<Tensor> new_weights,
284
299
c10::optional<Tensor> new_pos,
285
300
const c10::optional<Tensor>& unbucketize_permute,
286
- const c10::optional<Tensor>& batch_size_per_feature) {
301
+ const c10::optional<Tensor>& batch_size_per_feature,
302
+ const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos) {
287
303
// allocate tensors and buffers
288
304
const auto lengths_size = lengths.numel ();
289
305
const auto new_lengths_size = lengths_size * my_size;
@@ -304,7 +320,7 @@ void _block_bucketize_sparse_features_cpu(
304
320
const index_t * const block_sizes_data = block_sizes.data_ptr <index_t >();
305
321
offset_t * batch_sizes_data = nullptr ;
306
322
const auto variable_batch_size = batch_size_per_feature.has_value ();
307
-
323
+ const auto variable_bucket_sizes = block_bucketize_pos. has_value ();
308
324
using uindex_t = std::make_unsigned_t <index_t >;
309
325
using uoffset_t = std::make_unsigned_t <offset_t >;
310
326
@@ -330,6 +346,12 @@ void _block_bucketize_sparse_features_cpu(
330
346
for (const auto t : c10::irange (T)) {
331
347
const auto blk_size = block_sizes_data[t];
332
348
const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B;
349
+ const index_t * bucketize_offset = nullptr ;
350
+ int64_t bucket_size = 0 ;
351
+ if (variable_bucket_sizes) {
352
+ bucketize_offset = block_bucketize_pos.value ()[t].data_ptr <index_t >();
353
+ bucket_size = block_bucketize_pos.value ()[t].numel ();
354
+ }
333
355
for (const auto b : c10::irange (cur_batch_size)) {
334
356
const auto b_t = (variable_batch_size ? cur_offset : t * B) + b;
335
357
const offset_t rowstart = offsets_data[b_t ];
@@ -342,10 +364,17 @@ void _block_bucketize_sparse_features_cpu(
342
364
// range of blk_size, we expect the later embedding module to take care
343
365
// of hashing indices calculation.
344
366
uindex_t idx = static_cast <uindex_t >(indices_data[i]);
345
- uindex_t p = idx < static_cast <uindex_t >(blk_size * my_size)
346
- ? idx / blk_size
347
- : idx % my_size;
348
- new_lengths_data[p * lengths_size + b_t ]++;
367
+ if (variable_bucket_sizes) {
368
+ auto lb = _find_lower_bound<index_t >(
369
+ bucketize_offset, bucket_size, indices_data[i]);
370
+ uindex_t p = lb < my_size ? lb : idx % my_size;
371
+ new_lengths_data[p * lengths_size + b_t ]++;
372
+ } else {
373
+ uindex_t p = idx < static_cast <uindex_t >(blk_size * my_size)
374
+ ? idx / blk_size
375
+ : idx % my_size;
376
+ new_lengths_data[p * lengths_size + b_t ]++;
377
+ }
349
378
}
350
379
}
351
380
cur_offset += cur_batch_size;
@@ -358,6 +387,12 @@ void _block_bucketize_sparse_features_cpu(
358
387
for (const auto t : c10::irange (T)) {
359
388
const auto blk_size = block_sizes_data[t];
360
389
const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B;
390
+ const index_t * bucketize_offset = nullptr ;
391
+ int64_t bucket_size = 0 ;
392
+ if (variable_bucket_sizes) {
393
+ bucketize_offset = block_bucketize_pos.value ()[t].data_ptr <index_t >();
394
+ bucket_size = block_bucketize_pos.value ()[t].numel ();
395
+ }
361
396
for (const auto b : c10::irange (cur_batch_size)) {
362
397
const auto b_t = (variable_batch_size ? cur_offset : t * B) + b;
363
398
const offset_t rowstart = offsets_data[b_t ];
@@ -370,12 +405,19 @@ void _block_bucketize_sparse_features_cpu(
370
405
// range of blk_size, we expect the later embedding module to take care
371
406
// of hashing indices calculation.
372
407
const uindex_t idx = static_cast <uindex_t >(indices_data[i]);
373
- const uindex_t p = idx < static_cast <uindex_t >(blk_size * my_size)
374
- ? idx / blk_size
375
- : idx % my_size;
376
- const uindex_t new_idx = idx < static_cast <uindex_t >(blk_size * my_size)
377
- ? idx % blk_size
378
- : idx / my_size;
408
+ uindex_t p, new_idx;
409
+ if (variable_bucket_sizes) {
410
+ auto lb = _find_lower_bound<index_t >(
411
+ bucketize_offset, bucket_size, indices_data[i]);
412
+ p = lb < my_size ? lb : idx % my_size;
413
+ new_idx = lb < my_size ? idx - bucketize_offset[lb] : idx / my_size;
414
+ } else {
415
+ p = idx < static_cast <uindex_t >(blk_size * my_size) ? idx / blk_size
416
+ : idx % my_size;
417
+ new_idx = idx < static_cast <uindex_t >(blk_size * my_size)
418
+ ? idx % blk_size
419
+ : idx / my_size;
420
+ }
379
421
const uoffset_t pos = new_offsets_data[p * lengths_size + b_t ];
380
422
new_indices_data[pos] = new_idx;
381
423
if (sequence) {
@@ -910,8 +952,8 @@ block_bucketize_sparse_features_cpu(
910
952
const int64_t my_size,
911
953
const c10::optional<Tensor>& weights,
912
954
const c10::optional<Tensor>& batch_size_per_feature,
913
- const int64_t /* max_batch_size */ // Only used in GPU variant
914
- ) {
955
+ const int64_t /* max_batch_size */ , // Only used in GPU variant
956
+ const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos ) {
915
957
const auto lengths_size = lengths.numel ();
916
958
const auto new_lengths_size = lengths_size * my_size;
917
959
auto new_lengths = at::zeros ({new_lengths_size}, lengths.options ());
@@ -958,7 +1000,8 @@ block_bucketize_sparse_features_cpu(
958
1000
new_weights,
959
1001
new_pos,
960
1002
unbucketize_permute,
961
- batch_size_per_feature);
1003
+ batch_size_per_feature,
1004
+ block_bucketize_pos);
962
1005
});
963
1006
});
964
1007
});
@@ -993,7 +1036,8 @@ block_bucketize_sparse_features_cpu(
993
1036
new_weights,
994
1037
new_pos,
995
1038
unbucketize_permute,
996
- batch_size_per_feature);
1039
+ batch_size_per_feature,
1040
+ block_bucketize_pos);
997
1041
});
998
1042
});
999
1043
});
@@ -1026,7 +1070,8 @@ block_bucketize_sparse_features_cpu(
1026
1070
new_weights,
1027
1071
new_pos,
1028
1072
unbucketize_permute,
1029
- batch_size_per_feature);
1073
+ batch_size_per_feature,
1074
+ block_bucketize_pos);
1030
1075
});
1031
1076
});
1032
1077
} else {
@@ -1054,7 +1099,8 @@ block_bucketize_sparse_features_cpu(
1054
1099
new_weights,
1055
1100
new_pos,
1056
1101
unbucketize_permute,
1057
- batch_size_per_feature);
1102
+ batch_size_per_feature,
1103
+ block_bucketize_pos);
1058
1104
});
1059
1105
});
1060
1106
}
@@ -2696,7 +2742,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
2696
2742
m.def (
2697
2743
" expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, SymInt output_size) -> Tensor" );
2698
2744
m.def (
2699
- " block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)" );
2745
+ " block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None ) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)" );
2700
2746
m.def (
2701
2747
" bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, SymInt my_size, Tensor? weights=None) -> (Tensor, Tensor, Tensor?, Tensor?)" );
2702
2748
m.def (" asynchronous_exclusive_cumsum(Tensor t_in) -> Tensor" );
0 commit comments