Skip to content

Commit fe42aed

Browse files
zjing14facebook-github-bot
authored andcommitted
A hotfix for FBGEMM fp8 rowwise with irregular gemm sizes
Summary: - Hotfix for T219165899 reported by pranavsh, which is caused by some instances requiring K size being multiple of `KTile` - Added fallback for GEMM cases with K is not multiple of Max KTile = 256. Reviewed By: jianyuh Differential Revision: D71863248
1 parent 74db0ac commit fe42aed

File tree

2 files changed

+33
-5
lines changed

2 files changed

+33
-5
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/fp8_rowwise_gemm.hip

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,11 @@ RowwiseKernel rowwise_nk_lookup(int M, const NKLookupTableType& table) {
422422
RowwiseKernel rowwise_heuristic_dispatch(int M, int N, int K) {
423423
// Apply shape heuristics to find a suitable kernel implementation.
424424

425-
//Fallback of irregular data types
426-
if(!((N % 8 == 0) && (K % 16 == 0)))
425+
//Fallback for irregular data types: some instances require K to be a multiple
426+
//of K Tile.
427+
//To-Do: Need a systemic solution for various restrictions from different
428+
//instances.
429+
if(!((N % 8 == 0) && (K % 256 == 0)))
427430
return fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1;
428431

429432
if (M < 64 && N < 2048 && K < 2048) {

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.hip

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1
2020
int N = WQ.size(0);
2121
int K = WQ.size(1);
2222

23-
if ((K % 16 == 0) && (N % 4 == 0)) {
23+
if ((K % 256 == 0) && (N % 4 == 0)) {
2424
using DeviceGemmInstance = DeviceGemmHelper<
2525
64,
2626
16,
@@ -42,6 +42,31 @@ fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1
4242
16,
4343
16>;
4444

45+
// Run kernel instance.
46+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
47+
XQ, WQ, x_scale, w_scale, Y);
48+
} else if ((K % 16 == 0) && (N % 4 == 0)) {
49+
using DeviceGemmInstance = DeviceGemmHelper<
50+
64,
51+
16,
52+
16,
53+
256,
54+
16,
55+
16,
56+
1,
57+
1,
58+
S<16, 4, 1>,
59+
S<16, 4, 1>,
60+
S<1, 16, 1, 4>,
61+
S<4, 4, 1>,
62+
1,
63+
1,
64+
ck::BlockGemmPipelineScheduler::Intrawave,
65+
ck::BlockGemmPipelineVersion::v1,
66+
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
67+
16,
68+
16>;
69+
4570
// Run kernel instance.
4671
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(
4772
XQ, WQ, x_scale, w_scale, Y);
@@ -63,7 +88,7 @@ fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1
6388
1,
6489
ck::BlockGemmPipelineScheduler::Intrawave,
6590
ck::BlockGemmPipelineVersion::v1,
66-
ck::tensor_operation::device::GemmSpecialization::Default,
91+
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
6792
8,
6893
8>;
6994

@@ -88,7 +113,7 @@ fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1
88113
1,
89114
ck::BlockGemmPipelineScheduler::Intrawave,
90115
ck::BlockGemmPipelineVersion::v1,
91-
ck::tensor_operation::device::GemmSpecialization::Default,
116+
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
92117
2,
93118
2>;
94119

0 commit comments

Comments
 (0)