Skip to content

Commit 74490d6

Browse files
mxz297facebook-github-bot
authored andcommitted
amd fp8 rowwise gemm prefill shape tuning (#3607)
Summary: X-link: facebookresearch/FBGEMM#685 Pull Request resolved: #3607 This diff aims to add a more robust FP8 rowwise heuristics for LLM, especially for prefill cases. Consider input [M, K] and weight [N, K]. For LLMs, N and K are fixed across different prefill/decode lengths. Thus the new heuristic is based on lookup for (N,K) and then do a range based lookup for M. For each combination of N and K, there is offline tuning for many M, looking like: ``` 5280, 8192, 3584, 0.318272, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5312, 8192, 3584, 0.322179, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5344, 8192, 3584, 0.320632, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5440, 8192, 3584, 0.341432, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5472, 8192, 3584, 0.3436, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5536, 8192, 3584, 0.341703, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5568, 8192, 3584, 0.342054, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5632, 8192, 3584, 0.347904, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 5664, 8192, 3584, 0.345129, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 ``` A clear pattern is that a single instance is the top choice for a large range, justifying the M range based heuristic. The full tuning log is parsed and converted into a std::map for range based lookup. One key question here is which instance to use right at the range where the best instance has changed. For example: ``` 5376, 8192, 3584, 0.317728, 256x256x192x128_16x16_8x6_intrawave_v3_kbatch_1 5408, 8192, 3584, 0.338742, 256x224x256x128_16x16_7x8_intrawave_v3_kbatch_1 ``` Should we use 256x256x192x128 or 256x224x256x128 for M = 5377 to 5407? The implementation uses the tuning entry for the larger value (so use 256x224x256x128). The rational is if we use the smaller entry, it may lead to increased thread blocks and cause bad perf; in contrast, if we use the larger entry, the perf will in theory be the same as the larger entry itself. Empirically, using the smaller entry lead to some degraded perf for untuned values. Reviewed By: jwfromm Differential Revision: D68521662 fbshipit-source-id: e59a8634678a77e4e4d5c2110dbe5d92febc3ad8
1 parent 5f3adca commit 74490d6

File tree

32 files changed

+1279
-473
lines changed

32 files changed

+1279
-473
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/fp8_rowwise_gemm.hip

Lines changed: 125 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,15 @@ namespace fbgemm_gpu {
2525
using RowwiseKernel = std::function<
2626
at::Tensor(at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor)>;
2727

28+
using NKLookupTableType = std::map<int, RowwiseKernel>;
29+
2830
// Define a custom hash function for std::tuple<int, int, int>
2931
struct IntTupleHash {
32+
size_t operator()(const std::tuple<int, int>& t) const {
33+
auto hash1 = std::hash<int>{}(std::get<0>(t));
34+
auto hash2 = std::hash<int>{}(std::get<1>(t));
35+
return hash1 ^ hash2;
36+
}
3037
size_t operator()(const std::tuple<int, int, int>& t) const {
3138
auto hash1 = std::hash<int>{}(std::get<0>(t));
3239
auto hash2 = std::hash<int>{}(std::get<1>(t));
@@ -38,24 +45,6 @@ struct IntTupleHash {
3845
// For certain high priority shapes, we directly map to the best kernel rather
3946
// than use heuristics.
4047
static const std::unordered_map<std::tuple<int, int, int>, RowwiseKernel, IntTupleHash> rowwise_lookup_dispatch = {
41-
// Support for decode for [1024, 5120]
42-
{{16, 1024, 5120},
43-
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8},
44-
{{32, 1024, 5120},
45-
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8},
46-
{{64, 1024, 5120},
47-
fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
48-
{{128, 1024, 5120},
49-
fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
50-
// Support for decode for [5120, 1024]
51-
{{16, 5120, 1024},
52-
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
53-
{{32, 5120, 1024},
54-
fp8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2},
55-
{{64, 5120, 1024},
56-
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
57-
{{128, 5120, 1024},
58-
fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2},
5948
// LLama 70B Decode shapes.
6049
// Support for decode across batch sizes for [1280, 8192]
6150
{{16, 1280, 8192},
@@ -75,40 +64,6 @@ static const std::unordered_map<std::tuple<int, int, int>, RowwiseKernel, IntTup
7564
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
7665
{{128, 8192, 1024},
7766
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
78-
// Support for decode across batch sizes for [7168, 8192]
79-
{{16, 7168, 8192},
80-
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
81-
{{32, 7168, 8192},
82-
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
83-
{{64, 7168, 8192},
84-
fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
85-
{{128, 7168, 8192},
86-
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
87-
{{1024, 7168, 8192},
88-
fp8_rowwise_256x256x96x128_32x32_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
89-
{{2048, 7168, 8192},
90-
fp8_rowwise_256x256x192x128_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
91-
{{4096, 7168, 8192},
92-
fp8_rowwise_256x256x192x128_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
93-
{{8192, 7168, 8192},
94-
fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
95-
// Support for decode across batch sizes for [8192, 3584]
96-
{{16, 8192, 3584},
97-
fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
98-
{{32, 8192, 3584},
99-
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
100-
{{64, 8192, 3584},
101-
fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
102-
{{128, 8192, 3584},
103-
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
104-
{{1024, 8192, 3584},
105-
fp8_rowwise_256x256x128x128_32x32_4x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
106-
{{2048, 8192, 3584},
107-
fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
108-
{{4096, 8192, 3584},
109-
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
110-
{{8192, 8192, 3584},
111-
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
11267
// Llama 405B Decode Shapes.
11368
// Support for decode across batch sizes for [13312, 6656].
11469
{{16, 13312, 6656},
@@ -218,6 +173,119 @@ static const std::unordered_map<std::tuple<int, int, int>, RowwiseKernel, IntTup
218173
{{32768, 1024, 8192},
219174
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}};
220175

176+
static const std::map<int, RowwiseKernel> N_7168_K_8192_dispatch_table = {
177+
{ 8, fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_interwave_v2},
178+
{ 32, fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_intrawave_v2},
179+
{ 64, fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
180+
{ 128, fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
181+
{ 320, fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
182+
{ 512, fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
183+
{ 576, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
184+
{ 640, fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
185+
{ 768, fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
186+
{ 1024, fp8_rowwise_256x256x96x128_16x16_8x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
187+
{ 1280, fp8_rowwise_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
188+
{ 1536, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
189+
{ 2048, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
190+
{ 2304, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
191+
{ 2560, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
192+
{ 3328, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
193+
{ 4096, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
194+
{ 4672, fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
195+
{ 4864, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
196+
{ 5376, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
197+
{ 6144, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
198+
{ 7168, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
199+
{ 8192, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}
200+
};
201+
202+
static const std::map<int, RowwiseKernel> N_8192_K_3584_dispatch_table = {
203+
{ 8, fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2},
204+
{ 16, fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
205+
{ 32, fp8_rowwise_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3},
206+
{ 64, fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
207+
{ 128, fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
208+
{ 192, fp8_rowwise_256x64x96x256_16x16_2x3_16x16x1_16x16x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
209+
{ 256, fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
210+
{ 384, fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
211+
{ 512, fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
212+
{ 640, fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
213+
{ 896, fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
214+
{ 1024, fp8_rowwise_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
215+
{ 1280, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
216+
{ 1792, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
217+
{ 2048, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
218+
{ 2304, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
219+
{ 2368, fp8_rowwise_256x128x256x128_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
220+
{ 2816, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
221+
{ 3584, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
222+
{ 4256, fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
223+
{ 4864, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
224+
{ 5376, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
225+
{ 6272, fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
226+
{ 7168, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
227+
{ 7424, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
228+
{ 8192, fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}
229+
};
230+
231+
static const std::map<int, RowwiseKernel> N_1024_K_5120_dispatch_table = {
232+
{ 32, fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8},
233+
{ 64, fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2_2},
234+
{ 128, fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
235+
{ 192, fp8_rowwise_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3},
236+
{ 608, fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
237+
{ 1216, fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
238+
{ 2432, fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
239+
{ 3456, fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
240+
{ 4864, fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
241+
{ 5472, fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
242+
{ 6368, fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
243+
{ 6912, fp8_rowwise_256x256x96x128_16x16_8x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
244+
{ 8192, fp8_rowwise_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}
245+
};
246+
247+
static const std::map<int, RowwiseKernel> N_5120_K_1024_dispatch_table = {
248+
{ 16, fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
249+
{ 32, fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_intrawave_v1},
250+
{ 64, fp8_rowwise_128x32x16x256_16x16_1x1_16x8x1_16x8x1_1x32x1x4_4x4x1_1x1_interwave_v1},
251+
{ 96, fp8_rowwise_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
252+
{ 192, fp8_rowwise_256x32x128x256_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
253+
{ 256, fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
254+
{ 320, fp8_rowwise_256x64x96x256_16x16_2x3_16x16x1_16x16x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
255+
{ 448, fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
256+
{ 640, fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
257+
{ 896, fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
258+
{ 1152, fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
259+
{ 1408, fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
260+
{ 1920, fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
261+
{ 2304, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
262+
{ 2816, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
263+
{ 3360, fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
264+
{ 3840, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
265+
{ 4864, fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
266+
{ 5632, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
267+
{ 6720, fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
268+
{ 7680, fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
269+
{ 8192, fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}
270+
};
271+
272+
static const std::unordered_map<std::tuple<int, int>, NKLookupTableType, IntTupleHash> NK_lookup_table = {
273+
{{7168, 8192}, N_7168_K_8192_dispatch_table},
274+
{{8192, 3584}, N_8192_K_3584_dispatch_table},
275+
{{1024, 5120}, N_1024_K_5120_dispatch_table},
276+
{{5120, 1024}, N_5120_K_1024_dispatch_table}
277+
};
278+
279+
RowwiseKernel rowwise_nk_lookup(int M, const NKLookupTableType& table) {
280+
auto it = table.lower_bound(M);
281+
if (it != table.end()) {
282+
return it->second;
283+
} else {
284+
--it;
285+
return it->second;
286+
}
287+
}
288+
221289
RowwiseKernel rowwise_heuristic_dispatch(int M, int N, int K) {
222290
// Apply shape heuristics to find a suitable kernel implementation.
223291

@@ -281,6 +349,11 @@ RowwiseKernel rowwise_dispatch(int M, int N, int K) {
281349
// If we found an optimal kernel, use it.
282350
if (it != rowwise_lookup_dispatch.end()) {
283351
return it->second;
352+
} else {
353+
auto nk_lookup_it = NK_lookup_table.find({N,K});
354+
if (nk_lookup_it != NK_lookup_table.end()){
355+
return rowwise_nk_lookup(M, nk_lookup_it->second);
356+
}
284357
}
285358
// Otherwise, use heuristics.
286359
return rowwise_heuristic_dispatch(M, N, K);

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.hip

Lines changed: 21 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -15,55 +15,25 @@ fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v
1515
at::Tensor x_scale,
1616
at::Tensor w_scale,
1717
at::Tensor Y) {
18-
// The smallest kernel we have available. Works well for memory bound shapes.
19-
20-
// Check if this input needs to be padded.
21-
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
22-
int N = WQ.size(0);
23-
int K = WQ.size(1);
24-
bool pad = (M % 16 != 0) || (N % 32 != 0) || (K % 128 != 0);
25-
26-
if (pad) {
27-
using DeviceGemmInstance = DeviceGemmHelper<
28-
128,
29-
16,
30-
32,
31-
128,
32-
16,
33-
16,
34-
1,
35-
1,
36-
S<8, 16, 1>,
37-
S<8, 16, 1>,
38-
S<1, 16, 1, 8>,
39-
S<4, 4, 1>,
40-
1,
41-
1,
42-
ck::BlockGemmPipelineScheduler::Interwave,
43-
ck::BlockGemmPipelineVersion::v2>;
44-
// Run kernel instance.
45-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
46-
} else {
47-
using DeviceGemmInstance = DeviceGemmHelper<
48-
128,
49-
16,
50-
32,
51-
128,
52-
16,
53-
16,
54-
1,
55-
1,
56-
S<8, 16, 1>,
57-
S<8, 16, 1>,
58-
S<1, 16, 1, 8>,
59-
S<4, 4, 1>,
60-
1,
61-
1,
62-
ck::BlockGemmPipelineScheduler::Interwave,
63-
ck::BlockGemmPipelineVersion::v2,
64-
ck::tensor_operation::device::GemmSpecialization::Default>;
65-
// Run kernel instance.
66-
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
67-
68-
}
18+
using DeviceGemmInstance = DeviceGemmHelper<
19+
128,
20+
16,
21+
32,
22+
128,
23+
16,
24+
16,
25+
1,
26+
1,
27+
S<8, 16, 1>,
28+
S<8, 16, 1>,
29+
S<1, 16, 1, 8>,
30+
S<4, 4, 1>,
31+
1,
32+
1,
33+
ck::BlockGemmPipelineScheduler::Interwave,
34+
ck::BlockGemmPipelineVersion::v2,
35+
ck::tensor_operation::device::GemmSpecialization::Default>;
36+
// Run kernel instance.
37+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
6938
}
39+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "fp8_rowwise_common.h"
10+
11+
at::Tensor
12+
fp8_rowwise_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1(
13+
at::Tensor XQ,
14+
at::Tensor WQ,
15+
at::Tensor x_scale,
16+
at::Tensor w_scale,
17+
at::Tensor Y) {
18+
using DeviceGemmInstance = DeviceGemmHelper<
19+
128,
20+
16,
21+
32,
22+
256,
23+
16,
24+
16,
25+
1,
26+
1,
27+
S<16, 8, 1>,
28+
S<16, 8, 1>,
29+
S<1, 16, 1, 8>,
30+
S<4, 4, 1>,
31+
1,
32+
1,
33+
ck::BlockGemmPipelineScheduler::Intrawave,
34+
ck::BlockGemmPipelineVersion::v1,
35+
ck::tensor_operation::device::GemmSpecialization::Default>;
36+
// Run kernel instance.
37+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
38+
}
39+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "fp8_rowwise_common.h"
10+
11+
at::Tensor
12+
fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2(
13+
at::Tensor XQ,
14+
at::Tensor WQ,
15+
at::Tensor x_scale,
16+
at::Tensor w_scale,
17+
at::Tensor Y) {
18+
using DeviceGemmInstance = DeviceGemmHelper<
19+
128,
20+
16,
21+
32,
22+
512,
23+
16,
24+
16,
25+
1,
26+
1,
27+
S<32, 4, 1>,
28+
S<32, 4, 1>,
29+
S<1, 16, 1, 8>,
30+
S<4, 4, 1>,
31+
1,
32+
1,
33+
ck::BlockGemmPipelineScheduler::Interwave,
34+
ck::BlockGemmPipelineVersion::v2,
35+
ck::tensor_operation::device::GemmSpecialization::Default>;
36+
// Run kernel instance.
37+
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
38+
}
39+

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise/kernels/fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2.hip

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,3 +36,4 @@ fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v
3636
// Run kernel instance.
3737
return f8f8bf16_rowwise_impl<DeviceGemmInstance>(XQ, WQ, x_scale, w_scale, Y);
3838
}
39+

0 commit comments

Comments
 (0)