Skip to content

Commit e762d10

Browse files
tissue3facebook-github-bot
authored andcommitted
Support variable bucket size for block_bucketize_sparse_features (pytorch#2107)
Summary: This diff add support for variable bucket size for block bucketize_sparse features for RW sharding. E.g. Given bucket_sizes_pos as [[0,5,15], [0,10,13]] For batch 0, indices in [0,5) will be assigned to bucket 0, indices in [5,15) will be assigned to bucket 1. For batch 1, indices in [0,10) will be assigned to bucket 0, indices in [10,13) will be assigned to bucket 1. The new index will be original index - bucket_sizes_pos[new_bucket_id-1] i.e. for batch = 0, index = 12, it will be assigned to bucket 1 and the new index is 12 - 5 = 7. Differential Revision: D50868649
1 parent 975cb01 commit e762d10

File tree

3 files changed

+61
-22
lines changed

3 files changed

+61
-22
lines changed

fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ block_bucketize_sparse_features_cuda(
149149
const int64_t my_size,
150150
const c10::optional<at::Tensor>& weights,
151151
const c10::optional<at::Tensor>& batch_size_per_feature,
152-
const int64_t max_batch_size);
152+
const int64_t max_batch_size,
153+
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos);
153154

154155
std::tuple<
155156
at::Tensor,
@@ -168,7 +169,8 @@ block_bucketize_sparse_features_cpu(
168169
const int64_t my_size,
169170
const c10::optional<at::Tensor>& weights,
170171
const c10::optional<at::Tensor>& batch_size_per_feature,
171-
const int64_t max_batch_size);
172+
const int64_t max_batch_size,
173+
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos);
172174

173175
std::tuple<
174176
at::Tensor,

fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,10 @@ block_bucketize_sparse_features_cuda(
147147
const int64_t my_size,
148148
const c10::optional<Tensor>& weights,
149149
const c10::optional<Tensor>& batch_size_per_feature,
150-
const int64_t max_B) {
150+
const int64_t max_B,
151+
const c10::optional<std::vector<
152+
at::Tensor>>& /*block_bucketize_pos*/ // Only used in GPU variant
153+
) {
151154
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(lengths, indices);
152155

153156
at::cuda::OptionalCUDAGuard device_guard;

fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp

Lines changed: 53 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,8 @@ void _block_bucketize_sparse_features_cpu(
284284
c10::optional<Tensor> new_weights,
285285
c10::optional<Tensor> new_pos,
286286
const c10::optional<Tensor>& unbucketize_permute,
287-
const c10::optional<Tensor>& batch_size_per_feature) {
287+
const c10::optional<Tensor>& batch_size_per_feature,
288+
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos) {
288289
// allocate tensors and buffers
289290
const auto lengths_size = lengths.numel();
290291
const auto new_lengths_size = lengths_size * my_size;
@@ -305,9 +306,10 @@ void _block_bucketize_sparse_features_cpu(
305306
const index_t* const block_sizes_data = block_sizes.data_ptr<index_t>();
306307
offset_t* batch_sizes_data = nullptr;
307308
const auto variable_batch_size = batch_size_per_feature.has_value();
308-
309+
const auto variable_bucket_sizes = block_bucketize_pos.has_value();
309310
using uindex_t = std::make_unsigned_t<index_t>;
310311
using uoffset_t = std::make_unsigned_t<offset_t>;
312+
std::vector<int64_t> lower_bounds(indices.numel(), 0);
311313

312314
if constexpr (sequence) {
313315
unbucketize_permute_data = unbucketize_permute.value().data_ptr<index_t>();
@@ -331,6 +333,12 @@ void _block_bucketize_sparse_features_cpu(
331333
for (const auto t : c10::irange(T)) {
332334
const auto blk_size = block_sizes_data[t];
333335
const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B;
336+
const index_t* bucketize_offset = nullptr;
337+
int64_t bucket_size = 0;
338+
if (variable_bucket_sizes) {
339+
bucketize_offset = block_bucketize_pos.value()[t].data_ptr<index_t>();
340+
bucket_size = block_bucketize_pos.value()[t].numel();
341+
}
334342
for (const auto b : c10::irange(cur_batch_size)) {
335343
const auto b_t = (variable_batch_size ? cur_offset : t * B) + b;
336344
const offset_t rowstart = offsets_data[b_t];
@@ -343,10 +351,21 @@ void _block_bucketize_sparse_features_cpu(
343351
// range of blk_size, we expect the later embedding module to take care
344352
// of hashing indices calculation.
345353
uindex_t idx = static_cast<uindex_t>(indices_data[i]);
346-
uindex_t p = idx < static_cast<uindex_t>(blk_size * my_size)
347-
? idx / blk_size
348-
: idx % my_size;
349-
new_lengths_data[p * lengths_size + b_t]++;
354+
if (variable_bucket_sizes) {
355+
int64_t lb = std::upper_bound(
356+
bucketize_offset,
357+
bucketize_offset + bucket_size,
358+
indices_data[i]) -
359+
bucketize_offset - 1;
360+
lower_bounds[i] = lb;
361+
uindex_t p = lb < my_size ? lb : idx % my_size;
362+
new_lengths_data[p * lengths_size + b_t]++;
363+
} else {
364+
uindex_t p = idx < static_cast<uindex_t>(blk_size * my_size)
365+
? idx / blk_size
366+
: idx % my_size;
367+
new_lengths_data[p * lengths_size + b_t]++;
368+
}
350369
}
351370
}
352371
cur_offset += cur_batch_size;
@@ -359,6 +378,10 @@ void _block_bucketize_sparse_features_cpu(
359378
for (const auto t : c10::irange(T)) {
360379
const auto blk_size = block_sizes_data[t];
361380
const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B;
381+
const index_t* bucketize_offset = nullptr;
382+
if (variable_bucket_sizes) {
383+
bucketize_offset = block_bucketize_pos.value()[t].data_ptr<index_t>();
384+
}
362385
for (const auto b : c10::irange(cur_batch_size)) {
363386
const auto b_t = (variable_batch_size ? cur_offset : t * B) + b;
364387
const offset_t rowstart = offsets_data[b_t];
@@ -371,12 +394,19 @@ void _block_bucketize_sparse_features_cpu(
371394
// range of blk_size, we expect the later embedding module to take care
372395
// of hashing indices calculation.
373396
const uindex_t idx = static_cast<uindex_t>(indices_data[i]);
374-
const uindex_t p = idx < static_cast<uindex_t>(blk_size * my_size)
375-
? idx / blk_size
376-
: idx % my_size;
377-
const uindex_t new_idx = idx < static_cast<uindex_t>(blk_size * my_size)
378-
? idx % blk_size
379-
: idx / my_size;
397+
uindex_t p, new_idx;
398+
if (variable_bucket_sizes) {
399+
int64_t lb = lower_bounds[i];
400+
p = lb < my_size ? lb : idx % my_size;
401+
new_idx = lb < my_size ? idx - bucketize_offset[lb] : idx / my_size;
402+
403+
} else {
404+
p = idx < static_cast<uindex_t>(blk_size * my_size) ? idx / blk_size
405+
: idx % my_size;
406+
new_idx = idx < static_cast<uindex_t>(blk_size * my_size)
407+
? idx % blk_size
408+
: idx / my_size;
409+
}
380410
const uoffset_t pos = new_offsets_data[p * lengths_size + b_t];
381411
new_indices_data[pos] = new_idx;
382412
if (sequence) {
@@ -911,8 +941,8 @@ block_bucketize_sparse_features_cpu(
911941
const int64_t my_size,
912942
const c10::optional<Tensor>& weights,
913943
const c10::optional<Tensor>& batch_size_per_feature,
914-
const int64_t /* max_batch_size */ // Only used in GPU variant
915-
) {
944+
const int64_t /* max_batch_size */, // Only used in GPU variant
945+
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos) {
916946
const auto lengths_size = lengths.numel();
917947
const auto new_lengths_size = lengths_size * my_size;
918948
auto new_lengths = at::zeros({new_lengths_size}, lengths.options());
@@ -959,7 +989,8 @@ block_bucketize_sparse_features_cpu(
959989
new_weights,
960990
new_pos,
961991
unbucketize_permute,
962-
batch_size_per_feature);
992+
batch_size_per_feature,
993+
block_bucketize_pos);
963994
});
964995
});
965996
});
@@ -994,7 +1025,8 @@ block_bucketize_sparse_features_cpu(
9941025
new_weights,
9951026
new_pos,
9961027
unbucketize_permute,
997-
batch_size_per_feature);
1028+
batch_size_per_feature,
1029+
block_bucketize_pos);
9981030
});
9991031
});
10001032
});
@@ -1027,7 +1059,8 @@ block_bucketize_sparse_features_cpu(
10271059
new_weights,
10281060
new_pos,
10291061
unbucketize_permute,
1030-
batch_size_per_feature);
1062+
batch_size_per_feature,
1063+
block_bucketize_pos);
10311064
});
10321065
});
10331066
} else {
@@ -1055,7 +1088,8 @@ block_bucketize_sparse_features_cpu(
10551088
new_weights,
10561089
new_pos,
10571090
unbucketize_permute,
1058-
batch_size_per_feature);
1091+
batch_size_per_feature,
1092+
block_bucketize_pos);
10591093
});
10601094
});
10611095
}
@@ -2702,7 +2736,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
27022736
m.def(
27032737
"expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, SymInt output_size) -> Tensor");
27042738
m.def(
2705-
"block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)");
2739+
"block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)");
27062740
m.def(
27072741
"bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, SymInt my_size, Tensor? weights=None) -> (Tensor, Tensor, Tensor?, Tensor?)");
27082742
m.def(

0 commit comments

Comments
 (0)