Skip to content

Commit 2cd19dd

Browse files
sryapfacebook-github-bot
authored andcommitted
Prepare bounds_check_indices for VBE (pytorch#1633)
Summary: Pull Request resolved: pytorch#1633 Prepare `bounds_check_indices` for variable batch size TBE (VBE). - Update the frontend API to accept VBE args - Update the backend logic to process VBE data Reviewed By: jianyuh Differential Revision: D43253703 fbshipit-source-id: bdc2315ff2849d36cb4202f4883482ad04b1f183
1 parent 2776770 commit 2cd19dd

File tree

3 files changed

+103
-58
lines changed

3 files changed

+103
-58
lines changed

fbgemm_gpu/codegen/embedding_bounds_check.cu

Lines changed: 92 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -23,31 +23,52 @@ __device__ void adjust_offset_kernel(
2323
*offset_acc_end = indices_end;
2424
}
2525

26-
template <typename index_t>
26+
template <typename index_t, bool vbe>
2727
__global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
2828
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
2929
rows_per_table,
3030
at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> indices,
3131
at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets,
32+
const int32_t* const vbe_metadata,
3233
const int64_t bounds_check_mode_,
3334
at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> warning,
3435
FixedDivisor fd) {
3536
int32_t T = rows_per_table.size(0);
36-
int32_t B = (offsets.size(0) - 1) / T;
37-
3837
int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;
39-
int32_t b; // = b_t % B;
40-
int32_t t; // = b_t / B;
41-
fd.DivMod(b_t, &t, &b);
42-
if (t >= T) {
38+
int32_t b;
39+
int32_t t;
40+
int32_t B = 0;
41+
int32_t total_B = offsets.size(0) - 1;
42+
43+
if (!vbe && b_t >= total_B) {
4344
return;
4445
}
45-
auto bounds_check_mode = static_cast<BoundsCheckMode>(bounds_check_mode_);
4646

47-
auto num_rows = rows_per_table[t];
48-
auto indices_start = offsets[t * B + b];
49-
auto indices_end = offsets[t * B + b + 1];
50-
index_t num_indices = indices.size(0);
47+
fd.DivMod(b_t, &t, &b);
48+
49+
if (vbe) {
50+
// Check if t is valid
51+
if (t >= T) {
52+
return;
53+
}
54+
const auto B_start = vbe_metadata[t];
55+
B = vbe_metadata[t + 1] - B_start;
56+
// Check if b is valid
57+
if (b >= B) {
58+
return;
59+
}
60+
// Update b_t value
61+
b_t = B_start + b;
62+
} else {
63+
B = total_B / T;
64+
}
65+
66+
const auto bounds_check_mode =
67+
static_cast<BoundsCheckMode>(bounds_check_mode_);
68+
const auto num_rows = rows_per_table[t];
69+
auto indices_start = offsets[b_t];
70+
auto indices_end = offsets[b_t + 1];
71+
const index_t num_indices = indices.size(0);
5172

5273
if (bounds_check_mode == BoundsCheckMode::FATAL) {
5374
CUDA_KERNEL_ASSERT(indices_start >= 0);
@@ -58,12 +79,13 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
5879
indices_end > num_indices) {
5980
if (gpuAtomicIncrement(&warning[0]) == 0) {
6081
printf(
61-
"EmbeddingBoundsCheck: (at least one) Out of bounds access for "
62-
"batch: %lld, table: %lld, indices_start: %lld, indices_end: %lld,"
82+
"EmbeddingBoundsCheck (VBE %s): (at least one) Out of bounds access for "
83+
"batch: %d, table: %d, indices_start: %lld, indices_end: %lld,"
6384
" num_indices: %lld. Setting indices_start and indices_end within "
6485
"the range.\n",
65-
static_cast<int64_t>(b),
66-
static_cast<int64_t>(t),
86+
vbe ? "true" : "false",
87+
b,
88+
t,
6789
static_cast<int64_t>(indices_start),
6890
static_cast<int64_t>(indices_end),
6991
static_cast<int64_t>(num_indices));
@@ -72,16 +94,16 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
7294
indices_start,
7395
indices_end,
7496
num_indices,
75-
&offsets[t * B + b],
76-
&offsets[t * B + b + 1]);
97+
&offsets[b_t],
98+
&offsets[b_t + 1]);
7799
}
78100
} else if (bounds_check_mode == BoundsCheckMode::IGNORE) {
79101
adjust_offset_kernel(
80102
indices_start,
81103
indices_end,
82104
num_indices,
83-
&offsets[t * B + b],
84-
&offsets[t * B + b + 1]);
105+
&offsets[b_t],
106+
&offsets[b_t + 1]);
85107
}
86108

