Skip to content

Commit 37044aa

Browse files
sryapfacebook-github-bot
authored andcommitted
Improve bounds_check_indices for VBE (pytorch#3386)
Summary: X-link: facebookresearch/FBGEMM#475 Instead of over launching thread blocks, use `b_t_map` to launch only necessary thread blocks to increase occupancy for the VBE case Note that `b_t_map` is necessary for the TBE look for the VBE case. It is generated during the TBE forward pass. In this diff, we call `generate_vbe_metdata` twice (before bounds check and before forward look up). These two calls can be fused into one. We will clean this up in the subsequent diffs. Differential Revision: D65735342
1 parent 5fa2054 commit 37044aa

File tree

6 files changed

+83
-42
lines changed

6 files changed

+83
-42
lines changed

fbgemm_gpu/codegen/utils/embedding_bounds_check_host.cpp

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@ void _bounds_check_indices_cuda_v1(
2828
Tensor& warning,
2929
const std::optional<Tensor>& weights,
3030
const std::optional<Tensor>& B_offsets,
31-
const int64_t max_B);
31+
const int64_t max_B,
32+
const std::optional<Tensor>& b_t_map,
33+
const int32_t info_B_num_bits,
34+
const uint32_t info_B_mask);
3235

3336
void _bounds_check_indices_cuda_v2(
3437
Tensor& rows_per_table,
@@ -38,7 +41,10 @@ void _bounds_check_indices_cuda_v2(
3841
Tensor& warning,
3942
const std::optional<Tensor>& weights,
4043
const std::optional<Tensor>& B_offsets,
41-
const int64_t max_B);
44+
const int64_t max_B,
45+
const std::optional<Tensor>& b_t_map,
46+
const int32_t info_B_num_bits,
47+
const uint32_t info_B_mask);
4248

