Skip to content

Commit fb684f6

Browse files
tissue3facebook-github-bot
authored andcommitted
Support variable bucket size for block_bucketize_sparse_features (pytorch#2107)
Summary: This diff add support for variable bucket size for block bucketize_sparse features for RW sharding. E.g. Given bucket_sizes_pos as [[0,5,15], [0,10,13]] For batch 0, indices in [0,5) will be assigned to bucket 0, indices in [5,15) will be assigned to bucket 1. For batch 1, indices in [0,10) will be assigned to bucket 0, indices in [10,13) will be assigned to bucket 1. The new index will be original index - bucket_sizes_pos[new_bucket_id-1] i.e. for batch = 0, index = 12, it will be assigned to bucket 1 and the new index is 12 - 5 = 7. Differential Revision: D50868649
1 parent 80990a6 commit fb684f6

File tree

5 files changed

+173
-22
lines changed

5 files changed

+173
-22
lines changed

fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ block_bucketize_sparse_features_cuda(
149149
const int64_t my_size,
150150
const c10::optional<at::Tensor>& weights,
151151
const c10::optional<at::Tensor>& batch_size_per_feature,
152-
const int64_t max_batch_size);
152+
const int64_t max_batch_size,
153+
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos);
153154

154155
std::tuple<
155156
at::Tensor,
@@ -168,7 +169,8 @@ block_bucketize_sparse_features_cpu(
168169
const int64_t my_size,
169170
const c10::optional<at::Tensor>& weights,
170171
const c10::optional<at::Tensor>& batch_size_per_feature,
171-
const int64_t max_batch_size);
172+
const int64_t max_batch_size,
173+
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos);
172174

173175
std::tuple<
174176
at::Tensor,

fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,10 @@ block_bucketize_sparse_features_cuda(
147147
const int64_t my_size,
148148
const c10::optional<Tensor>& weights,
149149
const c10::optional<Tensor>& batch_size_per_feature,
150-
const int64_t max_B) {
150+
const int64_t max_B,
151+
const c10::optional<std::vector<
152+
at::Tensor>>& /*block_bucketize_pos*/ // Only used in GPU variant
153+
) {
151154
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(lengths, indices);
152155

153156
at::cuda::OptionalCUDAGuard device_guard;

fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp

Lines changed: 65 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,21 @@ void _permute_2D_lengths_cpu_kernel(
264264
input_offsets[i + 1] = lengths[i] + input_offsets[i];
265265
}
266266
}
267+
template <typename index_t>
268+
int64_t
269+
_find_lower_bound(const index_t* arr, const int64_t size, index_t target) {
270+
int64_t l = 0;
271+
int64_t h = size - 1;
272+
while (l <= h) {
273+
int mid = l + (h - l) / 2;
274+
if (arr[mid] > target) {
275+
h = mid - 1;
276+
} else {
277+
l = mid + 1;
278+
}
279+
}
280+
return l - 1;
281+
}
267282

