Skip to content

Commit 50eeead

Browse files
tissue3facebook-github-bot
authored andcommitted
Support variable bucket size for block_bucketize_sparse_features (pytorch#2107)
Summary: This diff add support for variable bucket size for block bucketize_sparse features for RW sharding. E.g. Given bucket_sizes_pos as [[0,5,15], [0,10,13]] For batch 0, indices in [0,5) will be assigned to bucket 0, indices in [5,15) will be assigned to bucket 1. For batch 1, indices in [0,10) will be assigned to bucket 0, indices in [10,13) will be assigned to bucket 1. The new index will be original index - bucket_sizes_pos[new_bucket_id-1] i.e. for batch = 0, index = 12, it will be assigned to bucket 1 and the new index is 12 - 5 = 7. Differential Revision: D50868649
1 parent 2117dd3 commit 50eeead

File tree

5 files changed

+163
-22
lines changed

5 files changed

+163
-22
lines changed

fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ block_bucketize_sparse_features_cuda(
149149
const int64_t my_size,
150150
const c10::optional<at::Tensor>& weights,
151151
const c10::optional<at::Tensor>& batch_size_per_feature,
152-
const int64_t max_batch_size);
152+
const int64_t max_batch_size,
153+
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos);
153154

154155
std::tuple<
155156
at::Tensor,
@@ -168,7 +169,8 @@ block_bucketize_sparse_features_cpu(
168169
const int64_t my_size,
169170
const c10::optional<at::Tensor>& weights,
170171
const c10::optional<at::Tensor>& batch_size_per_feature,
171-
const int64_t max_batch_size);
172+
const int64_t max_batch_size,
173+
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos);
172174

173175
std::tuple<
174176
at::Tensor,

fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,10 @@ block_bucketize_sparse_features_cuda(
147147
const int64_t my_size,
148148
const c10::optional<Tensor>& weights,
149149
const c10::optional<Tensor>& batch_size_per_feature,
150-
const int64_t max_B) {
150+
const int64_t max_B,
151+
const c10::optional<std::vector<
152+
at::Tensor>>& /*block_bucketize_pos*/ // Only used in GPU variant
153+
) {
151154
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(lengths, indices);
152155

153156
at::cuda::OptionalCUDAGuard device_guard;

fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,8 @@ void _block_bucketize_sparse_features_cpu(
284284
c10::optional<Tensor> new_weights,
285285
c10::optional<Tensor> new_pos,
286286
const c10::optional<Tensor>& unbucketize_permute,
287-
const c10::optional<Tensor>& batch_size_per_feature) {
287+
const c10::optional<Tensor>& batch_size_per_feature,
288+
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos) {
288289
// allocate tensors and buffers
289290
const auto lengths_size = lengths.numel();
290291
const auto new_lengths_size = lengths_size * my_size;
@@ -305,7 +306,7 @@ void _block_bucketize_sparse_features_cpu(
305306
const index_t* const block_sizes_data = block_sizes.data_ptr<index_t>();
306307
offset_t* batch_sizes_data = nullptr;
307308
const auto variable_batch_size = batch_size_per_feature.has_value();
308-
309+
const auto variable_bucket_sizes = block_bucketize_pos.has_value();
309310
using uindex_t = std::make_unsigned_t<index_t>;
310311
using uoffset_t = std::make_unsigned_t<offset_t>;
311312

@@ -331,6 +332,12 @@ void _block_bucketize_sparse_features_cpu(
331332
for (const auto t : c10::irange(T)) {
332333
const auto blk_size = block_sizes_data[t];
333334
const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B;
335+
const index_t* bucketize_offset = nullptr;
336+
int64_t bucket_size = 0;
337+
if (variable_bucket_sizes) {
338+
bucketize_offset = block_bucketize_pos.value()[t].data_ptr<index_t>();
339+
bucket_size = block_bucketize_pos.value()[t].numel();
340+
}
334341
for (const auto b : c10::irange(cur_batch_size)) {
335342
const auto b_t = (variable_batch_size ? cur_offset : t * B) + b;
336343
const offset_t rowstart = offsets_data[b_t];
@@ -343,10 +350,20 @@ void _block_bucketize_sparse_features_cpu(
343350
// range of blk_size, we expect the later embedding module to take care
344351
// of hashing indices calculation.
345352
uindex_t idx = static_cast<uindex_t>(indices_data[i]);
346-
uindex_t p = idx < static_cast<uindex_t>(blk_size * my_size)
347-
? idx / blk_size
348-
: idx % my_size;
349-
new_lengths_data[p * lengths_size + b_t]++;
353+
if (variable_bucket_sizes) {
354+
int64_t lb = std::upper_bound(
355+
bucketize_offset,
356+
bucketize_offset + bucket_size,
357+
indices_data[i]) -
358+
bucketize_offset - 1;
359+
uindex_t p = lb < my_size ? lb : idx % my_size;
360+
new_lengths_data[p * lengths_size + b_t]++;
361+
} else {
362+
uindex_t p = idx < static_cast<uindex_t>(blk_size * my_size)
363+
? idx / blk_size
364+
: idx % my_size;
365+
new_lengths_data[p * lengths_size + b_t]++;
366+
}
350367
}
351368
}
352369
cur_offset += cur_batch_size;
@@ -359,6 +376,12 @@ void _block_bucketize_sparse_features_cpu(
359376
for (const auto t : c10::irange(T)) {
360377
const auto blk_size = block_sizes_data[t];
361378
const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B;
379+
const index_t* bucketize_offset = nullptr;
380+
int64_t bucket_size = 0;
381+
if (variable_bucket_sizes) {
382+
bucketize_offset = block_bucketize_pos.value()[t].data_ptr<index_t>();
383+
bucket_size = block_bucketize_pos.value()[t].numel();
384+
}
362385
for (const auto b : c10::irange(cur_batch_size)) {
363386
const auto b_t = (variable_batch_size ? cur_offset : t * B) + b;
364387
const offset_t rowstart = offsets_data[b_t];
@@ -371,12 +394,22 @@ void _block_bucketize_sparse_features_cpu(
371394
// range of blk_size, we expect the later embedding module to take care
372395
// of hashing indices calculation.
373396
const uindex_t idx = static_cast<uindex_t>(indices_data[i]);
374-
const uindex_t p = idx < static_cast<uindex_t>(blk_size * my_size)
375-
? idx / blk_size
376-
: idx % my_size;
377-
const uindex_t new_idx = idx < static_cast<uindex_t>(blk_size * my_size)
378-
? idx % blk_size
379-
: idx / my_size;
397+
uindex_t p, new_idx;
398+
if (variable_bucket_sizes) {
399+
auto lb = std::upper_bound(
400+
bucketize_offset,
401+
bucketize_offset + bucket_size,
402+
indices_data[i]) -
403+
bucketize_offset - 1;
404+
p = lb < my_size ? lb : idx % my_size;
405+
new_idx = lb < my_size ? idx - bucketize_offset[lb] : idx / my_size;
406+
} else {
407+
p = idx < static_cast<uindex_t>(blk_size * my_size) ? idx / blk_size
408+
: idx % my_size;
409+
new_idx = idx < static_cast<uindex_t>(blk_size * my_size)
410+
? idx % blk_size
411+
: idx / my_size;
412+
}
380413
const uoffset_t pos = new_offsets_data[p * lengths_size + b_t];
381414
new_indices_data[pos] = new_idx;
382415
if (sequence) {
@@ -911,8 +944,8 @@ block_bucketize_sparse_features_cpu(
911944
const int64_t my_size,
912945
const c10::optional<Tensor>& weights,
913946
const c10::optional<Tensor>& batch_size_per_feature,
914-
const int64_t /* max_batch_size */ // Only used in GPU variant
915-
) {
947+
const int64_t /* max_batch_size */, // Only used in GPU variant
948+
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos) {
916949
const auto lengths_size = lengths.numel();
917950
const auto new_lengths_size = lengths_size * my_size;
918951
auto new_lengths = at::zeros({new_lengths_size}, lengths.options());
@@ -959,7 +992,8 @@ block_bucketize_sparse_features_cpu(
959992
new_weights,
960993
new_pos,
961994
unbucketize_permute,
962-
batch_size_per_feature);
995+
batch_size_per_feature,
996+
block_bucketize_pos);
963997
});
964998
});
965999
});
@@ -994,7 +1028,8 @@ block_bucketize_sparse_features_cpu(
9941028
new_weights,
9951029
new_pos,
9961030
unbucketize_permute,
997-
batch_size_per_feature);
1031+
batch_size_per_feature,
1032+
block_bucketize_pos);
9981033
});
9991034
});
10001035
});
@@ -1027,7 +1062,8 @@ block_bucketize_sparse_features_cpu(
10271062
new_weights,
10281063
new_pos,
10291064
unbucketize_permute,
1030-
batch_size_per_feature);
1065+
batch_size_per_feature,
1066+
block_bucketize_pos);
10311067
});
10321068
});
10331069
} else {
@@ -1055,7 +1091,8 @@ block_bucketize_sparse_features_cpu(
10551091
new_weights,
10561092
new_pos,
10571093
unbucketize_permute,
1058-
batch_size_per_feature);
1094+
batch_size_per_feature,
1095+
block_bucketize_pos);
10591096
});
10601097
});
10611098
}
@@ -2702,7 +2739,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
27022739
m.def(
27032740
"expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, SymInt output_size) -> Tensor");
27042741
m.def(
2705-
"block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)");
2742+
"block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)");
27062743
m.def(
27072744
"bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, SymInt my_size, Tensor? weights=None) -> (Tensor, Tensor, Tensor?, Tensor?)");
27082745
m.def(

fbgemm_gpu/test/failures_dict.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@
2929
"comment": "",
3030
"status": "xfail"
3131
},
32+
"SparseOpsTest.test_aot_dispatch_dynamic__test_block_bucketize_sparse_features_with_block_bucketize_pos": {
33+
"comment": "",
34+
"status": "xfail"
35+
},
3236
"SparseOpsTest.test_aot_dispatch_dynamic__test_block_bucketize_sparse_features_with_variable_batch_sizes": {
3337
"comment": "",
3438
"status": "xfail"
@@ -41,6 +45,10 @@
4145
"comment": "",
4246
"status": "xfail"
4347
},
48+
"SparseOpsTest.test_faketensor__test_block_bucketize_sparse_features_with_block_bucketize_pos": {
49+
"comment": "",
50+
"status": "xfail"
51+
},
4452
"SparseOpsTest.test_faketensor__test_block_bucketize_sparse_features_with_variable_batch_sizes": {
4553
"comment": "",
4654
"status": "xfail"

fbgemm_gpu/test/sparse_ops_test.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,14 @@ def test_block_bucketize_sparse_features_with_variable_batch_sizes(
869869
bucketize_pos: bool,
870870
sequence: bool,
871871
) -> None:
872+
"""
873+
Test variable bucket size for block bucketize_sparse features for RW sharding.
874+
E.g. Given bucket_sizes_pos as [[0,5,15], [0,10,13]]
875+
For batch 0, indices in [0,5) will be assigned to bucket 0, indices in [5,15) will be assigned to bucket 1.
876+
For batch 1, indices in [0,10) will be assigned to bucket 0, indices in [10,13) will be assigned to bucket 1.
877+
The new index will be original index - bucket_sizes_pos[new_bucket_id-1]
878+
i.e. for batch = 0, index = 12, it will be assigned to bucket 1 and the new index is 12 - 5 = 7.
879+
"""
872880
lengths = torch.tensor([2, 1, 1, 2, 0, 2], dtype=index_type)
873881
indices = torch.tensor(
874882
[1, 8, 5, 6, 7, 8, 8, 4],
@@ -942,6 +950,89 @@ def test_block_bucketize_sparse_features_with_variable_batch_sizes(
942950
new_indices_gpu.cpu(), new_indices_ref, rtol=0, atol=0
943951
)
944952

953+
@given(
954+
index_type=st.sampled_from([torch.int, torch.long]),
955+
has_weight=st.booleans(),
956+
bucketize_pos=st.booleans(),
957+
sequence=st.booleans(),
958+
)
959+
@settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None)
960+
def test_block_bucketize_sparse_features_with_block_bucketize_pos(
961+
self,
962+
index_type: Optional[torch.dtype],
963+
has_weight: bool,
964+
bucketize_pos: bool,
965+
sequence: bool,
966+
) -> None:
967+
lengths = torch.tensor([2, 1, 1, 2, 0, 2], dtype=index_type)
968+
indices = torch.tensor(
969+
[1, 7, 2, 6, 7, 8, 8, 4],
970+
dtype=index_type,
971+
)
972+
batch_sizes = torch.tensor([3, 1, 2], dtype=index_type)
973+
weights = (
974+
torch.tensor(
975+
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
976+
dtype=torch.float,
977+
)
978+
if has_weight
979+
else None
980+
)
981+
982+
block_sizes = torch.tensor([5, 10, 8], dtype=index_type)
983+
my_size = 2
984+
max_B = batch_sizes.max().item() # unused
985+
986+
block_bucketize_pos = [
987+
torch.tensor([0, 2, 8], dtype=index_type),
988+
torch.tensor([0, 5, 10], dtype=index_type),
989+
torch.tensor([0, 7, 12], dtype=index_type),
990+
]
991+
992+
new_lengths_ref = torch.tensor(
993+
[1, 0, 0, 0, 0, 1, 1, 1, 1, 2, 0, 1],
994+
dtype=index_type,
995+
)
996+
new_indices_ref = torch.tensor(
997+
[1, 4, 5, 0, 4, 2, 3, 1],
998+
dtype=index_type,
999+
)
1000+
new_weights_ref = torch.tensor(
1001+
[
1002+
1.0,
1003+
8.0,
1004+
2.0,
1005+
3.0,
1006+
4.0,
1007+
5.0,
1008+
6.0,
1009+
7.0,
1010+
],
1011+
dtype=torch.float,
1012+
)
1013+
(
1014+
new_lengths_cpu,
1015+
new_indices_cpu,
1016+
new_weights_cpu,
1017+
new_pos_cpu,
1018+
unbucketize_permute,
1019+
) = torch.ops.fbgemm.block_bucketize_sparse_features(
1020+
lengths,
1021+
indices,
1022+
bucketize_pos,
1023+
sequence,
1024+
block_sizes,
1025+
my_size,
1026+
weights,
1027+
batch_sizes,
1028+
max_B,
1029+
block_bucketize_pos,
1030+
)
1031+
torch.testing.assert_close(new_lengths_cpu, new_lengths_ref, rtol=0, atol=0)
1032+
torch.testing.assert_close(new_indices_cpu, new_indices_ref, rtol=0, atol=0)
1033+
if has_weight:
1034+
torch.testing.assert_close(new_weights_cpu, new_weights_ref)
1035+
9451036
@given(
9461037
index_type=st.sampled_from([torch.int, torch.long]),
9471038
has_weight=st.booleans(),

0 commit comments

Comments
 (0)