@@ -195,7 +195,7 @@ void set_static_kernel_args(
195
195
}
196
196
}
197
197
198
- __global__ void set_kernel_args_fixed_nk_kernel (
198
+ __global__ void set_kernel_args_fixed_nk_kernel_only (
199
199
KernelArguments* kernel_args,
200
200
ADataType* XQ,
201
201
BDataType* WQ,
@@ -206,8 +206,41 @@ __global__ void set_kernel_args_fixed_nk_kernel(
206
206
int M,
207
207
int N,
208
208
int K,
209
- int group_count,
210
- bool zeroing_output_tensor) {
209
+ int group_count) {
210
+ int thread_idx = blockIdx.x * blockDim.x + threadIdx.x ;
211
+ // Each thread is responsible for setting up the arguments for one group.
212
+ if (thread_idx < group_count) {
213
+ // Compute offsets for this group.
214
+ int group_M = prepad_M[thread_idx];
215
+ KernelArguments kernel_group_args = {
216
+ XQ + (thread_idx * M * K),
217
+ WQ + (thread_idx * N * K),
218
+ {w_scale + (thread_idx * N), x_scale + (thread_idx * M)},
219
+ output + (thread_idx * M * N),
220
+ group_M,
221
+ N,
222
+ K,
223
+ K,
224
+ K,
225
+ {0 , 0 },
226
+ N};
227
+ // Write kernel args to memory.
228
+ kernel_args[thread_idx] = kernel_group_args;
229
+ }
230
+ }
231
+
232
+ __global__ void set_kernel_args_fixed_nk_kernel_zeroing (
233
+ KernelArguments* kernel_args,
234
+ ADataType* XQ,
235
+ BDataType* WQ,
236
+ D0DataType* w_scale,
237
+ D1DataType* x_scale,
238
+ EDataType* output,
239
+ int64_t * prepad_M,
240
+ int M,
241
+ int N,
242
+ int K,
243
+ int group_count) {
211
244
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x ;
212
245
// Each thread is responsible for setting up the arguments for one group.
213
246
if (thread_idx < group_count) {
@@ -228,7 +261,6 @@ __global__ void set_kernel_args_fixed_nk_kernel(
228
261
// Write kernel args to memory.
229
262
kernel_args[thread_idx] = kernel_group_args;
230
263
}
231
- if (!zeroing_output_tensor) return ;
232
264
233
265
// Figure out where in memory we are.
234
266
// Each thread sets one float 4 which corresponds to 8 bf16 values.
@@ -284,19 +316,33 @@ void set_dynamic_kernel_args(
284
316
int block_factor = std::max (group_count, (group_count * M * N) / BLOCK_SIZE);
285
317
int blockSize = std::min (512 , block_factor);
286
318
int numBlocks = (block_factor + blockSize - 1 ) / blockSize;
287
- set_kernel_args_fixed_nk_kernel<<<numBlocks, blockSize, 0 , stream>>>(
288
- reinterpret_cast <KernelArguments*>(kernel_args.data_ptr ()),
289
- reinterpret_cast <ADataType*>(XQ.data_ptr ()),
290
- reinterpret_cast <BDataType*>(WQ.data_ptr ()),
291
- reinterpret_cast <D0DataType*>(w_scale.data_ptr ()),
292
- reinterpret_cast <D1DataType*>(x_scale.data_ptr ()),
293
- reinterpret_cast <EDataType*>(output.data_ptr ()),
294
- reinterpret_cast <int64_t *>(zero_start_index_M.data_ptr ()),
295
- M,
296
- N,
297
- K,
298
- group_count,
299
- zeroing_output_tensor);
319
+ if (zeroing_output_tensor) {
320
+ set_kernel_args_fixed_nk_kernel_zeroing<<<numBlocks, blockSize, 0 , stream>>>(
321
+ reinterpret_cast <KernelArguments*>(kernel_args.data_ptr ()),
322
+ reinterpret_cast <ADataType*>(XQ.data_ptr ()),
323
+ reinterpret_cast <BDataType*>(WQ.data_ptr ()),
324
+ reinterpret_cast <D0DataType*>(w_scale.data_ptr ()),
325
+ reinterpret_cast <D1DataType*>(x_scale.data_ptr ()),
326
+ reinterpret_cast <EDataType*>(output.data_ptr ()),
327
+ reinterpret_cast <int64_t *>(zero_start_index_M.data_ptr ()),
328
+ M,
329
+ N,
330
+ K,
331
+ group_count);
332
+ } else {
333
+ set_kernel_args_fixed_nk_kernel_only<<<1 , group_count, 0 , stream>>>(
334
+ reinterpret_cast <KernelArguments*>(kernel_args.data_ptr ()),
335
+ reinterpret_cast <ADataType*>(XQ.data_ptr ()),
336
+ reinterpret_cast <BDataType*>(WQ.data_ptr ()),
337
+ reinterpret_cast <D0DataType*>(w_scale.data_ptr ()),
338
+ reinterpret_cast <D1DataType*>(x_scale.data_ptr ()),
339
+ reinterpret_cast <EDataType*>(output.data_ptr ()),
340
+ reinterpret_cast <int64_t *>(zero_start_index_M.data_ptr ()),
341
+ M,
342
+ N,
343
+ K,
344
+ group_count);
345
+ }
300
346
}
301
347
302
348
template <typename OutputType>
0 commit comments