@@ -107,30 +107,14 @@ __global__ void set_dynamic_kernel_args_kernel(
107
107
StrideA* stride_a_ptr,
108
108
StrideB* stride_b_ptr,
109
109
StrideC* stride_c_ptr,
110
- std::optional<int64_t *> zero_start_index_M,
111
- std::optional<int64_t *> M_sizes) {
110
+ int64_t * zero_start_index_M) {
112
111
uint32_t group_index = blockIdx .x * blockDim .x + threadIdx .x ;
113
112
// If this thread corresponds to a valid group, write kernel args to device
114
113
// memory.
115
114
if (group_index < G) {
116
115
// Compute shape for this group.
117
- int offset_M;
118
- int kernel_M;
119
- if (zero_start_index_M.has_value ()) {
120
- // For inputs with padding, M is fixed and the number of rows
121
- // to operate on is available in zero_start_index_M.
122
- kernel_M = zero_start_index_M.value ()[group_index];
123
- offset_M = group_index * M;
124
- } else {
125
- // M for this group is pulled directly from M_sizes.
126
- kernel_M = M_sizes.value ()[group_index];
127
- // We compute the offset by getting the cumulative sum over
128
- // prior groups.
129
- offset_M = 0 ;
130
- for (int i = 0 ; i < group_index; i++) {
131
- offset_M += M_sizes.value ()[i];
132
- }
133
- }
116
+ int kernel_M = zero_start_index_M[group_index];
117
+ int offset_M = group_index * M;
134
118
// Set the problem shape for this group.
135
119
problem_shape_ptr[group_index] = ProblemShape (N, kernel_M, K);
136
120
// Set input pointers.
@@ -148,6 +132,82 @@ __global__ void set_dynamic_kernel_args_kernel(
148
132
}
149
133
}
150
134
135
+ template <
136
+ typename ProblemShape,
137
+ typename ElementA,
138
+ typename ElementB,
139
+ typename ElementC,
140
+ typename ElementComputeEpilogue,
141
+ typename StrideA,
142
+ typename StrideB,
143
+ typename StrideC>
144
+ __global__ void set_stacked_kernel_args_kernel (
145
+ int G,
146
+ int N,
147
+ int K,
148
+ ProblemShape* problem_shape_ptr,
149
+ ElementA* xq,
150
+ const ElementA** xq_ptr,
151
+ ElementB* wq,
152
+ const ElementB** wq_ptr,
153
+ ElementComputeEpilogue* x_scale,
154
+ const ElementComputeEpilogue** x_scale_ptr,
155
+ ElementComputeEpilogue* w_scale,
156
+ const ElementComputeEpilogue** w_scale_ptr,
157
+ ElementC* output,
158
+ ElementC** output_ptr,
159
+ StrideA* stride_a_ptr,
160
+ StrideB* stride_b_ptr,
161
+ StrideC* stride_c_ptr,
162
+ int64_t * M_sizes) {
163
+ uint32_t group_index = blockIdx .x * blockDim .x + threadIdx .x ;
164
+ // If this thread corresponds to a valid group, write kernel args to device
165
+ // memory.
166
+ if (group_index < G) {
167
+ // Its possible that we're only writing a subset of the groups to
168
+ // kernel args. To do this, we need to set all groups initially to empty.
169
+ // and keep a problem counter for the number of non-empty groups.
170
+ __shared__ int non_zero_counter;
171
+ // Initialize counter in first group.
172
+ if (group_index == 0 ) {
173
+ non_zero_counter = 0 ;
174
+ }
175
+ // Set problem shapes to empty by default.
176
+ problem_shape_ptr[group_index] = ProblemShape (0 , 0 , 0 );
177
+ // Sync threads to get consistent state in the block.
178
+ __syncthreads ();
179
+
180
+ // Compute shape for this group.
181
+ // M for this group is pulled directly from M_sizes.
182
+ int M = M_sizes[group_index];
183
+ // Only proceed to writing kernel args if this group is non-empty.
184
+ if (M > 0 ) {
185
+ // Get the index for this group atomically.
186
+ int non_zero_idx = atomicAdd (&non_zero_counter, 1 );
187
+ // We compute the offset by getting the cumulative sum over
188
+ // prior groups.
189
+ int offset_M = 0 ;
190
+ for (int i = 0 ; i < group_index; i++) {
191
+ offset_M += M_sizes[i];
192
+ }
193
+ // Set the problem shape for this group.
194
+ problem_shape_ptr[non_zero_idx] = ProblemShape (N, M, K);
195
+ // Set input pointers.
196
+ xq_ptr[non_zero_idx] = xq + (offset_M * K);
197
+ wq_ptr[non_zero_idx] = wq + (group_index * N * K);
198
+ x_scale_ptr[non_zero_idx] = x_scale + offset_M;
199
+ w_scale_ptr[non_zero_idx] = w_scale + (group_index * N);
200
+ output_ptr[non_zero_idx] = output + (offset_M * N);
201
+ stride_a_ptr[non_zero_idx] = cutlass::make_cute_packed_stride (
202
+ StrideA{}, cute::make_shape (M, K, 1 ));
203
+ stride_b_ptr[non_zero_idx] = cutlass::make_cute_packed_stride (
204
+ StrideB{}, cute::make_shape (N, K, 1 ));
205
+ stride_c_ptr[non_zero_idx] = cutlass::make_cute_packed_stride (
206
+ StrideC{}, cute::make_shape (N, M, 1 ));
207
+ }
208
+ }
209
+ }
210
+
151
211
template <
152
212
typename InputType,
153
213
int TB_M,
@@ -178,6 +238,8 @@ at::Tensor f8f8bf16_rowwise_grouped_impl(
178
238
G = WQ.size (0 );
179
239
options = XQ.options ();
180
240
}
241
+ // The number of groups the kernel uses may vary.
242
+ int kernel_groups = G;
181
243
// Return early if there are no elements in the output.
182
244
if (output.numel () == 0 ) {
183
245
return output;
@@ -421,41 +483,60 @@ at::Tensor f8f8bf16_rowwise_grouped_impl(
421
483
int M = XQ.size (XQ.dim () - 2 );
422
484
int N = WQ.size (1 );
423
485
int K = WQ.size (2 );
424
- std::optional<int64_t *> zero_start_index_M_ptr = std::nullopt;
425
- std::optional<int64_t *> M_sizes_ptr = std::nullopt;
426
486
if (zero_start_index_M.has_value ()) {
427
- zero_start_index_M_ptr =
487
+ int64_t * zero_start_index_M_ptr =
428
488
reinterpret_cast <int64_t *>(zero_start_index_M.value ().data_ptr ());
489
+ set_dynamic_kernel_args_kernel<<<1 , G, 0 , stream>>> (
490
+ G,
491
+ M,
492
+ N,
493
+ K,
494
+ problem_shape_ptr,
495
+ reinterpret_cast <ElementA*>(XQ.data_ptr ()),
496
+ xq_ptr,
497
+ reinterpret_cast <ElementB*>(WQ.data_ptr ()),
498
+ wq_ptr,
499
+ reinterpret_cast <ElementComputeEpilogue*>(x_scale.data_ptr ()),
500
+ x_scale_ptr,
501
+ reinterpret_cast <ElementComputeEpilogue*>(w_scale.data_ptr ()),
502
+ w_scale_ptr,
503
+ reinterpret_cast <ElementC*>(output.data_ptr ()),
504
+ output_ptr,
505
+ stride_a_ptr,
506
+ stride_b_ptr,
507
+ stride_c_ptr,
508
+ zero_start_index_M_ptr);
509
+ } else {
510
+ int64_t * M_sizes_ptr =
511
+ reinterpret_cast <int64_t *>(M_sizes.value ().data_ptr ());
512
+ set_stacked_kernel_args_kernel<<<1 , G, 0 , stream>>> (
513
+ G,
514
+ N,
515
+ K,
516
+ problem_shape_ptr,
517
+ reinterpret_cast <ElementA*>(XQ.data_ptr ()),
518
+ xq_ptr,
519
+ reinterpret_cast <ElementB*>(WQ.data_ptr ()),
520
+ wq_ptr,
521
+ reinterpret_cast <ElementComputeEpilogue*>(x_scale.data_ptr ()),
522
+ x_scale_ptr,
523
+ reinterpret_cast <ElementComputeEpilogue*>(w_scale.data_ptr ()),
524
+ w_scale_ptr,
525
+ reinterpret_cast <ElementC*>(output.data_ptr ()),
526
+ output_ptr,
527
+ stride_a_ptr,
528
+ stride_b_ptr,
529
+ stride_c_ptr,
530
+ M_sizes_ptr);
531
+ // Set the number of groups to the kernel to be at most the number of
532
+ // non-zero rows.
533
+ kernel_groups = std::min (M, G);
429
534
}
430
- if (M_sizes.has_value ()) {
431
- M_sizes_ptr = reinterpret_cast <int64_t *>(M_sizes.value ().data_ptr ());
432
- }
433
- set_dynamic_kernel_args_kernel<<<1 , G, 0 , stream>>> (
434
- G,
435
- M,
436
- N,
437
- K,
438
- problem_shape_ptr,
439
- reinterpret_cast <ElementA*>(XQ.data_ptr ()),
440
- xq_ptr,
441
- reinterpret_cast <ElementB*>(WQ.data_ptr ()),
442
- wq_ptr,
443
- reinterpret_cast <ElementComputeEpilogue*>(x_scale.data_ptr ()),
444
- x_scale_ptr,
445
- reinterpret_cast <ElementComputeEpilogue*>(w_scale.data_ptr ()),
446
- w_scale_ptr,
447
- reinterpret_cast <ElementC*>(output.data_ptr ()),
448
- output_ptr,
449
- stride_a_ptr,
450
- stride_b_ptr,
451
- stride_c_ptr,
452
- zero_start_index_M_ptr,
453
- M_sizes_ptr);
454
535
}
455
536
456
537
typename Gemm::Arguments arguments{
457
538
cutlass::gemm::GemmUniversalMode::kGrouped ,
458
- {G , problem_shape_ptr, nullptr },
539
+ {kernel_groups , problem_shape_ptr, nullptr },
459
540
{wq_ptr, stride_b_ptr, xq_ptr, stride_a_ptr},
460
541
{{}, nullptr , stride_c_ptr, output_ptr, stride_c_ptr}};
461
542
0 commit comments