18
18
19
19
#include < ATen/ATen.h>
20
20
#include < c10/hip/HIPStream.h>
21
- #include < hip_bf16.h>
22
21
#include < torch/torch.h>
23
22
24
23
#include " ck/ck.hpp"
27
26
28
27
namespace fbgemm_gpu {
29
28
30
- template <typename InputType, typename OutputType>
29
+ template <typename InputType, typename OutputType>
31
30
using RowwiseGroupedKernel = std::function<OutputType(
32
31
InputType,
33
32
InputType,
@@ -46,49 +45,76 @@ using D1DataType = float;
46
45
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
47
46
using EDataType = ck::bhalf_t ;
48
47
49
- template <typename InputType, typename OutputType>
50
- RowwiseGroupedKernel<InputType, OutputType> rowwise_grouped_heuristic_dispatch (int M, int N, int K) {
48
+ template <typename InputType, typename OutputType>
49
+ RowwiseGroupedKernel<InputType, OutputType>
50
+ rowwise_grouped_heuristic_dispatch (int M, int N, int K) {
51
51
// We use shape heuristics to find the best kernel.
52
52
// To do this, we divide by the size of M and find the best
53
53
// option within that grouping.
54
54
if (M <= 16 ) {
55
55
if (N < 8192 && K <= 8192 ) {
56
- return fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1<InputType, OutputType>;
56
+ return fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1<
57
+ InputType,
58
+ OutputType>;
57
59
}
58
60
if (K <= 8192 ) {
59
- return fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2<InputType, OutputType>;
61
+ return fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2<
62
+ InputType,
63
+ OutputType>;
60
64
}
61
- return fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2<InputType, OutputType>;
65
+ return fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2<
66
+ InputType,
67
+ OutputType>;
62
68
}
63
69
if (M <= 32 ) {
64
70
if (N < 8192 && K <= 8192 ) {
65
- return fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2<InputType, OutputType>;
71
+ return fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2<
72
+ InputType,
73
+ OutputType>;
66
74
}
67
75
if (K <= 8192 ) {
68
- return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2<InputType, OutputType>;
76
+ return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2<
77
+ InputType,
78
+ OutputType>;
69
79
}
70
- return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2<InputType, OutputType>;
80
+ return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2<
81
+ InputType,
82
+ OutputType>;
71
83
}
72
84
if (M <= 64 ) {
73
- return fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<InputType, OutputType>;
85
+ return fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<
86
+ InputType,
87
+ OutputType>;
74
88
}
75
89
if (M <= 128 ) {
76
90
if (N < 8192 && K <= 8192 ) {
77
- return fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<InputType, OutputType>;
91
+ return fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<
92
+ InputType,
93
+ OutputType>;
78
94
}
79
- return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<InputType, OutputType>;
95
+ return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<
96
+ InputType,
97
+ OutputType>;
80
98
}
81
99
if (M <= 256 ) {
82
- return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<InputType, OutputType>;
100
+ return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<
101
+ InputType,
102
+ OutputType>;
83
103
}
84
104
if (M <= 512 ) {
85
105
if (K <= 8192 ) {
86
- return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1<InputType, OutputType>;
106
+ return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1<
107
+ InputType,
108
+ OutputType>;
87
109
}
88
- return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<InputType, OutputType>;
110
+ return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<
111
+ InputType,
112
+ OutputType>;
89
113
}
90
114
// Default kernel for all other shapes.
91
- return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1<InputType, OutputType>;
115
+ return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1<
116
+ InputType,
117
+ OutputType>;
92
118
}
93
119
94
120
__global__ void set_kernel_args_kernel (
@@ -139,9 +165,10 @@ void set_static_kernel_args(
139
165
if constexpr (std::is_same_v<OutputType, std::vector<at::Tensor>>) {
140
166
// Output is a list of tensors and we can access each individually.
141
167
output_ptr = reinterpret_cast <EDataType*>(output[i].data_ptr ());
142
- } else {
168
+ } else {
143
169
// Output is a single contiguous tensor and must be accessed via offset.
144
- output_ptr = reinterpret_cast <EDataType*>(output.data_ptr ()) + output_offset;
170
+ output_ptr =
171
+ reinterpret_cast <EDataType*>(output.data_ptr ()) + output_offset;
145
172
output_offset += M * N;
146
173
}
147
174
@@ -165,7 +192,6 @@ void set_static_kernel_args(
165
192
M,
166
193
N,
167
194
K);
168
-
169
195
}
170
196
}
171
197
@@ -180,8 +206,7 @@ __global__ void set_kernel_args_fixed_nk_kernel(
180
206
int M,
181
207
int N,
182
208
int K,
183
- int group_count,
184
- const int BLOCK_SIZE) {
209
+ int group_count) {
185
210
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x ;
186
211
// Each thread is responsible for setting up the arguments for one group.
187
212
if (thread_idx < group_count) {
@@ -203,33 +228,21 @@ __global__ void set_kernel_args_fixed_nk_kernel(
203
228
kernel_args[thread_idx] = kernel_group_args;
204
229
}
205
230
206
- // We also fuse in initialization of the output tensor.
207
- // We write in chunks of 2 bfloats at a time for efficiency.
208
- for (int i = 0 ; i < BLOCK_SIZE / 2 ; i++) {
209
- // Figure out where in memory we are.
210
- int output_offset = (thread_idx * BLOCK_SIZE) + (i * 2 );
211
- int current_group = output_offset / (M * N);
212
- // Skip if outside of valid groups.
213
- if (current_group < group_count) {
214
- int nonzeros = prepad_M[current_group];
215
- int current_M = output_offset / N;
216
- // Only write if this block needs initialization.
217
- // Avoid writing to final element if number of elements is odd.
218
- if (current_M >= nonzeros && output_offset < (M * N * group_count) - 1 ) {
219
- __hip_bfloat162* output_block =
220
- reinterpret_cast <__hip_bfloat162*>(output + output_offset);
221
- *output_block = __hip_bfloat162 (0 , 0 );
222
- }
231
+ // Figure out where in memory we are.
232
+ // Each thread sets one float 4 which corresponds to 8 bf16 values.
233
+ int output_offset = (thread_idx * 8 );
234
+ int current_group = output_offset / (M * N);
235
+ // Skip if outside of valid groups.
236
+ if (current_group < group_count) {
237
+ int nonzeros = prepad_M[current_group];
238
+ int current_M = (output_offset % (M * N)) / N;
239
+ // Only write zeros if we're currently in a sparse row.
240
+ if (current_M >= nonzeros) {
241
+ // Write out a block of 8 output values via vectorized float4.
242
+ float4* output_block = reinterpret_cast <float4*>(output + output_offset);
243
+ *output_block = {0 , 0 , 0 , 0 };
223
244
}
224
245
}
225
- // Handle case where there are an odd number of total elements.
226
- if (((M * N * group_count) % 2 ) != 0 &&
227
- ((M * N * group_count) - (thread_idx * BLOCK_SIZE) < BLOCK_SIZE)) {
228
- // Write out the final element.
229
- __hip_bfloat16* output_block =
230
- reinterpret_cast <__hip_bfloat16*>(output + (M * N * group_count) - 1 );
231
- *output_block = __hip_bfloat16 (0 );
232
- }
233
246
}
234
247
235
248
void set_dynamic_kernel_args (
@@ -261,9 +274,12 @@ void set_dynamic_kernel_args(
261
274
int N = WQ.size (1 );
262
275
263
276
// Launch a kernel that sets kernel argument memory.
277
+ // Each thread sets one float4 which corresponds to 8 bf16 values.
264
278
const int BLOCK_SIZE = 8 ;
279
+ TORCH_CHECK (
280
+ N % BLOCK_SIZE == 0 , " N must be divisible 8 for dynamic grouped gemm." );
265
281
int block_factor = std::max (group_count, (group_count * M * N) / BLOCK_SIZE);
266
- int blockSize = std::min (1024 , block_factor);
282
+ int blockSize = std::min (512 , block_factor);
267
283
int numBlocks = (block_factor + blockSize - 1 ) / blockSize;
268
284
set_kernel_args_fixed_nk_kernel<<<numBlocks, blockSize, 0 , stream>>>(
269
285
reinterpret_cast <KernelArguments*>(kernel_args.data_ptr ()),
@@ -276,8 +292,7 @@ void set_dynamic_kernel_args(
276
292
M,
277
293
N,
278
294
K,
279
- group_count,
280
- BLOCK_SIZE);
295
+ group_count);
281
296
}
282
297
283
298
template <typename OutputType>
@@ -347,22 +362,25 @@ OutputType _f8f8bf16_rowwise_grouped(
347
362
Y.push_back (at::empty ({M, N}, XQ[i].options ().dtype (at::kBFloat16 )));
348
363
}
349
364
}
350
- // Now handle single tensor output.
365
+ // Now handle single tensor output.
351
366
} else {
352
367
// Compute total M across groups.
353
368
int total_M = 0 ;
354
369
int N = WQ[0 ].size (0 );
355
370
for (int i = 0 ; i < group_count; i++) {
356
371
total_M += XQ[i].size (0 );
357
372
// Also make sure N is constant across shapes.
358
- TORCH_CHECK (WQ[i].size (0 ) == N, " N must be constant across groups for stacked output." );
373
+ TORCH_CHECK (
374
+ WQ[i].size (0 ) == N,
375
+ " N must be constant across groups for stacked output." );
359
376
}
360
377
if (output.has_value ()) {
361
378
Y = output.value ();
362
379
// Check that shape is expected.
363
- TORCH_CHECK (Y.size (0 ) == total_M && Y.size (1 ) == N, " Preallocated output should have size [total_M, N]." );
364
- }
365
- else {
380
+ TORCH_CHECK (
381
+ Y.size (0 ) == total_M && Y.size (1 ) == N,
382
+ " Preallocated output should have size [total_M, N]." );
383
+ } else {
366
384
Y = at::empty ({total_M, N}, XQ[0 ].options ().dtype (at::kBFloat16 ));
367
385
}
368
386
}
@@ -383,7 +401,8 @@ OutputType _f8f8bf16_rowwise_grouped(
383
401
MaxK = max (MaxK, XQ[i].size (1 ));
384
402
}
385
403
RowwiseGroupedKernel<at::TensorList, OutputType> selected_kernel =
386
- rowwise_grouped_heuristic_dispatch<at::TensorList, OutputType>(MaxM, MaxN, MaxK);
404
+ rowwise_grouped_heuristic_dispatch<at::TensorList, OutputType>(
405
+ MaxM, MaxN, MaxK);
387
406
return selected_kernel (XQ, WQ, x_scale, w_scale, kernel_args, Y);
388
407
}
389
408
@@ -394,7 +413,8 @@ std::vector<at::Tensor> f8f8bf16_rowwise_grouped(
394
413
at::TensorList x_scale,
395
414
at::TensorList w_scale,
396
415
std::optional<std::vector<at::Tensor>> output = std::nullopt) {
397
- return _f8f8bf16_rowwise_grouped<std::vector<at::Tensor>>(XQ, WQ, x_scale, w_scale, output);
416
+ return _f8f8bf16_rowwise_grouped<std::vector<at::Tensor>>(
417
+ XQ, WQ, x_scale, w_scale, output);
398
418
}
399
419
400
420
// Wrapper function for list input single tensor output.
@@ -404,7 +424,8 @@ at::Tensor f8f8bf16_rowwise_grouped_stacked(
404
424
at::TensorList x_scale,
405
425
at::TensorList w_scale,
406
426
std::optional<at::Tensor> output = std::nullopt) {
407
- return _f8f8bf16_rowwise_grouped<at::Tensor>(XQ, WQ, x_scale, w_scale, output);
427
+ return _f8f8bf16_rowwise_grouped<at::Tensor>(
428
+ XQ, WQ, x_scale, w_scale, output);
408
429
}
409
430
410
431
at::Tensor f8f8bf16_rowwise_grouped_dynamic (
@@ -452,13 +473,7 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
452
473
{static_cast <long >(group_count * sizeof (KernelArguments))},
453
474
XQ.options ().dtype (at::kByte ));
454
475
set_dynamic_kernel_args (
455
- kernel_args,
456
- XQ,
457
- WQ,
458
- x_scale,
459
- w_scale,
460
- Y,
461
- zero_start_index_M);
476
+ kernel_args, XQ, WQ, x_scale, w_scale, Y, zero_start_index_M);
462
477
463
478
RowwiseGroupedKernel<at::Tensor, at::Tensor> selected_kernel =
464
479
rowwise_grouped_heuristic_dispatch<at::Tensor, at::Tensor>(M, N, K);
0 commit comments