268283
template <
269284
bool sequence,
@@ -283,7 +298,8 @@ void _block_bucketize_sparse_features_cpu(
283298
c10::optional<Tensor> new_weights,
284299
c10::optional<Tensor> new_pos,
285300
const c10::optional<Tensor>& unbucketize_permute,
286-
const c10::optional<Tensor>& batch_size_per_feature) {
301+
const c10::optional<Tensor>& batch_size_per_feature,
302+
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos) {
287303
// allocate tensors and buffers
288304
const auto lengths_size = lengths.numel();
289305
const auto new_lengths_size = lengths_size * my_size;
@@ -304,7 +320,7 @@ void _block_bucketize_sparse_features_cpu(
304320
const index_t* const block_sizes_data = block_sizes.data_ptr<index_t>();
305321
offset_t* batch_sizes_data = nullptr;
306322
const auto variable_batch_size = batch_size_per_feature.has_value();
307-
323+
const auto variable_bucket_sizes = block_bucketize_pos.has_value();
308324
using uindex_t = std::make_unsigned_t<index_t>;
309325
using uoffset_t = std::make_unsigned_t<offset_t>;
310326

@@ -330,6 +346,12 @@ void _block_bucketize_sparse_features_cpu(
330346
for (const auto t : c10::irange(T)) {
331347
const auto blk_size = block_sizes_data[t];
332348
const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B;
349+
const index_t* bucketize_offset = nullptr;
350+
int64_t bucket_size = 0;
351+
if (variable_bucket_sizes) {
352+
bucketize_offset = block_bucketize_pos.value()[t].data_ptr<index_t>();
353+
bucket_size = block_bucketize_pos.value()[t].numel();
354+
}
333355
for (const auto b : c10::irange(cur_batch_size)) {
334356
const auto b_t = (variable_batch_size ? cur_offset : t * B) + b;
335357
const offset_t rowstart = offsets_data[b_t];
@@ -342,10 +364,17 @@ void _block_bucketize_sparse_features_cpu(
342364
// range of blk_size, we expect the later embedding module to take care
343365
// of hashing indices calculation.
344366
uindex_t idx = static_cast<uindex_t>(indices_data[i]);
345-
uindex_t p = idx < static_cast<uindex_t>(blk_size * my_size)
346-
? idx / blk_size
347-
: idx % my_size;
348-
new_lengths_data[p * lengths_size + b_t]++;
367+
if (variable_bucket_sizes) {
368+
auto lb = _find_lower_bound<index_t>(
369+
bucketize_offset, bucket_size, indices_data[i]);
370+
uindex_t p = lb < my_size ? lb : idx % my_size;
371+
new_lengths_data[p * lengths_size + b_t]++;
372+
} else {
373+
uindex_t p = idx < static_cast<uindex_t>(blk_size * my_size)
374+
? idx / blk_size
375+
: idx % my_size;
376+
new_lengths_data[p * lengths_size + b_t]++;
377+
}
349378
}
350379
}
351380
cur_offset += cur_batch_size;
@@ -358,6 +387,12 @@ void _block_bucketize_sparse_features_cpu(
358387
for (const auto t : c10::irange(T)) {
359388
const auto blk_size = block_sizes_data[t];
360389
const auto cur_batch_size = variable_batch_size ? batch_sizes_data[t] : B;
390+
const index_t* bucketize_offset = nullptr;
391+
int64_t bucket_size = 0;
392+
if (variable_bucket_sizes) {
393+
bucketize_offset = block_bucketize_pos.value()[t].data_ptr<index_t>();
394+
bucket_size = block_bucketize_pos.value()[t].numel();
395+
}
361396
for (const auto b : c10::irange(cur_batch_size)) {
362397
const auto b_t = (variable_batch_size ? cur_offset : t * B) + b;
363398
const offset_t rowstart = offsets_data[b_t];
@@ -370,12 +405,19 @@ void _block_bucketize_sparse_features_cpu(
370405
// range of blk_size, we expect the later embedding module to take care
371406
// of hashing indices calculation.
372407
const uindex_t idx = static_cast<uindex_t>(indices_data[i]);
373-
const uindex_t p = idx < static_cast<uindex_t>(blk_size * my_size)
374-
? idx / blk_size
375-
: idx % my_size;
376-
const uindex_t new_idx = idx < static_cast<uindex_t>(blk_size * my_size)
377-
? idx % blk_size
378-
: idx / my_size;
408+
uindex_t p, new_idx;
409+
if (variable_bucket_sizes) {
410+
auto lb = _find_lower_bound<index_t>(
411+
bucketize_offset, bucket_size, indices_data[i]);
412+
p = lb < my_size ? lb : idx % my_size;
413+
new_idx = lb < my_size ? idx - bucketize_offset[lb] : idx / my_size;
414+
} else {
415+
p = idx < static_cast<uindex_t>(blk_size * my_size) ? idx / blk_size
416+
: idx % my_size;
417+
new_idx = idx < static_cast<uindex_t>(blk_size * my_size)
418+
? idx % blk_size
419+
: idx / my_size;
420+
}
379421
const uoffset_t pos = new_offsets_data[p * lengths_size + b_t];
380422
new_indices_data[pos] = new_idx;
381423
if (sequence) {
@@ -910,8 +952,8 @@ block_bucketize_sparse_features_cpu(
910952
const int64_t my_size,
911953
const c10::optional<Tensor>& weights,
912954
const c10::optional<Tensor>& batch_size_per_feature,
913-
const int64_t /* max_batch_size */ // Only used in GPU variant
914-
) {
955+
const int64_t /* max_batch_size */, // Only used in GPU variant
956+
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos) {
915957
const auto lengths_size = lengths.numel();
916958
const auto new_lengths_size = lengths_size * my_size;
917959
auto new_lengths = at::zeros({new_lengths_size}, lengths.options());
@@ -958,7 +1000,8 @@ block_bucketize_sparse_features_cpu(
9581000
new_weights,
9591001
new_pos,
9601002
unbucketize_permute,
961-
batch_size_per_feature);
1003+
batch_size_per_feature,
1004+
block_bucketize_pos);
9621005
});
9631006
});
9641007
});
@@ -993,7 +1036,8 @@ block_bucketize_sparse_features_cpu(
9931036
new_weights,
9941037
new_pos,
9951038
unbucketize_permute,
996-
batch_size_per_feature);
1039+
batch_size_per_feature,
1040+
block_bucketize_pos);
9971041
});
9981042
});
9991043
});
@@ -1026,7 +1070,8 @@ block_bucketize_sparse_features_cpu(
10261070
new_weights,
10271071
new_pos,
10281072
unbucketize_permute,
1029-
batch_size_per_feature);
1073+
batch_size_per_feature,
1074+
block_bucketize_pos);
10301075
});
10311076
});
10321077
} else {
@@ -1054,7 +1099,8 @@ block_bucketize_sparse_features_cpu(
10541099
new_weights,
10551100
new_pos,
10561101
unbucketize_permute,
1057-
batch_size_per_feature);
1102+
batch_size_per_feature,
1103+
block_bucketize_pos);
10581104
});
10591105
});
10601106
}
@@ -2696,7 +2742,7 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
26962742
m.def(
26972743
"expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, SymInt output_size) -> Tensor");
26982744
m.def(
2699-
"block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)");
2745+
"block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)");
27002746
m.def(
27012747
"bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, SymInt my_size, Tensor? weights=None) -> (Tensor, Tensor, Tensor?, Tensor?)");
27022748
m.def("asynchronous_exclusive_cumsum(Tensor t_in) -> Tensor");

