@@ -23,31 +23,52 @@ __device__ void adjust_offset_kernel(
23
23
*offset_acc_end = indices_end;
24
24
}
25
25
26
- template <typename index_t >
26
+ template <typename index_t , bool vbe >
27
27
__global__ __launch_bounds__ (kMaxThreads ) void bounds_check_indices_kernel(
28
28
const at::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits>
29
29
rows_per_table,
30
30
at::PackedTensorAccessor32<index_t , 1 , at::RestrictPtrTraits> indices,
31
31
at::PackedTensorAccessor32<index_t , 1 , at::RestrictPtrTraits> offsets,
32
+ const int32_t * const vbe_metadata,
32
33
const int64_t bounds_check_mode_,
33
34
at::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits> warning,
34
35
FixedDivisor fd) {
35
36
int32_t T = rows_per_table.size (0 );
36
- int32_t B = (offsets.size (0 ) - 1 ) / T;
37
-
38
37
int32_t b_t = blockIdx .x * blockDim .y + threadIdx .y ;
39
- int32_t b; // = b_t % B;
40
- int32_t t; // = b_t / B;
41
- fd.DivMod (b_t , &t, &b);
42
- if (t >= T) {
38
+ int32_t b;
39
+ int32_t t;
40
+ int32_t B = 0 ;
41
+ int32_t total_B = offsets.size (0 ) - 1 ;
42
+
43
+ if (!vbe && b_t >= total_B) {
43
44
return ;
44
45
}
45
- auto bounds_check_mode = static_cast <BoundsCheckMode>(bounds_check_mode_);
46
46
47
- auto num_rows = rows_per_table[t];
48
- auto indices_start = offsets[t * B + b];
49
- auto indices_end = offsets[t * B + b + 1 ];
50
- index_t num_indices = indices.size (0 );
47
+ fd.DivMod (b_t , &t, &b);
48
+
49
+ if (vbe) {
50
+ // Check if t is valid
51
+ if (t >= T) {
52
+ return ;
53
+ }
54
+ const auto B_start = vbe_metadata[t];
55
+ B = vbe_metadata[t + 1 ] - B_start;
56
+ // Check if b is valid
57
+ if (b >= B) {
58
+ return ;
59
+ }
60
+ // Update b_t value
61
+ b_t = B_start + b;
62
+ } else {
63
+ B = total_B / T;
64
+ }
65
+
66
+ const auto bounds_check_mode =
67
+ static_cast <BoundsCheckMode>(bounds_check_mode_);
68
+ const auto num_rows = rows_per_table[t];
69
+ auto indices_start = offsets[b_t ];
70
+ auto indices_end = offsets[b_t + 1 ];
71
+ const index_t num_indices = indices.size (0 );
51
72
52
73
if (bounds_check_mode == BoundsCheckMode::FATAL) {
53
74
CUDA_KERNEL_ASSERT (indices_start >= 0 );
@@ -58,12 +79,13 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
58
79
indices_end > num_indices) {
59
80
if (gpuAtomicIncrement (&warning[0 ]) == 0 ) {
60
81
printf (
61
- " EmbeddingBoundsCheck: (at least one) Out of bounds access for "
62
- " batch: %lld , table: %lld , indices_start: %lld, indices_end: %lld,"
82
+ " EmbeddingBoundsCheck (VBE %s) : (at least one) Out of bounds access for "
83
+ " batch: %d , table: %d , indices_start: %lld, indices_end: %lld,"
63
84
" num_indices: %lld. Setting indices_start and indices_end within "
64
85
" the range.\n " ,
65
- static_cast <int64_t >(b),
66
- static_cast <int64_t >(t),
86
+ vbe ? " true" : " false" ,
87
+ b,
88
+ t,
67
89
static_cast <int64_t >(indices_start),
68
90
static_cast <int64_t >(indices_end),
69
91
static_cast <int64_t >(num_indices));
@@ -72,16 +94,16 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
72
94
indices_start,
73
95
indices_end,
74
96
num_indices,
75
- &offsets[t * B + b ],
76
- &offsets[t * B + b + 1 ]);
97
+ &offsets[b_t ],
98
+ &offsets[b_t + 1 ]);
77
99
}
78
100
} else if (bounds_check_mode == BoundsCheckMode::IGNORE) {
79
101
adjust_offset_kernel (
80
102
indices_start,
81
103
indices_end,
82
104
num_indices,
83
- &offsets[t * B + b ],
84
- &offsets[t * B + b + 1 ]);
105
+ &offsets[b_t ],
106
+ &offsets[b_t + 1 ]);
85
107
}
86
108
87
109
const auto L = indices_end - indices_start;
@@ -100,9 +122,10 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
100
122
if (idx < 0 || idx >= num_rows) {
101
123
if (gpuAtomicIncrement (&warning[0 ]) == 0 ) {
102
124
printf (
103
- " EmbeddingBoundsCheck: (at least one) Out of bounds access for batch: %lld, table: %lld, bag element: %lld, idx: %lld, num_rows: %lld, indices_start: %lld, indices_end: %lld, T: %d, B: %d, b_t: %d. Setting idx to zero.\n " ,
104
- static_cast <int64_t >(b),
105
- static_cast <int64_t >(t),
125
+ " EmbeddingBoundsCheck (VBE %s): (at least one) Out of bounds access for batch: %d, table: %d, bag element: %lld, idx: %lld, num_rows: %lld, indices_start: %lld, indices_end: %lld, T: %d, B: %d, b_t: %d. Setting idx to zero.\n " ,
126
+ vbe ? " true" : " false" ,
127
+ b,
128
+ t,
106
129
static_cast <int64_t >(i),
107
130
static_cast <int64_t >(idx),
108
131
num_rows,
@@ -122,25 +145,27 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
122
145
}
123
146
124
147
if (bounds_check_mode == BoundsCheckMode::FATAL) {
125
- CUDA_KERNEL_ASSERT (num_indices == offsets[B * T ]);
148
+ CUDA_KERNEL_ASSERT (num_indices == offsets[total_B ]);
126
149
} else if (bounds_check_mode == BoundsCheckMode::WARNING) {
127
- if (num_indices != offsets[B * T ]) {
150
+ if (num_indices != offsets[total_B ]) {
128
151
if (gpuAtomicIncrement (&warning[0 ]) == 0 ) {
129
152
printf (
130
- " EmbeddingBoundsCheck: the last element in offsets is incorrect for "
131
- " total batch size B : %lld , total table num T: %lld , "
153
+ " EmbeddingBoundsCheck (VBE %s) : the last element in offsets is incorrect for "
154
+ " total batch size %s : %d , total table num T: %d , "
132
155
" last element in offsets: %lld, indices size: %lld. "
133
156
" Setting the last element in offsets to be indices size.\n " ,
134
- static_cast <int64_t >(B),
135
- static_cast <int64_t >(T),
136
- static_cast <int64_t >(offsets[B * T]),
157
+ vbe ? " true" : " false" ,
158
+ vbe ? " total_B" : " B" ,
159
+ vbe ? total_B : B,
160
+ T,
161
+ static_cast <int64_t >(offsets[total_B]),
137
162
static_cast <int64_t >(num_indices));
138
163
}
139
- offsets[B * T ] = num_indices;
164
+ offsets[total_B ] = num_indices;
140
165
}
141
166
} else if (bounds_check_mode == BoundsCheckMode::IGNORE) {
142
- if (num_indices != offsets[B * T ]) {
143
- offsets[B * T ] = num_indices;
167
+ if (num_indices != offsets[total_B ]) {
168
+ offsets[total_B ] = num_indices;
144
169
}
145
170
}
146
171
}
@@ -151,19 +176,23 @@ void bounds_check_indices_cuda(
151
176
Tensor& offsets,
152
177
int64_t bounds_check_mode_,
153
178
Tensor& warning,
154
- c10::optional<Tensor> weights) {
179
+ const c10::optional<Tensor>& weights,
180
+ const c10::optional<Tensor>& vbe_metadata,
181
+ const int64_t max_B) {
155
182
TENSOR_ON_CUDA_GPU (rows_per_table);
156
183
TENSOR_ON_CUDA_GPU (indices);
157
184
TENSOR_ON_CUDA_GPU (offsets);
158
185
TENSOR_ON_CUDA_GPU (warning);
159
186
TENSOR_EMPTY_OR_ON_CUDA_GPU (weights);
187
+ TENSOR_EMPTY_OR_ON_CUDA_GPU (vbe_metadata);
160
188
161
189
at::cuda::OptionalCUDAGuard device_guard;
162
190
device_guard.set_index (rows_per_table.get_device ());
163
191
164
192
const int32_t T = rows_per_table.size (0 );
165
- const int32_t B = (offsets.size (0 ) - 1 ) / T;
166
- if (B == 0 || T == 0 ) {
193
+ const int32_t total_B = offsets.size (0 ) - 1 ;
194
+ const int32_t B = (total_B) / T;
195
+ if (total_B == 0 || T == 0 ) {
167
196
return ;
168
197
}
169
198
const auto bounds_check_mode =
@@ -172,12 +201,17 @@ void bounds_check_indices_cuda(
172
201
warning.zero_ ();
173
202
}
174
203
const int64_t num_indices = indices.size (0 );
204
+ const auto vbe = vbe_metadata.has_value ();
175
205
176
- TORCH_CHECK (
177
- offsets.size (0 ) == B * T + 1 ,
178
- " offsets size " + std::to_string (offsets.size (0 )) +
179
- " is not equal to B (" + std::to_string (B) + " ) * T (" +
180
- std::to_string (T) + " ) + 1" );
206
+ if (vbe) {
207
+ TORCH_CHECK (max_B >= 0 );
208
+ } else {
209
+ TORCH_CHECK (
210
+ offsets.size (0 ) == B * T + 1 ,
211
+ " offsets size " + std::to_string (offsets.size (0 )) +
212
+ " is not equal to B (" + std::to_string (B) + " ) * T (" +
213
+ std::to_string (T) + " ) + 1" );
214
+ }
181
215
if (weights.has_value ()) {
182
216
TORCH_CHECK (
183
217
weights.value ().size (0 ) == num_indices,
@@ -186,20 +220,24 @@ void bounds_check_indices_cuda(
186
220
}
187
221
188
222
constexpr size_t kNumThreads = 256 ;
223
+ const auto max_B_ = vbe ? max_B : B;
189
224
190
225
AT_DISPATCH_INDEX_TYPES (indices.scalar_type (), " bounds_check_indices" , [&] {
191
- bounds_check_indices_kernel<index_t >
192
- <<<div_round_up(B * T, kNumThreads / fbgemm_gpu::kWarpSize ),
193
- dim3 (fbgemm_gpu::kWarpSize , kNumThreads / fbgemm_gpu::kWarpSize ),
194
- 0,
195
- at::cuda::getCurrentCUDAStream()>>>(
196
- rows_per_table
197
- .packed_accessor32<int64_t , 1 , at::RestrictPtrTraits>(),
198
- indices.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
199
- offsets.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
200
- bounds_check_mode_,
201
- warning.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
202
- FixedDivisor(B));
226
+ const auto bounds_check_kernel =
227
+ (vbe ? bounds_check_indices_kernel<index_t , true >
228
+ : bounds_check_indices_kernel<index_t , false >);
229
+ bounds_check_kernel<<<
230
+ div_round_up (max_B_ * T, kNumThreads / fbgemm_gpu::kWarpSize ),
231
+ dim3(fbgemm_gpu::kWarpSize , kNumThreads / fbgemm_gpu::kWarpSize ),
232
+ 0,
233
+ at::cuda::getCurrentCUDAStream()>>>(
234
+ rows_per_table.packed_accessor32<int64_t , 1 , at::RestrictPtrTraits>(),
235
+ indices.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
236
+ offsets.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
237
+ vbe ? vbe_metadata.value().data_ptr<int32_t>() : nullptr,
238
+ bounds_check_mode_,
239
+ warning.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
240
+ FixedDivisor(max_B_));
241
+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
203
242
});
204
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
205
243
}
0 commit comments