Skip to content

Commit 5659e9f

Browse files
sryapfacebook-github-bot
authored andcommitted
Prepare bounds_check_indices for VLE (pytorch#1633)
Summary: Pull Request resolved: pytorch#1633 Prepare `bounds_check_indices` for variable length TBE (VLE). - Update the frontend API to accept VLE args - Update the backend logic to process VLE data Reviewed By: jianyuh Differential Revision: D43253703 fbshipit-source-id: f8c270fc26501bb43e2eb4a8d3739bb31b31fbe9
1 parent 125ce44 commit 5659e9f

File tree

3 files changed

+98
-58
lines changed

3 files changed

+98
-58
lines changed

fbgemm_gpu/codegen/embedding_bounds_check.cu

Lines changed: 89 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -23,31 +23,46 @@ __device__ void adjust_offset_kernel(
2323
*offset_acc_end = indices_end;
2424
}
2525

26-
template <typename index_t>
26+
template <typename index_t, bool vle>
2727
__global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
2828
const at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
2929
rows_per_table,
3030
at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> indices,
3131
at::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits> offsets,
32+
const int32_t* const vle_metadata,
3233
const int64_t bounds_check_mode_,
3334
at::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits> warning,
3435
FixedDivisor fd) {
3536
int32_t T = rows_per_table.size(0);
36-
int32_t B = (offsets.size(0) - 1) / T;
37-
3837
int32_t b_t = blockIdx.x * blockDim.y + threadIdx.y;
39-
int32_t b; // = b_t % B;
40-
int32_t t; // = b_t / B;
41-
fd.DivMod(b_t, &t, &b);
42-
if (t >= T) {
38+
int32_t b;
39+
int32_t t;
40+
int32_t B = 0;
41+
int32_t total_B = offsets.size(0) - 1;
42+
43+
if (b_t >= total_B) {
4344
return;
4445
}
45-
auto bounds_check_mode = static_cast<BoundsCheckMode>(bounds_check_mode_);
4646

47-
auto num_rows = rows_per_table[t];
48-
auto indices_start = offsets[t * B + b];
49-
auto indices_end = offsets[t * B + b + 1];
50-
index_t num_indices = indices.size(0);
47+
if (vle) {
48+
if (threadIdx.x == 0) {
49+
// binary_search_range takes inclusive sumscan array
50+
binary_search_range(&t, vle_metadata + 1, b_t, T);
51+
b = b_t - vle_metadata[t];
52+
}
53+
t = shfl_sync(t, 0);
54+
b = shfl_sync(b, 0);
55+
} else {
56+
B = total_B / T;
57+
fd.DivMod(b_t, &t, &b);
58+
}
59+
60+
const auto bounds_check_mode =
61+
static_cast<BoundsCheckMode>(bounds_check_mode_);
62+
const auto num_rows = rows_per_table[t];
63+
auto indices_start = offsets[b_t];
64+
auto indices_end = offsets[b_t + 1];
65+
const index_t num_indices = indices.size(0);
5166

5267
if (bounds_check_mode == BoundsCheckMode::FATAL) {
5368
CUDA_KERNEL_ASSERT(indices_start >= 0);
@@ -58,12 +73,13 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
5873
indices_end > num_indices) {
5974
if (gpuAtomicIncrement(&warning[0]) == 0) {
6075
printf(
61-
"EmbeddingBoundsCheck: (at least one) Out of bounds access for "
62-
"batch: %lld, table: %lld, indices_start: %lld, indices_end: %lld,"
76+
"EmbeddingBoundsCheck (VLE %s): (at least one) Out of bounds access for "
77+
"batch: %d, table: %d, indices_start: %lld, indices_end: %lld,"
6378
" num_indices: %lld. Setting indices_start and indices_end within "
6479
"the range.\n",
65-
static_cast<int64_t>(b),
66-
static_cast<int64_t>(t),
80+
vle ? "true" : "false",
81+
b,
82+
t,
6783
static_cast<int64_t>(indices_start),
6884
static_cast<int64_t>(indices_end),
6985
static_cast<int64_t>(num_indices));
@@ -72,16 +88,16 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
7288
indices_start,
7389
indices_end,
7490
num_indices,
75-
&offsets[t * B + b],
76-
&offsets[t * B + b + 1]);
91+
&offsets[b_t],
92+
&offsets[b_t + 1]);
7793
}
7894
} else if (bounds_check_mode == BoundsCheckMode::IGNORE) {
7995
adjust_offset_kernel(
8096
indices_start,
8197
indices_end,
8298
num_indices,
83-
&offsets[t * B + b],
84-
&offsets[t * B + b + 1]);
99+
&offsets[b_t],
100+
&offsets[b_t + 1]);
85101
}
86102

