@@ -118,7 +118,7 @@ __global__ void __launch_bounds__(THREADS_PER_THREAD_BLOCK) LogitsBitmaskKernel(
118
118
}
119
119
}
120
120
121
- template <typename T, typename = std::enable_if_t <std::is_integral<T>::value>>
121
+ template <typename T, typename = std::enable_if_t <std::is_integral<T>::value> >
122
122
constexpr auto CeilDiv (T numerator, T denominator) {
123
123
return (numerator + denominator - 1 ) / denominator;
124
124
}
@@ -142,19 +142,19 @@ void ApplyTokenBitmaskInplaceDispatchToBitsPerThread(
142
142
if (num_bits_per_thread <= 4 && kAlignment <= 4 ) {
143
143
const dim3 grid (CeilDiv (vocab_size, THREADS_PER_THREAD_BLOCK * 4 ), num_rows);
144
144
LogitsBitmaskKernel<T, PackedT, 4 >
145
- <<<grid, block, 0 , stream>> > (logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
145
+ <<<grid, block, 0 , stream> > >(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
146
146
} else if (num_bits_per_thread <= 8 && kAlignment <= 8 ) {
147
147
const dim3 grid (CeilDiv (vocab_size, THREADS_PER_THREAD_BLOCK * 8 ), num_rows);
148
148
LogitsBitmaskKernel<T, PackedT, 8 >
149
- <<<grid, block, 0 , stream>> > (logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
149
+ <<<grid, block, 0 , stream> > >(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
150
150
} else if (num_bits_per_thread <= 16 && kAlignment <= 16 ) {
151
151
const dim3 grid (CeilDiv (vocab_size, THREADS_PER_THREAD_BLOCK * 16 ), num_rows);
152
152
LogitsBitmaskKernel<T, PackedT, 16 >
153
- <<<grid, block, 0 , stream>> > (logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
153
+ <<<grid, block, 0 , stream> > >(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
154
154
} else {
155
155
const dim3 grid (CeilDiv (vocab_size, THREADS_PER_THREAD_BLOCK * 32 ), num_rows);
156
156
LogitsBitmaskKernel<T, PackedT, 32 >
157
- <<<grid, block, 0 , stream>> > (logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
157
+ <<<grid, block, 0 , stream> > >(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride);
158
158
}
159
159
}
160
160
0 commit comments