@@ -23,31 +23,46 @@ __device__ void adjust_offset_kernel(
23
23
*offset_acc_end = indices_end;
24
24
}
25
25
26
- template <typename index_t >
26
+ template <typename index_t , bool vle >
27
27
__global__ __launch_bounds__ (kMaxThreads ) void bounds_check_indices_kernel(
28
28
const at::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits>
29
29
rows_per_table,
30
30
at::PackedTensorAccessor32<index_t , 1 , at::RestrictPtrTraits> indices,
31
31
at::PackedTensorAccessor32<index_t , 1 , at::RestrictPtrTraits> offsets,
32
+ const int32_t * const vle_metadata,
32
33
const int64_t bounds_check_mode_,
33
34
at::PackedTensorAccessor32<int64_t , 1 , at::RestrictPtrTraits> warning,
34
35
FixedDivisor fd) {
35
36
int32_t T = rows_per_table.size (0 );
36
- int32_t B = (offsets.size (0 ) - 1 ) / T;
37
-
38
37
int32_t b_t = blockIdx .x * blockDim .y + threadIdx .y ;
39
- int32_t b; // = b_t % B;
40
- int32_t t; // = b_t / B;
41
- fd.DivMod (b_t , &t, &b);
42
- if (t >= T) {
38
+ int32_t b;
39
+ int32_t t;
40
+ int32_t B = 0 ;
41
+ int32_t total_B = offsets.size (0 ) - 1 ;
42
+
43
+ if (b_t >= total_B) {
43
44
return ;
44
45
}
45
- auto bounds_check_mode = static_cast <BoundsCheckMode>(bounds_check_mode_);
46
46
47
- auto num_rows = rows_per_table[t];
48
- auto indices_start = offsets[t * B + b];
49
- auto indices_end = offsets[t * B + b + 1 ];
50
- index_t num_indices = indices.size (0 );
47
+ if (vle) {
48
+ if (threadIdx .x == 0 ) {
49
+ // binary_search_range takes inclusive sumscan array
50
+ binary_search_range (&t, vle_metadata + 1 , b_t , T);
51
+ b = b_t - vle_metadata[t];
52
+ }
53
+ t = shfl_sync (t, 0 );
54
+ b = shfl_sync (b, 0 );
55
+ } else {
56
+ B = total_B / T;
57
+ fd.DivMod (b_t , &t, &b);
58
+ }
59
+
60
+ const auto bounds_check_mode =
61
+ static_cast <BoundsCheckMode>(bounds_check_mode_);
62
+ const auto num_rows = rows_per_table[t];
63
+ auto indices_start = offsets[b_t ];
64
+ auto indices_end = offsets[b_t + 1 ];
65
+ const index_t num_indices = indices.size (0 );
51
66
52
67
if (bounds_check_mode == BoundsCheckMode::FATAL) {
53
68
CUDA_KERNEL_ASSERT (indices_start >= 0 );
@@ -58,12 +73,13 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
58
73
indices_end > num_indices) {
59
74
if (gpuAtomicIncrement (&warning[0 ]) == 0 ) {
60
75
printf (
61
- " EmbeddingBoundsCheck: (at least one) Out of bounds access for "
62
- " batch: %lld , table: %lld , indices_start: %lld, indices_end: %lld,"
76
+ " EmbeddingBoundsCheck (VLE %s) : (at least one) Out of bounds access for "
77
+ " batch: %d , table: %d , indices_start: %lld, indices_end: %lld,"
63
78
" num_indices: %lld. Setting indices_start and indices_end within "
64
79
" the range.\n " ,
65
- static_cast <int64_t >(b),
66
- static_cast <int64_t >(t),
80
+ vle ? " true" : " false" ,
81
+ b,
82
+ t,
67
83
static_cast <int64_t >(indices_start),
68
84
static_cast <int64_t >(indices_end),
69
85
static_cast <int64_t >(num_indices));
@@ -72,16 +88,16 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
72
88
indices_start,
73
89
indices_end,
74
90
num_indices,
75
- &offsets[t * B + b ],
76
- &offsets[t * B + b + 1 ]);
91
+ &offsets[b_t ],
92
+ &offsets[b_t + 1 ]);
77
93
}
78
94
} else if (bounds_check_mode == BoundsCheckMode::IGNORE) {
79
95
adjust_offset_kernel (
80
96
indices_start,
81
97
indices_end,
82
98
num_indices,
83
- &offsets[t * B + b ],
84
- &offsets[t * B + b + 1 ]);
99
+ &offsets[b_t ],
100
+ &offsets[b_t + 1 ]);
85
101
}
86
102
87
103
const auto L = indices_end - indices_start;
@@ -100,9 +116,10 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
100
116
if (idx < 0 || idx >= num_rows) {
101
117
if (gpuAtomicIncrement (&warning[0 ]) == 0 ) {
102
118
printf (
103
- " EmbeddingBoundsCheck: (at least one) Out of bounds access for batch: %lld, table: %lld, bag element: %lld, idx: %lld, num_rows: %lld, indices_start: %lld, indices_end: %lld, T: %d, B: %d, b_t: %d. Setting idx to zero.\n " ,
104
- static_cast <int64_t >(b),
105
- static_cast <int64_t >(t),
119
+ " EmbeddingBoundsCheck (VLE %s): (at least one) Out of bounds access for batch: %d, table: %d, bag element: %lld, idx: %lld, num_rows: %lld, indices_start: %lld, indices_end: %lld, T: %d, B: %d, b_t: %d. Setting idx to zero.\n " ,
120
+ vle ? " true" : " false" ,
121
+ b,
122
+ t,
106
123
static_cast <int64_t >(i),
107
124
static_cast <int64_t >(idx),
108
125
num_rows,
@@ -122,25 +139,27 @@ __global__ __launch_bounds__(kMaxThreads) void bounds_check_indices_kernel(
122
139
}
123
140
124
141
if (bounds_check_mode == BoundsCheckMode::FATAL) {
125
- CUDA_KERNEL_ASSERT (num_indices == offsets[B * T ]);
142
+ CUDA_KERNEL_ASSERT (num_indices == offsets[total_B ]);
126
143
} else if (bounds_check_mode == BoundsCheckMode::WARNING) {
127
- if (num_indices != offsets[B * T ]) {
144
+ if (num_indices != offsets[total_B ]) {
128
145
if (gpuAtomicIncrement (&warning[0 ]) == 0 ) {
129
146
printf (
130
- " EmbeddingBoundsCheck: the last element in offsets is incorrect for "
131
- " total batch size B : %lld , total table num T: %lld , "
147
+ " EmbeddingBoundsCheck (VLE %s) : the last element in offsets is incorrect for "
148
+ " total batch size %s : %d , total table num T: %d , "
132
149
" last element in offsets: %lld, indices size: %lld. "
133
150
" Setting the last element in offsets to be indices size.\n " ,
134
- static_cast <int64_t >(B),
135
- static_cast <int64_t >(T),
136
- static_cast <int64_t >(offsets[B * T]),
151
+ vle ? " true" : " false" ,
152
+ vle ? " total_B" : " B" ,
153
+ vle ? total_B : B,
154
+ T,
155
+ static_cast <int64_t >(offsets[total_B]),
137
156
static_cast <int64_t >(num_indices));
138
157
}
139
- offsets[B * T ] = num_indices;
158
+ offsets[total_B ] = num_indices;
140
159
}
141
160
} else if (bounds_check_mode == BoundsCheckMode::IGNORE) {
142
- if (num_indices != offsets[B * T ]) {
143
- offsets[B * T ] = num_indices;
161
+ if (num_indices != offsets[total_B ]) {
162
+ offsets[total_B ] = num_indices;
144
163
}
145
164
}
146
165
}
@@ -151,19 +170,22 @@ void bounds_check_indices_cuda(
151
170
Tensor& offsets,
152
171
int64_t bounds_check_mode_,
153
172
Tensor& warning,
154
- c10::optional<Tensor> weights) {
173
+ const c10::optional<Tensor>& weights,
174
+ const c10::optional<Tensor>& vle_metadata) {
155
175
TENSOR_ON_CUDA_GPU (rows_per_table);
156
176
TENSOR_ON_CUDA_GPU (indices);
157
177
TENSOR_ON_CUDA_GPU (offsets);
158
178
TENSOR_ON_CUDA_GPU (warning);
159
179
TENSOR_EMPTY_OR_ON_CUDA_GPU (weights);
180
+ TENSOR_EMPTY_OR_ON_CUDA_GPU (vle_metadata);
160
181
161
182
at::cuda::OptionalCUDAGuard device_guard;
162
183
device_guard.set_index (rows_per_table.get_device ());
163
184
164
185
const int32_t T = rows_per_table.size (0 );
165
- const int32_t B = (offsets.size (0 ) - 1 ) / T;
166
- if (B == 0 || T == 0 ) {
186
+ const int32_t total_B = offsets.size (0 ) - 1 ;
187
+ const int32_t B = (total_B) / T;
188
+ if (total_B == 0 || T == 0 ) {
167
189
return ;
168
190
}
169
191
const auto bounds_check_mode =
@@ -173,11 +195,13 @@ void bounds_check_indices_cuda(
173
195
}
174
196
const int64_t num_indices = indices.size (0 );
175
197
176
- TORCH_CHECK (
177
- offsets.size (0 ) == B * T + 1 ,
178
- " offsets size " + std::to_string (offsets.size (0 )) +
179
- " is not equal to B (" + std::to_string (B) + " ) * T (" +
180
- std::to_string (T) + " ) + 1" );
198
+ if (!vle_metadata.has_value ()) {
199
+ TORCH_CHECK (
200
+ offsets.size (0 ) == B * T + 1 ,
201
+ " offsets size " + std::to_string (offsets.size (0 )) +
202
+ " is not equal to B (" + std::to_string (B) + " ) * T (" +
203
+ std::to_string (T) + " ) + 1" );
204
+ }
181
205
if (weights.has_value ()) {
182
206
TORCH_CHECK (
183
207
weights.value ().size (0 ) == num_indices,
@@ -187,19 +211,30 @@ void bounds_check_indices_cuda(
187
211
188
212
constexpr size_t kNumThreads = 256 ;
189
213
214
+ #define INVOKE_BOUNDS_CHECK_INDICES_KERNEL (VAR_BATCH_SIZE, VAR_B_METADATA ) \
215
+ bounds_check_indices_kernel<index_t , VAR_BATCH_SIZE> \
216
+ <<<div_round_up(total_B, kNumThreads / fbgemm_gpu::kWarpSize ), \
217
+ dim3 (fbgemm_gpu::kWarpSize , kNumThreads / fbgemm_gpu::kWarpSize ), \
218
+ 0, \
219
+ at::cuda::getCurrentCUDAStream()>>>( \
220
+ rows_per_table \
221
+ .packed_accessor32<int64_t , 1 , at::RestrictPtrTraits>(), \
222
+ indices.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(), \
223
+ offsets.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(), \
224
+ VAR_B_METADATA, \
225
+ bounds_check_mode_, \
226
+ warning.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(), \
227
+ FixedDivisor(B)); \
228
+ C10_CUDA_KERNEL_LAUNCH_CHECK ()
229
+
190
230
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "bounds_check_indices", [&] {
191
- bounds_check_indices_kernel<index_t >
192
- <<<div_round_up(B * T, kNumThreads / fbgemm_gpu::kWarpSize ),
193
- dim3 (fbgemm_gpu::kWarpSize , kNumThreads / fbgemm_gpu::kWarpSize ),
194
- 0,
195
- at::cuda::getCurrentCUDAStream()>>>(
196
- rows_per_table
197
- .packed_accessor32<int64_t , 1 , at::RestrictPtrTraits>(),
198
- indices.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
199
- offsets.packed_accessor32<index_t, 1, at::RestrictPtrTraits>(),
200
- bounds_check_mode_,
201
- warning.packed_accessor32<int64_t, 1, at::RestrictPtrTraits>(),
202
- FixedDivisor(B));
231
+ if (vle_metadata.has_value ()) {
232
+ INVOKE_BOUNDS_CHECK_INDICES_KERNEL (
233
+ true , vle_metadata.value ().data_ptr <int32_t >());
234
+ } else {
235
+ INVOKE_BOUNDS_CHECK_INDICES_KERNEL (false , nullptr );
236
+ }
203
237
});
204
- C10_CUDA_KERNEL_LAUNCH_CHECK ();
238
+
239
+ #undef INVOKE_BOUNDS_CHECK_INDICES_KERNEL
205
240
}
0 commit comments