Skip to content

Commit 2cef43a

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Optimize zero fill (pytorch#3666)
Summary: Pull Request resolved: pytorch#3666 X-link: facebookresearch/FBGEMM#741 We were spending more time then necessary setting the output tensor to zero during kernel setup. Assuming that N is divisible by 8 and using float4 vectorized writes saves us a good bit of time. This can yield as much as a 10% overall speedup for fp8 grouped gemm. Reviewed By: jiawenliu64, mxz297 Differential Revision: D69267443 fbshipit-source-id: 527b81f69fc3792c2b41fad0ba8f123de5bafde6
1 parent d564c8c commit 2cef43a

File tree

1 file changed

+80
-65
lines changed

1 file changed

+80
-65
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_grouped/fp8_rowwise_grouped_gemm.hip

Lines changed: 80 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
#include <ATen/ATen.h>
2020
#include <c10/hip/HIPStream.h>
21-
#include <hip_bf16.h>
2221
#include <torch/torch.h>
2322

2423
#include "ck/ck.hpp"
@@ -27,7 +26,7 @@
2726

2827
namespace fbgemm_gpu {
2928

30-
template<typename InputType, typename OutputType>
29+
template <typename InputType, typename OutputType>
3130
using RowwiseGroupedKernel = std::function<OutputType(
3231
InputType,
3332
InputType,
@@ -46,49 +45,76 @@ using D1DataType = float;
4645
using DsDataType = ck::Tuple<D0DataType, D1DataType>;
4746
using EDataType = ck::bhalf_t;
4847

49-
template<typename InputType, typename OutputType>
50-
RowwiseGroupedKernel<InputType, OutputType> rowwise_grouped_heuristic_dispatch(int M, int N, int K) {
48+
template <typename InputType, typename OutputType>
49+
RowwiseGroupedKernel<InputType, OutputType>
50+
rowwise_grouped_heuristic_dispatch(int M, int N, int K) {
5151
// We use shape heuristics to find the best kernel.
5252
// To do this, we divide by the size of M and find the best
5353
// option within that grouping.
5454
if (M <= 16) {
5555
if (N < 8192 && K <= 8192) {
56-
return fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1<InputType, OutputType>;
56+
return fp8_rowwise_grouped_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1<
57+
InputType,
58+
OutputType>;
5759
}
5860
if (K <= 8192) {
59-
return fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2<InputType, OutputType>;
61+
return fp8_rowwise_grouped_128x16x64x128_16x16_1x2_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2<
62+
InputType,
63+
OutputType>;
6064
}
61-
return fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2<InputType, OutputType>;
65+
return fp8_rowwise_grouped_128x16x32x256_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2<
66+
InputType,
67+
OutputType>;
6268
}
6369
if (M <= 32) {
6470
if (N < 8192 && K <= 8192) {
65-
return fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2<InputType, OutputType>;
71+
return fp8_rowwise_grouped_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2<
72+
InputType,
73+
OutputType>;
6674
}
6775
if (K <= 8192) {
68-
return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2<InputType, OutputType>;
76+
return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2<
77+
InputType,
78+
OutputType>;
6979
}
70-
return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2<InputType, OutputType>;
80+
return fp8_rowwise_grouped_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2<
81+
InputType,
82+
OutputType>;
7183
}
7284
if (M <= 64) {
73-
return fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<InputType, OutputType>;
85+
return fp8_rowwise_grouped_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<
86+
InputType,
87+
OutputType>;
7488
}
7589
if (M <= 128) {
7690
if (N < 8192 && K <= 8192) {
77-
return fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<InputType, OutputType>;
91+
return fp8_rowwise_grouped_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<
92+
InputType,
93+
OutputType>;
7894
}
79-
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<InputType, OutputType>;
95+
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<
96+
InputType,
97+
OutputType>;
8098
}
8199
if (M <= 256) {
82-
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<InputType, OutputType>;
100+
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<
101+
InputType,
102+
OutputType>;
83103
}
84104
if (M <= 512) {
85105
if (K <= 8192) {
86-
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1<InputType, OutputType>;
106+
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1<
107+
InputType,
108+
OutputType>;
87109
}
88-
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<InputType, OutputType>;
110+
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3<
111+
InputType,
112+
OutputType>;
89113
}
90114
// Default kernel for all other shapes.
91-
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1<InputType, OutputType>;
115+
return fp8_rowwise_grouped_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1<
116+
InputType,
117+
OutputType>;
92118
}
93119

94120
__global__ void set_kernel_args_kernel(
@@ -139,9 +165,10 @@ void set_static_kernel_args(
139165
if constexpr (std::is_same_v<OutputType, std::vector<at::Tensor>>) {
140166
// Output is a list of tensors and we can access each individually.
141167
output_ptr = reinterpret_cast<EDataType*>(output[i].data_ptr());
142-
} else{
168+
} else {
143169
// Output is a single contiguous tensor and must be accessed via offset.
144-
output_ptr = reinterpret_cast<EDataType*>(output.data_ptr()) + output_offset;
170+
output_ptr =
171+
reinterpret_cast<EDataType*>(output.data_ptr()) + output_offset;
145172
output_offset += M * N;
146173
}
147174

@@ -165,7 +192,6 @@ void set_static_kernel_args(
165192
M,
166193
N,
167194
K);
168-
169195
}
170196
}
171197

@@ -180,8 +206,7 @@ __global__ void set_kernel_args_fixed_nk_kernel(
180206
int M,
181207
int N,
182208
int K,
183-
int group_count,
184-
const int BLOCK_SIZE) {
209+
int group_count) {
185210
int thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
186211
// Each thread is responsible for setting up the arguments for one group.
187212
if (thread_idx < group_count) {
@@ -203,33 +228,21 @@ __global__ void set_kernel_args_fixed_nk_kernel(
203228
kernel_args[thread_idx] = kernel_group_args;
204229
}
205230

206-
// We also fuse in initialization of the output tensor.
207-
// We write in chunks of 2 bfloats at a time for efficiency.
208-
for (int i = 0; i < BLOCK_SIZE / 2; i++) {
209-
// Figure out where in memory we are.
210-
int output_offset = (thread_idx * BLOCK_SIZE) + (i * 2);
211-
int current_group = output_offset / (M * N);
212-
// Skip if outside of valid groups.
213-
if (current_group < group_count) {
214-
int nonzeros = prepad_M[current_group];
215-
int current_M = output_offset / N;
216-
// Only write if this block needs initialization.
217-
// Avoid writing to final element if number of elements is odd.
218-
if (current_M >= nonzeros && output_offset < (M * N * group_count) - 1) {
219-
__hip_bfloat162* output_block =
220-
reinterpret_cast<__hip_bfloat162*>(output + output_offset);
221-
*output_block = __hip_bfloat162(0, 0);
222-
}
231+
// Figure out where in memory we are.
232+
// Each thread sets one float 4 which corresponds to 8 bf16 values.
233+
int output_offset = (thread_idx * 8);
234+
int current_group = output_offset / (M * N);
235+
// Skip if outside of valid groups.
236+
if (current_group < group_count) {
237+
int nonzeros = prepad_M[current_group];
238+
int current_M = (output_offset % (M * N)) / N;
239+
// Only write zeros if we're currently in a sparse row.
240+
if (current_M >= nonzeros) {
241+
// Write out a block of 8 output values via vectorized float4.
242+
float4* output_block = reinterpret_cast<float4*>(output + output_offset);
243+
*output_block = {0, 0, 0, 0};
223244
}
224245
}
225-
// Handle case where there are an odd number of total elements.
226-
if (((M * N * group_count) % 2) != 0 &&
227-
((M * N * group_count) - (thread_idx * BLOCK_SIZE) < BLOCK_SIZE)) {
228-
// Write out the final element.
229-
__hip_bfloat16* output_block =
230-
reinterpret_cast<__hip_bfloat16*>(output + (M * N * group_count) - 1);
231-
*output_block = __hip_bfloat16(0);
232-
}
233246
}
234247

235248
void set_dynamic_kernel_args(
@@ -261,9 +274,12 @@ void set_dynamic_kernel_args(
261274
int N = WQ.size(1);
262275

263276
// Launch a kernel that sets kernel argument memory.
277+
// Each thread sets one float4 which corresponds to 8 bf16 values.
264278
const int BLOCK_SIZE = 8;
279+
TORCH_CHECK(
280+
N % BLOCK_SIZE == 0, "N must be divisible 8 for dynamic grouped gemm.");
265281
int block_factor = std::max(group_count, (group_count * M * N) / BLOCK_SIZE);
266-
int blockSize = std::min(1024, block_factor);
282+
int blockSize = std::min(512, block_factor);
267283
int numBlocks = (block_factor + blockSize - 1) / blockSize;
268284
set_kernel_args_fixed_nk_kernel<<<numBlocks, blockSize, 0, stream>>>(
269285
reinterpret_cast<KernelArguments*>(kernel_args.data_ptr()),
@@ -276,8 +292,7 @@ void set_dynamic_kernel_args(
276292
M,
277293
N,
278294
K,
279-
group_count,
280-
BLOCK_SIZE);
295+
group_count);
281296
}
282297

283298
template <typename OutputType>
@@ -347,22 +362,25 @@ OutputType _f8f8bf16_rowwise_grouped(
347362
Y.push_back(at::empty({M, N}, XQ[i].options().dtype(at::kBFloat16)));
348363
}
349364
}
350-
// Now handle single tensor output.
365+
// Now handle single tensor output.
351366
} else {
352367
// Compute total M across groups.
353368
int total_M = 0;
354369
int N = WQ[0].size(0);
355370
for (int i = 0; i < group_count; i++) {
356371
total_M += XQ[i].size(0);
357372
// Also make sure N is constant across shapes.
358-
TORCH_CHECK(WQ[i].size(0) == N, "N must be constant across groups for stacked output.");
373+
TORCH_CHECK(
374+
WQ[i].size(0) == N,
375+
"N must be constant across groups for stacked output.");
359376
}
360377
if (output.has_value()) {
361378
Y = output.value();
362379
// Check that shape is expected.
363-
TORCH_CHECK(Y.size(0) == total_M && Y.size(1) == N, "Preallocated output should have size [total_M, N].");
364-
}
365-
else {
380+
TORCH_CHECK(
381+
Y.size(0) == total_M && Y.size(1) == N,
382+
"Preallocated output should have size [total_M, N].");
383+
} else {
366384
Y = at::empty({total_M, N}, XQ[0].options().dtype(at::kBFloat16));
367385
}
368386
}
@@ -383,7 +401,8 @@ OutputType _f8f8bf16_rowwise_grouped(
383401
MaxK = max(MaxK, XQ[i].size(1));
384402
}
385403
RowwiseGroupedKernel<at::TensorList, OutputType> selected_kernel =
386-
rowwise_grouped_heuristic_dispatch<at::TensorList, OutputType>(MaxM, MaxN, MaxK);
404+
rowwise_grouped_heuristic_dispatch<at::TensorList, OutputType>(
405+
MaxM, MaxN, MaxK);
387406
return selected_kernel(XQ, WQ, x_scale, w_scale, kernel_args, Y);
388407
}
389408

@@ -394,7 +413,8 @@ std::vector<at::Tensor> f8f8bf16_rowwise_grouped(
394413
at::TensorList x_scale,
395414
at::TensorList w_scale,
396415
std::optional<std::vector<at::Tensor>> output = std::nullopt) {
397-
return _f8f8bf16_rowwise_grouped<std::vector<at::Tensor>>(XQ, WQ, x_scale, w_scale, output);
416+
return _f8f8bf16_rowwise_grouped<std::vector<at::Tensor>>(
417+
XQ, WQ, x_scale, w_scale, output);
398418
}
399419

400420
// Wrapper function for list input single tensor output.
@@ -404,7 +424,8 @@ at::Tensor f8f8bf16_rowwise_grouped_stacked(
404424
at::TensorList x_scale,
405425
at::TensorList w_scale,
406426
std::optional<at::Tensor> output = std::nullopt) {
407-
return _f8f8bf16_rowwise_grouped<at::Tensor>(XQ, WQ, x_scale, w_scale, output);
427+
return _f8f8bf16_rowwise_grouped<at::Tensor>(
428+
XQ, WQ, x_scale, w_scale, output);
408429
}
409430

410431
at::Tensor f8f8bf16_rowwise_grouped_dynamic(
@@ -452,13 +473,7 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
452473
{static_cast<long>(group_count * sizeof(KernelArguments))},
453474
XQ.options().dtype(at::kByte));
454475
set_dynamic_kernel_args(
455-
kernel_args,
456-
XQ,
457-
WQ,
458-
x_scale,
459-
w_scale,
460-
Y,
461-
zero_start_index_M);
476+
kernel_args, XQ, WQ, x_scale, w_scale, Y, zero_start_index_M);
462477

463478
RowwiseGroupedKernel<at::Tensor, at::Tensor> selected_kernel =
464479
rowwise_grouped_heuristic_dispatch<at::Tensor, at::Tensor>(M, N, K);

0 commit comments

Comments
 (0)