4349
///@ingroup embedding-cuda
4450
void bounds_check_indices_cuda(
@@ -49,7 +55,10 @@ void bounds_check_indices_cuda(
4955
Tensor& warning,
5056
const std::optional<Tensor>& weights,
5157
const std::optional<Tensor>& B_offsets,
52-
const int64_t max_B) {
58+
const int64_t max_B,
59+
const std::optional<Tensor>& b_t_map,
60+
const int64_t info_B_num_bits,
61+
const int64_t info_B_mask) {
5362
const static bool use_v2 = fbgemm_gpu::config::is_feature_enabled(
5463
fbgemm_gpu::config::FeatureGateName::BOUNDS_CHECK_INDICES_V2);
5564
const auto bounds_check_indices_fn =
@@ -62,7 +71,10 @@ void bounds_check_indices_cuda(
6271
warning,
6372
weights,
6473
B_offsets,
65-
max_B);
74+
max_B,
75+
b_t_map,
76+
static_cast<int32_t>(info_B_num_bits),
77+
static_cast<uint32_t>(info_B_mask));
6678
}
6779
// Deprecated for fb namespace! Please use fbgemm namespace instead!
6880
TORCH_LIBRARY_FRAGMENT(fb, m) {

fbgemm_gpu/codegen/utils/embedding_bounds_check_host_cpu.cpp

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,10 @@ void bounds_check_indices_cpu(
4848
Tensor& warning,
4949
const std::optional<Tensor>& weights,
5050
const std::optional<Tensor>& B_offsets,
51-
const int64_t max_B) {
51+
const int64_t max_B,
52+
const std::optional<Tensor>& /*b_t_map*/,
53+
const int64_t /*info_B_num_bits*/,
54+
const int64_t /*info_B_mask*/) {
5255
if (offsets.scalar_type() != indices.scalar_type()) {
5356
offsets = offsets.toType(indices.scalar_type());
5457
}
@@ -190,7 +193,19 @@ TORCH_LIBRARY_FRAGMENT(fb, m) {
190193
// The (a!) tells PyTorch this is an impure operation and so cannot be CSE'd
191194
// or DCE'd, etc.
192195
m.def(
193-
"bounds_check_indices(Tensor rows_per_table, Tensor(a!) indices, Tensor(b!) offsets, int bounds_check_mode, Tensor(c!) warning, Tensor(d!)? weights=None, Tensor? B_offsets=None, SymInt max_B=-1) -> ()",
196+
"bounds_check_indices("
197+
" Tensor rows_per_table, "
198+
" Tensor(a!) indices, "
199+
" Tensor(b!) offsets, "
200+
" int bounds_check_mode, "
201+
" Tensor(c!) warning, "
202+
" Tensor(d!)? weights=None, "
203+
" Tensor? B_offsets=None, "
204+
" SymInt max_B=-1, "
205+
" Tensor? b_t_map=None, "
206+
" int info_B_num_bits=-1, "
207+
" int info_B_mask=-1"
208+
") -> ()",
194209
{PT2_COMPLIANT_TAG});
195210
DISPATCH_TO_CPU("bounds_check_indices", bounds_check_indices_cpu);
196211
}
@@ -202,7 +217,19 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
202217
"fbgemm_gpu.sparse_ops",
203218
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py");
204219
m.def(
205-
"bounds_check_indices(Tensor rows_per_table, Tensor(a!) indices, Tensor(b!) offsets, int bounds_check_mode, Tensor(c!) warning, Tensor(d!)? weights=None, Tensor? B_offsets=None, SymInt max_B=-1) -> ()",
220+
"bounds_check_indices("
221+
" Tensor rows_per_table, "
222+
" Tensor(a!) indices, "
223+
" Tensor(b!) offsets, "
224+
" int bounds_check_mode, "
225+
" Tensor(c!) warning, "
226+
" Tensor(d!)? weights=None, "
227+
" Tensor? B_offsets=None, "
228+
" SymInt max_B=-1, "
229+
" Tensor? b_t_map=None, "
230+
" int info_B_num_bits=-1, "
231+
" int info_B_mask=-1"
232+
") -> ()",
206233
{PT2_COMPLIANT_TAG});
207234
DISPATCH_TO_CPU("bounds_check_indices", bounds_check_indices_cpu);
208235
}

fbgemm_gpu/codegen/utils/embedding_bounds_check_v1.cu

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,10 @@ void _bounds_check_indices_cuda_v1(
187187
Tensor& warning,
188188
const std::optional<Tensor>& weights,
189189
const std::optional<Tensor>& B_offsets,
190-
const int64_t max_B) {
190+
const int64_t max_B,
191+
const std::optional<Tensor>& /*b_t_map*/,
192+
const int32_t /*info_b_num_bits*/,
193+
const uint32_t /*info_B_mask*/) {
191194
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
192195
rows_per_table, indices, offsets, warning, weights, B_offsets);
193196
TENSOR_NDIM_EQUALS(rows_per_table, 1);

fbgemm_gpu/codegen/utils/embedding_bounds_check_v2.cu

Lines changed: 24 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
4141
// dummy PackedTensorAccessor
4242
pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> warning,
4343
FixedDivisor fd,
44-
const int32_t vbe_bound,
44+
const int32_t* const b_t_map,
45+
const int32_t info_B_num_bits,
46+
const int32_t info_B_mask,
4547
TORCH_DSA_KERNEL_ARGS) {
4648
int32_t T = rows_per_table.size(0);
4749
int32_t total_B = offsets.size(0) - 1;
@@ -80,28 +82,17 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel_v2(
8082
}
8183
}
8284