87109
const auto L = indices_end - indices_start;
@@ -100,9 +122,10 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
100122
if (idx < 0 || idx >= num_rows) {
101123
if (gpuAtomicIncrement(&warning[0]) == 0) {
102124
printf(
103-
"EmbeddingBoundsCheck: (at least one) Out of bounds access for batch: %lld, table: %lld, bag element: %lld, idx: %lld, num_rows: %lld, indices_start: %lld, indices_end: %lld, T: %d, B: %d, b_t: %d. Setting idx to zero.\n",
104-
static_cast<int64_t>(b),
105-
static_cast<int64_t>(t),
125+
"EmbeddingBoundsCheck (VBE %s): (at least one) Out of bounds access for batch: %d, table: %d, bag element: %lld, idx: %lld, num_rows: %lld, indices_start: %lld, indices_end: %lld, T: %d, B: %d, b_t: %d. Setting idx to zero.\n",
126+
vbe ? "true" : "false",
127+
b,
128+
t,
106129
static_cast<int64_t>(i),
107130
static_cast<int64_t>(idx),
108131
num_rows,
@@ -122,25 +145,27 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
122145
}
123146

124147
if (bounds_check_mode == BoundsCheckMode::FATAL) {
125-
CUDA_KERNEL_ASSERT(num_indices == offsets[B * T]);
148+
CUDA_KERNEL_ASSERT(num_indices == offsets[total_B]);
126149
} else if (bounds_check_mode == BoundsCheckMode::WARNING) {
127-
if (num_indices != offsets[B * T]) {
150+
if (num_indices != offsets[total_B]) {
128151
if (gpuAtomicIncrement(&warning[0]) == 0) {
129152
printf(
130-
"EmbeddingBoundsCheck: the last element in offsets is incorrect for "
131-
"total batch size B: %lld, total table num T: %lld, "
153+
"EmbeddingBoundsCheck (VBE %s): the last element in offsets is incorrect for "
154+
"total batch size %s: %d, total table num T: %d, "
132155
" last element in offsets: %lld, indices size: %lld. "
133156
" Setting the last element in offsets to be indices size.\n",
134-
static_cast<int64_t>(B),
135-
static_cast<int64_t>(T),
136-
static_cast<int64_t>(offsets[B * T]),
157+
vbe ? "true" : "false",
158+
vbe ? "total_B" : "B",
159+
vbe ? total_B : B,
160+
T,
161+
static_cast<int64_t>(offsets[total_B]),
137162
static_cast<int64_t>(num_indices));
138163
}
139-
offsets[B * T] = num_indices;
164+
offsets[total_B] = num_indices;
140165
}
141166
} else if (bounds_check_mode == BoundsCheckMode::IGNORE) {
142-
if (num_indices != offsets[B * T]) {
143-
offsets[B * T] = num_indices;
167+
if (num_indices != offsets[total_B]) {
168+
offsets[total_B] = num_indices;
144169
}
145170
}
146171
}
@@ -151,19 +176,23 @@ void bounds_check_indices_cuda(
151176
Tensor& offsets,
152177
int64_t bounds_check_mode_,
153178
Tensor& warning,
154-
c10::optional<Tensor> weights) {
179+
const c10::optional<Tensor>& weights,
180+
const c10::optional<Tensor>& vbe_metadata,
181+
const int64_t max_B) {
155182
TENSOR_ON_CUDA_GPU(rows_per_table);
156183
TENSOR_ON_CUDA_GPU(indices);
157184
TENSOR_ON_CUDA_GPU(offsets);
158185
TENSOR_ON_CUDA_GPU(warning);
159186
TENSOR_EMPTY_OR_ON_CUDA_GPU(weights);
187+
TENSOR_EMPTY_OR_ON_CUDA_GPU(vbe_metadata);
160188

161189
at::cuda::OptionalCUDAGuard device_guard;
162190
device_guard.set_index(rows_per_table.get_device());
163191

164192
const int32_t T = rows_per_table.size(0);
165-
const int32_t B = (offsets.size(0) - 1) / T;
166-
if (B == 0 || T == 0) {
193+
const int32_t total_B = offsets.size(0) - 1;
194+
const int32_t B = (total_B) / T;
195+
if (total_B == 0 || T == 0) {
167196
return;
168197
}
169198
const auto bounds_check_mode =
@@ -172,12 +201,17 @@ void bounds_check_indices_cuda(
172201
warning.zero_();
173202
}
174203
const int64_t num_indices = indices.size(0);
204+
const auto vbe = vbe_metadata.has_value();
175205

176-
TORCH_CHECK(
177-
offsets.size(0) == B * T + 1,
178-
"offsets size " + std::to_string(offsets.size(0)) +
179-
" is not equal to B (" + std::to_string(B) + ") * T (" +
180-
std::to_string(T) + ") + 1");
206+
if (vbe) {
207+
TORCH_CHECK(max_B >= 0);
208+
} else {
209+
TORCH_CHECK(
210+
offsets.size(0) == B * T + 1,
211+
"offsets size " + std::to_string(offsets.size(0)) +
212+
" is not equal to B (" + std::to_string(B) + ") * T (" +
213+
std::to_string(T) + ") + 1");
214+
}
181215
if (weights.has_value()) {
182216
TORCH_CHECK(
183217
weights.value().size(0) == num_indices,
@@ -186,20 +220,24 @@ void bounds_check_indices_cuda(
186220
}
187221

188222
constexpr size_t kNumThreads = 256;
223+
const auto max_B_ = vbe ? max_B : B;
189224

190225
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices", [&] {
191-
bounds_check_indices_kernel<index_t>
192-
<<<div_round_up(B * T, kNumThreads / fbgemm_gpu::kWarpSize),
193-
dim3(fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize),
194-
0,
195-
at::cuda::getCurrentCUDAStream()>>>(
196-
rows_per_table
197-
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
198-
indices.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
199-
offsets.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
200-
bounds_check_mode_,
201-
warning.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
202-
FixedDivisor(B));
226+
const auto bounds_check_kernel =
227+
(vbe ? bounds_check_indices_kernel<index_t, true>
228+
: bounds_check_indices_kernel<index_t, false>);
229+
bounds_check_kernel<<<
230+
div_round_up(max_B_ * T, kNumThreads / fbgemm_gpu::kWarpSize),
231+
dim3(fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize),
232+
0,
233+
at::cuda::getCurrentCUDAStream()>>>(
234+
rows_per_table.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
235+
indices.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
236+
offsets.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
237+
vbe ? vbe_metadata.value().data_ptr<int32_t>() : nullptr,
238+
bounds_check_mode_,
239+
warning.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
240+
FixedDivisor(max_B_));
241+
C10_CUDA_KERNEL_LAUNCH_CHECK();
203242
});
204-
C10_CUDA_KERNEL_LAUNCH_CHECK();
205243
}

fbgemm_gpu/codegen/embedding_bounds_check_host.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ void bounds_check_indices_cuda(
2323
Tensor& offsets,
2424
int64_t bounds_check_mode,
2525
Tensor& warning,
26-
c10::optional<Tensor> weights);
26+
const c10::optional<Tensor>& weights,
27+
const c10::optional<Tensor>& vbe_metadata,
28+
const int64_t max_B);
2729

2830
// Deprecated for fb namespace! Please use fbgemm namespace instead!
2931
TORCH_LIBRARY_FRAGMENT(fb, m) {

fbgemm_gpu/codegen/embedding_bounds_check_host_cpu.cpp

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,12 @@ void bounds_check_indices_cpu(
4242
Tensor& offsets,
4343
int64_t bounds_check_mode_,
4444
Tensor& warning,
45-
c10::optional<Tensor> weights) {
45+
const c10::optional<Tensor>& weights,
46+
const c10::optional<Tensor>& vbe_metadata,
47+
const int64_t /*max_B*/) {
48+
TORCH_CHECK(
49+
!vbe_metadata.has_value(),
50+
"bounds_check_indices on CPU does not support variable length (batch size)");
4651
auto bounds_check_mode = static_cast<BoundsCheckMode>(bounds_check_mode_);
4752
if (bounds_check_mode == BoundsCheckMode::WARNING) {
4853
warning.zero_();
@@ -163,14 +168,14 @@ TORCH_LIBRARY_FRAGMENT(fb, m) {
163168
// The (a!) tells PyTorch this is an impure operation and so cannot be CSE'd
164169
// or DCE'd, etc.
165170
m.def(
166-
"bounds_check_indices(Tensor rows_per_table, Tensor(a!) indices, Tensor(b!) offsets, int bounds_check_mode, Tensor(c!) warning, Tensor(d!)? weights=None) -> ()");
171+
"bounds_check_indices(Tensor rows_per_table, Tensor(a!) indices, Tensor(b!) offsets, int bounds_check_mode, Tensor(c!) warning, Tensor(d!)? weights=None, Tensor? vbe_metadata=None, int max_B=-1) -> ()");
167172
DISPATCH_TO_CPU("bounds_check_indices", bounds_check_indices_cpu);
168173
}
169174

170175
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
171176
// The (a!) tells PyTorch this is an impure operation and so cannot be CSE'd
172177
// or DCE'd, etc.
173178
m.def(
174-
"bounds_check_indices(Tensor rows_per_table, Tensor(a!) indices, Tensor(b!) offsets, int bounds_check_mode, Tensor(c!) warning, Tensor(d!)? weights=None) -> ()");
179+
"bounds_check_indices(Tensor rows_per_table, Tensor(a!) indices, Tensor(b!) offsets, int bounds_check_mode, Tensor(c!) warning, Tensor(d!)? weights=None, Tensor? vbe_metadata=None, int max_B=-1) -> ()");
175180
DISPATCH_TO_CPU("bounds_check_indices", bounds_check_indices_cpu);
176181
}

0 commit comments

Comments
 (0)