87103
const auto L = indices_end - indices_start;
@@ -100,9 +116,10 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
100116
if (idx < 0 || idx >= num_rows) {
101117
if (gpuAtomicIncrement(&warning[0]) == 0) {
102118
printf(
103-
"EmbeddingBoundsCheck: (at least one) Out of bounds access for batch: %lld, table: %lld, bag element: %lld, idx: %lld, num_rows: %lld, indices_start: %lld, indices_end: %lld, T: %d, B: %d, b_t: %d. Setting idx to zero.\n",
104-
static_cast<int64_t>(b),
105-
static_cast<int64_t>(t),
119+
"EmbeddingBoundsCheck (VLE %s): (at least one) Out of bounds access for batch: %d, table: %d, bag element: %lld, idx: %lld, num_rows: %lld, indices_start: %lld, indices_end: %lld, T: %d, B: %d, b_t: %d. Setting idx to zero.\n",
120+
vle ? "true" : "false",
121+
b,
122+
t,
106123
static_cast<int64_t>(i),
107124
static_cast<int64_t>(idx),
108125
num_rows,
@@ -122,25 +139,27 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
122139
}
123140

124141
if (bounds_check_mode == BoundsCheckMode::FATAL) {
125-
CUDA_KERNEL_ASSERT(num_indices == offsets[B * T]);
142+
CUDA_KERNEL_ASSERT(num_indices == offsets[total_B]);
126143
} else if (bounds_check_mode == BoundsCheckMode::WARNING) {
127-
if (num_indices != offsets[B * T]) {
144+
if (num_indices != offsets[total_B]) {
128145
if (gpuAtomicIncrement(&warning[0]) == 0) {
129146
printf(
130-
"EmbeddingBoundsCheck: the last element in offsets is incorrect for "
131-
"total batch size B: %lld, total table num T: %lld, "
147+
"EmbeddingBoundsCheck (VLE %s): the last element in offsets is incorrect for "
148+
"total batch size %s: %d, total table num T: %d, "
132149
" last element in offsets: %lld, indices size: %lld. "
133150
" Setting the last element in offsets to be indices size.\n",
134-
static_cast<int64_t>(B),
135-
static_cast<int64_t>(T),
136-
static_cast<int64_t>(offsets[B * T]),
151+
vle ? "true" : "false",
152+
vle ? "total_B" : "B",
153+
vle ? total_B : B,
154+
T,
155+
static_cast<int64_t>(offsets[total_B]),
137156
static_cast<int64_t>(num_indices));
138157
}
139-
offsets[B * T] = num_indices;
158+
offsets[total_B] = num_indices;
140159
}
141160
} else if (bounds_check_mode == BoundsCheckMode::IGNORE) {
142-
if (num_indices != offsets[B * T]) {
143-
offsets[B * T] = num_indices;
161+
if (num_indices != offsets[total_B]) {
162+
offsets[total_B] = num_indices;
144163
}
145164
}
146165
}
@@ -151,19 +170,22 @@ void bounds_check_indices_cuda(
151170
Tensor& offsets,
152171
int64_t bounds_check_mode_,
153172
Tensor& warning,
154-
c10::optional<Tensor> weights) {
173+
const c10::optional<Tensor>& weights,
174+
const c10::optional<Tensor>& vle_metadata) {
155175
TENSOR_ON_CUDA_GPU(rows_per_table);
156176
TENSOR_ON_CUDA_GPU(indices);
157177
TENSOR_ON_CUDA_GPU(offsets);
158178
TENSOR_ON_CUDA_GPU(warning);
159179
TENSOR_EMPTY_OR_ON_CUDA_GPU(weights);
180+
TENSOR_EMPTY_OR_ON_CUDA_GPU(vle_metadata);
160181

161182
at::cuda::OptionalCUDAGuard device_guard;
162183
device_guard.set_index(rows_per_table.get_device());
163184

164185
const int32_t T = rows_per_table.size(0);
165-
const int32_t B = (offsets.size(0) - 1) / T;
166-
if (B == 0 || T == 0) {
186+
const int32_t total_B = offsets.size(0) - 1;
187+
const int32_t B = (total_B) / T;
188+
if (total_B == 0 || T == 0) {
167189
return;
168190
}
169191
const auto bounds_check_mode =
@@ -173,11 +195,13 @@ void bounds_check_indices_cuda(
173195
}
174196
const int64_t num_indices = indices.size(0);
175197

176-
TORCH_CHECK(
177-
offsets.size(0) == B * T + 1,
178-
"offsets size " + std::to_string(offsets.size(0)) +
179-
" is not equal to B (" + std::to_string(B) + ") * T (" +
180-
std::to_string(T) + ") + 1");
198+
if (!vle_metadata.has_value()) {
199+
TORCH_CHECK(
200+
offsets.size(0) == B * T + 1,
201+
"offsets size " + std::to_string(offsets.size(0)) +
202+
" is not equal to B (" + std::to_string(B) + ") * T (" +
203+
std::to_string(T) + ") + 1");
204+
}
181205
if (weights.has_value()) {
182206
TORCH_CHECK(
183207
weights.value().size(0) == num_indices,
@@ -187,19 +211,30 @@ void bounds_check_indices_cuda(
187211

188212
constexpr size_t kNumThreads = 256;
189213

214+
#define INVOKE_BOUNDS_CHECK_INDICES_KERNEL(VAR_BATCH_SIZE, VAR_B_METADATA) \
215+
bounds_check_indices_kernel<index_t, VAR_BATCH_SIZE> \
216+
<<<div_round_up(total_B, kNumThreads / fbgemm_gpu::kWarpSize), \
217+
dim3(fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize), \
218+
0, \
219+
at::cuda::getCurrentCUDAStream()>>>( \
220+
rows_per_table \
221+
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), \
222+
indices.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(), \
223+
offsets.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(), \
224+
VAR_B_METADATA, \
225+
bounds_check_mode_, \
226+
warning.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), \
227+
FixedDivisor(B)); \
228+
C10_CUDA_KERNEL_LAUNCH_CHECK()
229+
190230
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices", [&] {
191-
bounds_check_indices_kernel<index_t>
192-
<<<div_round_up(B * T, kNumThreads / fbgemm_gpu::kWarpSize),
193-
dim3(fbgemm_gpu::kWarpSize, kNumThreads / fbgemm_gpu::kWarpSize),
194-
0,
195-
at::cuda::getCurrentCUDAStream()>>>(
196-
rows_per_table
197-
.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
198-
indices.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
199-
offsets.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
200-
bounds_check_mode_,
201-
warning.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
202-
FixedDivisor(B));
231+
if (vle_metadata.has_value()) {
232+
INVOKE_BOUNDS_CHECK_INDICES_KERNEL(
233+
true, vle_metadata.value().data_ptr<int32_t>());
234+
} else {
235+
INVOKE_BOUNDS_CHECK_INDICES_KERNEL(false, nullptr);
236+
}
203237
});
204-
C10_CUDA_KERNEL_LAUNCH_CHECK();
238+
239+
#undef INVOKE_BOUNDS_CHECK_INDICES_KERNEL
205240
}

fbgemm_gpu/codegen/embedding_bounds_check_host.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ void bounds_check_indices_cuda(
2323
Tensor& offsets,
2424
int64_t bounds_check_mode,
2525
Tensor& warning,
26-
c10::optional<Tensor> weights);
26+
const c10::optional<Tensor>& weights,
27+
const c10::optional<Tensor>& vle_metadata);
2728

2829
// Deprecated for fb namespace! Please use fbgemm namespace instead!
2930
TORCH_LIBRARY_FRAGMENT(fb, m) {

fbgemm_gpu/codegen/embedding_bounds_check_host_cpu.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,11 @@ void bounds_check_indices_cpu(
4242
Tensor& offsets,
4343
int64_t bounds_check_mode_,
4444
Tensor& warning,
45-
c10::optional<Tensor> weights) {
45+
const c10::optional<Tensor>& weights,
46+
const c10::optional<Tensor>& vle_metadata) {
47+
TORCH_CHECK(
48+
!vle_metadata.has_value(),
49+
"bounds_check_indices on CPU does not support variable length (batch size)");
4650
auto bounds_check_mode = static_cast<BoundsCheckMode>(bounds_check_mode_);
4751
if (bounds_check_mode == BoundsCheckMode::WARNING) {
4852
warning.zero_();
@@ -163,14 +167,14 @@ TORCH_LIBRARY_FRAGMENT(fb, m) {
163167
// The (a!) tells PyTorch this is an impure operation and so cannot be CSE'd
164168
// or DCE'd, etc.
165169
m.def(
166-
"bounds_check_indices(Tensor rows_per_table, Tensor(a!) indices, Tensor(b!) offsets, int bounds_check_mode, Tensor(c!) warning, Tensor(d!)? weights=None) -> ()");
170+
"bounds_check_indices(Tensor rows_per_table, Tensor(a!) indices, Tensor(b!) offsets, int bounds_check_mode, Tensor(c!) warning, Tensor(d!)? weights=None, Tensor? vle_metadata=None) -> ()");
167171
DISPATCH_TO_CPU("bounds_check_indices", bounds_check_indices_cpu);
168172
}
169173

170174
TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
171175
// The (a!) tells PyTorch this is an impure operation and so cannot be CSE'd
172176
// or DCE'd, etc.
173177
m.def(
174-
"bounds_check_indices(Tensor rows_per_table, Tensor(a!) indices, Tensor(b!) offsets, int bounds_check_mode, Tensor(c!) warning, Tensor(d!)? weights=None) -> ()");
178+
"bounds_check_indices(Tensor rows_per_table, Tensor(a!) indices, Tensor(b!) offsets, int bounds_check_mode, Tensor(c!) warning, Tensor(d!)? weights=None, Tensor? vle_metadata=None) -> ()");
175179
DISPATCH_TO_CPU("bounds_check_indices", bounds_check_indices_cpu);
176180
}

0 commit comments

Comments
 (0)