83-
for (int32_t b_t_init = blockIdx.x * blockDim.y + threadIdx.y;
84-
b_t_init < (vbe ? vbe_bound : total_B);
85-
b_t_init += blockDim.y * gridDim.x) {
85+
for (int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y; b_t < total_B;
86+
b_t += blockDim.y * gridDim.x) {
87+
// Compute b and t
8688
int32_t b;
8789
int32_t t;
88-
int32_t b_t = b_t_init;
89-
90-
fd.DivMod(b_t, &t, &b);
91-
9290
if (vbe) {
93-
// Check if t is valid
94-
if (t >= T) {
95-
return;
96-
}
97-
const auto B_start = B_offsets[t];
98-
B = B_offsets[t + 1] - B_start;
99-
// Check if b is valid
100-
if (b >= B) {
101-
continue;
102-
}
103-
// Update b_t value
104-
b_t = B_start + b;
91+
const auto info = *reinterpret_cast<const uint32_t*>(&b_t_map[b_t]);
92+
*reinterpret_cast<uint32_t*>(&t) = info >> info_B_num_bits;
93+
*reinterpret_cast<uint32_t*>(&b) = info & info_B_mask;
94+
} else {
95+
fd.DivMod(b_t, &t, &b);
10596
}
10697

10798
const auto num_rows = rows_per_table[t];
@@ -208,9 +199,12 @@ void _bounds_check_indices_cuda_v2(
208199
Tensor& warning,
209200
const std::optional<Tensor>& weights,
210201
const std::optional<Tensor>& B_offsets,
211-
const int64_t max_B) {
202+
const int64_t /*max_B*/,
203+
const std::optional<Tensor>& b_t_map,
204+
const int32_t info_B_num_bits,
205+
const uint32_t info_B_mask) {
212206
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
213-
rows_per_table, indices, offsets, warning, weights, B_offsets);
207+
rows_per_table, indices, offsets, warning, weights, B_offsets, b_t_map);
214208
TENSOR_NDIM_EQUALS(rows_per_table, 1);
215209
TENSOR_NDIM_EQUALS(indices, 1);
216210
TENSOR_NDIM_EQUALS(offsets, 1);
@@ -219,6 +213,8 @@ void _bounds_check_indices_cuda_v2(
219213
const auto vbe = B_offsets.has_value();
220214
if (vbe) {
221215
TENSOR_NDIM_EQUALS(B_offsets.value(), 1);
216+
TORCH_CHECK(b_t_map.has_value());
217+
TENSOR_NDIM_EQUALS(b_t_map.value(), 1);
222218
}
223219

224220
CUDA_DEVICE_GUARD(rows_per_table);
@@ -236,9 +232,7 @@ void _bounds_check_indices_cuda_v2(
236232
}
237233
const int64_t num_indices = indices.size(0);
238234

239-
if (vbe) {
240-
TORCH_CHECK(max_B >= 0);
241-
} else {
235+
if (!vbe) {
242236
TORCH_CHECK(
243237
offsets.size(0) == B * T + 1,
244238
"offsets size " + std::to_string(offsets.size(0)) +
@@ -253,11 +247,6 @@ void _bounds_check_indices_cuda_v2(
253247
}
254248

255249
constexpr size_t kNumThreads = 1024;
256-
const auto max_B_ = vbe ? max_B : B;
257-
258-
const int32_t vbe_bound = max_B_ * T;
259-
TORCH_CHECK(
260-
vbe_bound >= 0, "EmbeddingBoundsCheck: vbe_bound is out of bound");
261250

262251
#define INVOKE_BOUNDS_CHECK_INDICES(MODE) \
263252
if (bounds_check_mode == MODE) { \
@@ -270,8 +259,7 @@ void _bounds_check_indices_cuda_v2(
270259
: bounds_check_indices_kernel_v2<index_t, false, MODE>); \
271260
TORCH_DSA_KERNEL_LAUNCH( \
272261
bounds_check_kernel, \
273-
min(div_round_up( \
274-
max_B_* T, kNumThreads / fbgemm_gpu::kWarpSize), \
262+
min(div_round_up(total_B, kNumThreads / fbgemm_gpu::kWarpSize), \
275263
get_max_thread_blocks_()), \
276264
dim3( \
277265
fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize), \
@@ -282,8 +270,10 @@ void _bounds_check_indices_cuda_v2(
282270
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), \
283271
vbe ? B_offsets.value().data_ptr<int32_t>() : nullptr, \
284272
MAKE_PTA_WITH_NAME(func_name, warning, int64_t, 1, 32), \
285-
FixedDivisor(max_B_), \
286-
vbe_bound); \
273+
FixedDivisor(B), \
274+
vbe ? b_t_map.value().data_ptr<int32_t>() : nullptr, \
275+
info_B_num_bits, \
276+
info_B_mask); \
287277
}); \
288278
}
289279

fbgemm_gpu/fbgemm_gpu/sparse_ops.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,9 @@ def bounds_check_indices_abstract(
826826
per_sample_weights: Optional[torch.Tensor] = None,
827827
B_offsets: Optional[torch.Tensor] = None,
828828
max_B: Optional[SymInt] = None,
829+
b_t_map: Optional[torch.Tensor] = None,
830+
info_B_num_bits: int = -1,
831+
info_B_mask: int = -1,
829832
) -> None:
830833
"""
831834
This meta function is used to fake the bounds checking

fbgemm_gpu/src/split_embeddings_utils/split_embeddings_utils.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@ generate_vbe_metadata_meta(
3333
return {row_output_offsets, b_t_map};
3434
}
3535

36+
std::tuple<int64_t, int64_t>
37+
get_infos_metadata_meta(Tensor /*unused*/, int64_t /*B*/, int64_t /*T*/) {
38+
return {-1, -1};
39+
}
40+
3641
} // namespace
3742

3843
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
@@ -43,4 +48,5 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
4348

4449
TORCH_LIBRARY_IMPL(fbgemm, Meta, m) {
4550
m.impl("generate_vbe_metadata", &generate_vbe_metadata_meta);
51+
m.impl("get_infos_metadata", &get_infos_metadata);
4652
}

0 commit comments

Comments
 (0)