fbgemm_gpu/test/failures_dict.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@
3333
"comment": "",
3434
"status": "xfail"
3535
},
36+
"SparseOpsTest.test_aot_dispatch_dynamic__test_block_bucketize_sparse_features_with_block_bucketize_pos": {
37+
"comment": "",
38+
"status": "xfail"
39+
},
3640
"SparseOpsTest.test_aot_dispatch_dynamic__test_block_bucketize_sparse_features_with_variable_batch_sizes": {
3741
"comment": "",
3842
"status": "xfail"
@@ -57,6 +61,10 @@
5761
"comment": "",
5862
"status": "xfail"
5963
},
64+
"SparseOpsTest.test_faketensor__test_block_bucketize_sparse_features_with_block_bucketize_pos": {
65+
"comment": "",
66+
"status": "xfail"
67+
},
6068
"SparseOpsTest.test_faketensor__test_block_bucketize_sparse_features_with_variable_batch_sizes": {
6169
"comment": "",
6270
"status": "xfail"

fbgemm_gpu/test/sparse_ops_test.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,14 @@ def test_block_bucketize_sparse_features_with_variable_batch_sizes(
869869
bucketize_pos: bool,
870870
sequence: bool,
871871
) -> None:
872+
"""
873+
Test variable bucket size for block bucketize_sparse features for RW sharding.
874+
E.g. Given bucket_sizes_pos as [[0,5,15], [0,10,13]]
875+
For batch 0, indices in [0,5) will be assigned to bucket 0, indices in [5,15) will be assigned to bucket 1.
876+
For batch 1, indices in [0,10) will be assigned to bucket 0, indices in [10,13) will be assigned to bucket 1.
877+
The new index will be original index - bucket_sizes_pos[new_bucket_id-1]
878+
i.e. for batch = 0, index = 12, it will be assigned to bucket 1 and the new index is 12 - 5 = 7.
879+
"""
872880
lengths = torch.tensor([2, 1, 1, 2, 0, 2], dtype=index_type)
873881
indices = torch.tensor(
874882
[1, 8, 5, 6, 7, 8, 8, 4],
@@ -942,6 +950,90 @@ def test_block_bucketize_sparse_features_with_variable_batch_sizes(
942950
new_indices_gpu.cpu(), new_indices_ref, rtol=0, atol=0
943951
)
944952

953+
@given(
954+
index_type=st.sampled_from([torch.int, torch.long]),
955+
has_weight=st.booleans(),
956+
bucketize_pos=st.booleans(),
957+
sequence=st.booleans(),
958+
)
959+
@settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None)
960+
def test_block_bucketize_sparse_features_with_block_bucketize_pos(
961+
self,
962+
index_type: Optional[torch.dtype],
963+
has_weight: bool,
964+
bucketize_pos: bool,
965+
sequence: bool,
966+
) -> None:
967+
lengths = torch.tensor([2, 1, 1, 2, 0, 2], dtype=index_type)
968+
indices = torch.tensor(
969+
[1, 7, 2, 6, 7, 8, 8, 4],
970+
dtype=index_type,
971+
)
972+
batch_sizes = torch.tensor([3, 1, 2], dtype=index_type)
973+
weights = (
974+
torch.tensor(
975+
[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
976+
dtype=torch.float,
977+
)
978+
if has_weight
979+
else None
980+
)
981+
982+
block_sizes = torch.tensor([5, 10, 8], dtype=index_type)
983+
my_size = 2
984+
max_B = batch_sizes.max().item() # unused
985+
986+
block_bucketize_pos = [
987+
torch.tensor([0, 2, 8], dtype=index_type),
988+
torch.tensor([0, 5, 10], dtype=index_type),
989+
torch.tensor([0, 7, 12], dtype=index_type),
990+
]
991+
992+
new_lengths_ref = torch.tensor(
993+
[1, 0, 0, 0, 0, 1, 1, 1, 1, 2, 0, 1],
994+
dtype=index_type,
995+
)
996+
new_indices_ref = torch.tensor(
997+
[1, 4, 5, 0, 4, 2, 3, 1],
998+
dtype=index_type,
999+
)
1000+
new_weights_ref = torch.tensor(
1001+
[
1002+
1.0,
1003+
8.0,
1004+
2.0,
1005+
3.0,
1006+
4.0,
1007+
5.0,
1008+
6.0,
1009+
7.0,
1010+
],
1011+
dtype=torch.float,
1012+
)
1013+
(
1014+
new_lengths_cpu,
1015+
new_indices_cpu,
1016+
new_weights_cpu,
1017+
new_pos_cpu,
1018+
unbucketize_permute,
1019+
) = torch.ops.fbgemm.block_bucketize_sparse_features(
1020+
lengths,
1021+
indices,
1022+
bucketize_pos,
1023+
sequence,
1024+
block_sizes,
1025+
my_size,
1026+
weights,
1027+
batch_sizes,
1028+
max_B,
1029+
block_bucketize_pos,
1030+
)
1031+
breakpoint()
1032+
torch.testing.assert_close(new_lengths_cpu, new_lengths_ref, rtol=0, atol=0)
1033+
torch.testing.assert_close(new_indices_cpu, new_indices_ref, rtol=0, atol=0)
1034+
if has_weight:
1035+
torch.testing.assert_close(new_weights_cpu, new_weights_ref)
1036+
9451037
@given(
9461038
index_type=st.sampled_from([torch.int, torch.long]),
9471039
has_weight=st.booleans(),

0 commit comments

Comments
 (0)