Skip to content

Commit cc30828

Browse files
mxz297facebook-github-bot
authored andcommitted
fp8 rowwise regular gemm tuning for llm new shapes (pytorch#730)
Summary: X-link: pytorch#3654 Pull Request resolved: facebookresearch/FBGEMM#730 1. Add more FP8 rowwise instances 2. Extend the template to allow specifying AK1 and BK1 3. Add tuning for new shapes Reviewed By: jianyuh Differential Revision: D69070084 fbshipit-source-id: 7f841bd0855743d5389d7afa0f135619ca0cd61a
1 parent 8968454 commit cc30828

17 files changed

+628
-209
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/fp8_rowwise_gemm.hip

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,11 +269,72 @@ static const std::map<int, RowwiseKernel> N_5120_K_1024_dispatch_table = {
269269
{ 8192, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}
270270
};
271271

272+
static const std::map<int, RowwiseKernel> N_2048_K_5120_dispatch_table = {
273+
{ 4, fp8_rowwise_256x16x64x128_16x16_1x1_16x16x1_8x32x1_1x16x1x16_4x4x1_1x1_intrawave_v2_8},
274+
{ 8, fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_4},
275+
{ 64, fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_intrawave_v2},
276+
{ 288, fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
277+
{ 576, fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
278+
{ 1216, fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
279+
{ 1664, fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
280+
{ 2432, fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
281+
{ 2944, fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
282+
{ 3456, fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
283+
{ 4864, fp8_rowwise_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
284+
{ 5888, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
285+
{ 5984, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
286+
};
287+
288+
static const std::map<int, RowwiseKernel> N_896_K_5120_dispatch_table = {
289+
{ 64, fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8},
290+
{ 72, fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2_8},
291+
{ 80, fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2_2},
292+
{ 160, fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_intrawave_v2},
293+
{ 200, fp8_rowwise_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3},
294+
{ 256, fp8_rowwise_256x64x16x512_16x16_1x1_32x8x1_32x8x1_1x64x1x4_4x4x1_1x1_intrawave_v2},
295+
{ 672, fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
296+
{ 1344, fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
297+
{ 2752, fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
298+
{ 3840, fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
299+
{ 5504, fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
300+
{ 5984, fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
301+
};
302+
303+
static const std::map<int, RowwiseKernel> N_5120_K_640_dispatch_table = {
304+
{ 64, fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
305+
{ 80, fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
306+
{ 112, fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
307+
{ 192, fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
308+
{ 224, fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2},
309+
{ 256, fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
310+
{ 384, fp8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
311+
{ 448, fp8_rowwise_256x64x128x128_32x32_1x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
312+
{ 512, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
313+
{ 704, fp8_rowwise_256x64x192x128_32x32_1x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
314+
{ 896, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
315+
{ 960, fp8_rowwise_256x64x256x128_32x32_1x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
316+
{ 1152, fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
317+
{ 1280, fp8_rowwise_256x256x96x128_32x32_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
318+
{ 1408, fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
319+
{ 1920, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
320+
{ 2304, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
321+
{ 2816, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
322+
{ 3360, fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
323+
{ 3840, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
324+
{ 4864, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
325+
{ 5520, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
326+
{ 5760, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
327+
{ 5984, fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
328+
};
329+
272330
static const std::unordered_map<std::tuple<int, int>, NKLookupTableType, IntTupleHash> NK_lookup_table = {
273331
{{7168, 8192}, N_7168_K_8192_dispatch_table},
274332
{{8192, 3584}, N_8192_K_3584_dispatch_table},
275333
{{1024, 5120}, N_1024_K_5120_dispatch_table},
276-
{{5120, 1024}, N_5120_K_1024_dispatch_table}
334+
{{5120, 1024}, N_5120_K_1024_dispatch_table},
335+
{{2048, 5120}, N_2048_K_5120_dispatch_table},
336+
{{896, 5120}, N_896_K_5120_dispatch_table},
337+
{{5120, 640}, N_5120_K_640_dispatch_table}
277338
};
278339

279340
RowwiseKernel rowwise_nk_lookup(int M, const NKLookupTableType& table) {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "fp8_rowwise_common.h"
10+
11+
at::Tensor
12+
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_4(
13+
at::Tensor XQ,
14+
at::Tensor WQ,
15+
at::Tensor x_scale,
16+
at::Tensor w_scale,
17+
at::Tensor Y) {
18+
using DeviceGemmInstance = DeviceGemmHelper<
19+
128,
20+
16,
21+
32,
22+
128,
23+
16,
24+
16,
25+
1,
26+
1,
27+
S<8, 16, 1>,
28+
S<8, 16, 1>,
29+
S<1, 16, 1, 8>,
30+
S<4, 4, 1>,
31+
1,
32+
1,
33+
ck::BlockGemmPipelineScheduler::Interwave,
34+
ck::BlockGemmPipelineVersion::v2,
35+
ck::tensor_operation::device::GemmSpecialization::Default>;
36+
// Run kernel instance.
37+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y, 4);
38+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "fp8_rowwise_common.h"
10+
11+
at::Tensor
12+
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v1(
13+
at::Tensor XQ,
14+
at::Tensor WQ,
15+
at::Tensor x_scale,
16+
at::Tensor w_scale,
17+
at::Tensor Y) {
18+
using DeviceGemmInstance = DeviceGemmHelper<
19+
128,
20+
16,
21+
32,
22+
128,
23+
16,
24+
16,
25+
1,
26+
1,
27+
S<8, 16, 1>,
28+
S<8, 16, 1>,
29+
S<1, 16, 1, 8>,
30+
S<4, 4, 1>,
31+
1,
32+
1,
33+
ck::BlockGemmPipelineScheduler::Intrawave,
34+
ck::BlockGemmPipelineVersion::v1,
35+
ck::tensor_operation::device::GemmSpecialization::Default>;
36+
// Run kernel instance.
37+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
38+
}
39+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "fp8_rowwise_common.h"
10+
11+
at::Tensor
12+
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2_8(
13+
at::Tensor XQ,
14+
at::Tensor WQ,
15+
at::Tensor x_scale,
16+
at::Tensor w_scale,
17+
at::Tensor Y) {
18+
using DeviceGemmInstance = DeviceGemmHelper<
19+
128,
20+
16,
21+
32,
22+
128,
23+
16,
24+
16,
25+
1,
26+
1,
27+
S<8, 16, 1>,
28+
S<8, 16, 1>,
29+
S<1, 16, 1, 8>,
30+
S<4, 4, 1>,
31+
1,
32+
1,
33+
ck::BlockGemmPipelineScheduler::Intrawave,
34+
ck::BlockGemmPipelineVersion::v2,
35+
ck::tensor_operation::device::GemmSpecialization::Default>;
36+
// Run kernel instance.
37+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y, 8);
38+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "fp8_rowwise_common.h"
10+
11+
at::Tensor
12+
fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2_2(
13+
at::Tensor XQ,
14+
at::Tensor WQ,
15+
at::Tensor x_scale,
16+
at::Tensor w_scale,
17+
at::Tensor Y) {
18+
using DeviceGemmInstance = DeviceGemmHelper<
19+
128,
20+
16,
21+
32,
22+
512,
23+
16,
24+
16,
25+
1,
26+
1,
27+
S<32, 4, 1>,
28+
S<32, 4, 1>,
29+
S<1, 16, 1, 8>,
30+
S<4, 4, 1>,
31+
1,
32+
1,
33+
ck::BlockGemmPipelineScheduler::Interwave,
34+
ck::BlockGemmPipelineVersion::v2,
35+
ck::tensor_operation::device::GemmSpecialization::Default>;
36+
// Run kernel instance.
37+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y, 2);
38+
}

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.hip

Lines changed: 21 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,54 +15,25 @@ fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v
1515
at::Tensor x_scale,
1616
at::Tensor w_scale,
1717
at::Tensor Y) {
18-
// A small kernel for small but not tiny shapes.
19-
20-
// Check if this input needs to be padded.
21-
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
22-
int N = WQ.size(0);
23-
int K = WQ.size(1);
24-
bool pad = (M % 32 != 0) || (N % 16 != 0) || (K % 128 != 0);
25-
26-
if (pad) {
27-
using DeviceGemmInstance = DeviceGemmHelper<
28-
128,
29-
32,
30-
16,
31-
128,
32-
16,
33-
16,
34-
1,
35-
1,
36-
S<8, 16, 1>,
37-
S<8, 16, 1>,
38-
S<1, 16, 1, 8>,
39-
S<2, 2, 1>,
40-
1,
41-
1,
42-
ck::BlockGemmPipelineScheduler::Interwave,
43-
ck::BlockGemmPipelineVersion::v2>;
44-
// Run kernel instance.
45-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
46-
} else {
47-
using DeviceGemmInstance = DeviceGemmHelper<
48-
128,
49-
32,
50-
16,
51-
128,
52-
16,
53-
16,
54-
1,
55-
1,
56-
S<8, 16, 1>,
57-
S<8, 16, 1>,
58-
S<1, 16, 1, 8>,
59-
S<2, 2, 1>,
60-
1,
61-
1,
62-
ck::BlockGemmPipelineScheduler::Interwave,
63-
ck::BlockGemmPipelineVersion::v2,
64-
ck::tensor_operation::device::GemmSpecialization::Default>;
65-
// Run kernel instance.
66-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
67-
}
18+
using DeviceGemmInstance = DeviceGemmHelper<
19+
128,
20+
32,
21+
16,
22+
128,
23+
16,
24+
16,
25+
1,
26+
1,
27+
S<8, 16, 1>,
28+
S<8, 16, 1>,
29+
S<1, 16, 1, 8>,
30+
S<2, 2, 1>,
31+
1,
32+
1,
33+
ck::BlockGemmPipelineScheduler::Interwave,
34+
ck::BlockGemmPipelineVersion::v2,
35+
ck::tensor_operation::device::GemmSpecialization::Default>;
36+
// Run kernel instance.
37+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
6838
}
39+

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.hip

Lines changed: 21 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,54 +15,25 @@ fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v
1515
at::Tensor x_scale,
1616
at::Tensor w_scale,
1717
at::Tensor Y) {
18-
// A kernel that works well on small but not super tiny shapes.
19-
20-
// Check if this input needs to be padded.
21-
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
22-
int N = WQ.size(0);
23-
int K = WQ.size(1);
24-
bool pad = (M % 32 != 0) || (N % 64 != 0) || (K % 128 != 0);
25-
26-
if (pad) {
27-
using DeviceGemmInstance = DeviceGemmHelper<
28-
128,
29-
32,
30-
64,
31-
128,
32-
32,
33-
32,
34-
1,
35-
1,
36-
S<8, 16, 1>,
37-
S<8, 16, 1>,
38-
S<1, 16, 1, 8>,
39-
S<8, 8, 1>,
40-
1,
41-
1,
42-
ck::BlockGemmPipelineScheduler::Interwave,
43-
ck::BlockGemmPipelineVersion::v2>;
44-
// Run kernel instance.
45-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
46-
} else {
47-
using DeviceGemmInstance = DeviceGemmHelper<
48-
128,
49-
32,
50-
64,
51-
128,
52-
32,
53-
32,
54-
1,
55-
1,
56-
S<8, 16, 1>,
57-
S<8, 16, 1>,
58-
S<1, 16, 1, 8>,
59-
S<8, 8, 1>,
60-
1,
61-
1,
62-
ck::BlockGemmPipelineScheduler::Interwave,
63-
ck::BlockGemmPipelineVersion::v2,
64-
ck::tensor_operation::device::GemmSpecialization::Default>;
65-
// Run kernel instance.
66-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
67-
}
18+
using DeviceGemmInstance = DeviceGemmHelper<
19+
128,
20+
32,
21+
64,
22+
128,
23+
32,
24+
32,
25+
1,
26+
1,
27+
S<8, 16, 1>,
28+
S<8, 16, 1>,
29+
S<1, 16, 1, 8>,
30+
S<8, 8, 1>,
31+
1,
32+
1,
33+
ck::BlockGemmPipelineScheduler::Interwave,
34+
ck::BlockGemmPipelineVersion::v2,
35+
ck::tensor_operation::device::GemmSpecialization::Default>;
36+
// Run kernel instance.
37+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
6838
}
39+

0 commit comments

Comments
 (0)