@@ -25,8 +25,15 @@ namespace fbgemm_gpu {
25
25
using RowwiseKernel = std::function<
26
26
at::Tensor (at::Tensor, at::Tensor, at::Tensor, at::Tensor, at::Tensor)>;
27
27
28
+ using NKLookupTableType = std::map<int , RowwiseKernel>;
29
+
28
30
// Define a custom hash function for std::tuple<int, int, int>
29
31
struct IntTupleHash {
32
+ size_t operator ()(const std::tuple<int , int >& t) const {
33
+ auto hash1 = std::hash<int >{}(std::get<0 >(t));
34
+ auto hash2 = std::hash<int >{}(std::get<1 >(t));
35
+ return hash1 ^ hash2;
36
+ }
30
37
size_t operator ()(const std::tuple<int , int , int >& t) const {
31
38
auto hash1 = std::hash<int >{}(std::get<0 >(t));
32
39
auto hash2 = std::hash<int >{}(std::get<1 >(t));
@@ -38,24 +45,6 @@ struct IntTupleHash {
38
45
// For certain high priority shapes, we directly map to the best kernel rather
39
46
// than use heuristics.
40
47
static const std::unordered_map<std::tuple<int , int , int >, RowwiseKernel, IntTupleHash> rowwise_lookup_dispatch = {
41
- // Support for decode for [1024, 5120]
42
- {{16 , 1024 , 5120 },
43
- fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8},
44
- {{32 , 1024 , 5120 },
45
- fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8},
46
- {{64 , 1024 , 5120 },
47
- fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
48
- {{128 , 1024 , 5120 },
49
- fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
50
- // Support for decode for [5120, 1024]
51
- {{16 , 5120 , 1024 },
52
- fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
53
- {{32 , 5120 , 1024 },
54
- fp8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2},
55
- {{64 , 5120 , 1024 },
56
- fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
57
- {{128 , 5120 , 1024 },
58
- fp8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2},
59
48
// LLama 70B Decode shapes.
60
49
// Support for decode across batch sizes for [1280, 8192]
61
50
{{16 , 1280 , 8192 },
@@ -75,40 +64,6 @@ static const std::unordered_map<std::tuple<int, int, int>, RowwiseKernel, IntTup
75
64
fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
76
65
{{128 , 8192 , 1024 },
77
66
fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
78
- // Support for decode across batch sizes for [7168, 8192]
79
- {{16 , 7168 , 8192 },
80
- fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
81
- {{32 , 7168 , 8192 },
82
- fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
83
- {{64 , 7168 , 8192 },
84
- fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
85
- {{128 , 7168 , 8192 },
86
- fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
87
- {{1024 , 7168 , 8192 },
88
- fp8_rowwise_256x256x96x128_32x32_2x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
89
- {{2048 , 7168 , 8192 },
90
- fp8_rowwise_256x256x192x128_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
91
- {{4096 , 7168 , 8192 },
92
- fp8_rowwise_256x256x192x128_32x32_4x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
93
- {{8192 , 7168 , 8192 },
94
- fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
95
- // Support for decode across batch sizes for [8192, 3584]
96
- {{16 , 8192 , 3584 },
97
- fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
98
- {{32 , 8192 , 3584 },
99
- fp8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2},
100
- {{64 , 8192 , 3584 },
101
- fp8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
102
- {{128 , 8192 , 3584 },
103
- fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
104
- {{1024 , 8192 , 3584 },
105
- fp8_rowwise_256x256x128x128_32x32_4x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
106
- {{2048 , 8192 , 3584 },
107
- fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
108
- {{4096 , 8192 , 3584 },
109
- fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
110
- {{8192 , 8192 , 3584 },
111
- fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
112
67
// Llama 405B Decode Shapes.
113
68
// Support for decode across batch sizes for [13312, 6656].
114
69
{{16 , 13312 , 6656 },
@@ -218,6 +173,119 @@ static const std::unordered_map<std::tuple<int, int, int>, RowwiseKernel, IntTup
218
173
{{32768 , 1024 , 8192 },
219
174
fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}};
220
175
176
+ static const std::map<int , RowwiseKernel> N_7168_K_8192_dispatch_table = {
177
+ { 8 , fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_interwave_v2},
178
+ { 32 , fp8_rowwise_128x32x16x512_16x16_1x1_32x4x1_32x4x1_1x32x1x4_4x4x1_1x1_intrawave_v2},
179
+ { 64 , fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
180
+ { 128 , fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
181
+ { 320 , fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
182
+ { 512 , fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
183
+ { 576 , fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
184
+ { 640 , fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
185
+ { 768 , fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
186
+ { 1024 , fp8_rowwise_256x256x96x128_16x16_8x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
187
+ { 1280 , fp8_rowwise_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
188
+ { 1536 , fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
189
+ { 2048 , fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
190
+ { 2304 , fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
191
+ { 2560 , fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
192
+ { 3328 , fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
193
+ { 4096 , fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
194
+ { 4672 , fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
195
+ { 4864 , fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
196
+ { 5376 , fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
197
+ { 6144 , fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
198
+ { 7168 , fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
199
+ { 8192 , fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}
200
+ };
201
+
202
+ static const std::map<int , RowwiseKernel> N_8192_K_3584_dispatch_table = {
203
+ { 8 , fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_interwave_v2},
204
+ { 16 , fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
205
+ { 32 , fp8_rowwise_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3},
206
+ { 64 , fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
207
+ { 128 , fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
208
+ { 192 , fp8_rowwise_256x64x96x256_16x16_2x3_16x16x1_16x16x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
209
+ { 256 , fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
210
+ { 384 , fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
211
+ { 512 , fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
212
+ { 640 , fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
213
+ { 896 , fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
214
+ { 1024 , fp8_rowwise_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
215
+ { 1280 , fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
216
+ { 1792 , fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
217
+ { 2048 , fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
218
+ { 2304 , fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
219
+ { 2368 , fp8_rowwise_256x128x256x128_32x32_2x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
220
+ { 2816 , fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
221
+ { 3584 , fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
222
+ { 4256 , fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
223
+ { 4864 , fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
224
+ { 5376 , fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
225
+ { 6272 , fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
226
+ { 7168 , fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
227
+ { 7424 , fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
228
+ { 8192 , fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}
229
+ };
230
+
231
+ static const std::map<int , RowwiseKernel> N_1024_K_5120_dispatch_table = {
232
+ { 32 , fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2_8},
233
+ { 64 , fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2_2},
234
+ { 128 , fp8_rowwise_128x16x32x512_16x16_1x1_32x4x1_32x4x1_1x16x1x8_4x4x1_1x1_intrawave_v2},
235
+ { 192 , fp8_rowwise_256x16x64x512_16x16_1x1_32x8x1_32x8x1_1x16x1x16_4x4x1_1x1_intrawave_v3},
236
+ { 608 , fp8_rowwise_256x32x64x512_16x16_1x2_32x8x1_32x8x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
237
+ { 1216 , fp8_rowwise_256x64x64x512_32x32_1x1_32x8x1_32x8x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
238
+ { 2432 , fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
239
+ { 3456 , fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
240
+ { 4864 , fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
241
+ { 5472 , fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
242
+ { 6368 , fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
243
+ { 6912 , fp8_rowwise_256x256x96x128_16x16_8x3_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
244
+ { 8192 , fp8_rowwise_256x256x128x128_16x16_8x4_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}
245
+ };
246
+
247
+ static const std::map<int , RowwiseKernel> N_5120_K_1024_dispatch_table = {
248
+ { 16 , fp8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2},
249
+ { 32 , fp8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x16x1x4_4x4x1_1x1_intrawave_v1},
250
+ { 64 , fp8_rowwise_128x32x16x256_16x16_1x1_16x8x1_16x8x1_1x32x1x4_4x4x1_1x1_interwave_v1},
251
+ { 96 , fp8_rowwise_128x16x32x256_16x16_1x1_16x8x1_16x8x1_1x16x1x8_4x4x1_1x1_intrawave_v1},
252
+ { 192 , fp8_rowwise_256x32x128x256_32x32_1x1_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
253
+ { 256 , fp8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
254
+ { 320 , fp8_rowwise_256x64x96x256_16x16_2x3_16x16x1_16x16x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
255
+ { 448 , fp8_rowwise_256x64x128x256_32x32_1x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
256
+ { 640 , fp8_rowwise_256x128x96x256_32x32_1x3_16x16x1_16x16x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
257
+ { 896 , fp8_rowwise_256x128x128x256_32x32_2x2_16x16x1_16x16x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
258
+ { 1152 , fp8_rowwise_256x128x160x128_32x32_1x5_8x32x1_8x32x1_1x64x1x4_8x8x1_1x1_intrawave_v3},
259
+ { 1408 , fp8_rowwise_256x128x192x128_32x32_2x3_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
260
+ { 1920 , fp8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3},
261
+ { 2304 , fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
262
+ { 2816 , fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
263
+ { 3360 , fp8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
264
+ { 3840 , fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
265
+ { 4864 , fp8_rowwise_256x256x160x128_16x16_8x5_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
266
+ { 5632 , fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
267
+ { 6720 , fp8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3},
268
+ { 7680 , fp8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3},
269
+ { 8192 , fp8_rowwise_256x256x192x128_16x16_8x6_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3}
270
+ };
271
+
272
+ static const std::unordered_map<std::tuple<int , int >, NKLookupTableType, IntTupleHash> NK_lookup_table = {
273
+ {{7168 , 8192 }, N_7168_K_8192_dispatch_table},
274
+ {{8192 , 3584 }, N_8192_K_3584_dispatch_table},
275
+ {{1024 , 5120 }, N_1024_K_5120_dispatch_table},
276
+ {{5120 , 1024 }, N_5120_K_1024_dispatch_table}
277
+ };
278
+
279
+ RowwiseKernel rowwise_nk_lookup (int M, const NKLookupTableType& table) {
280
+ auto it = table.lower_bound (M);
281
+ if (it != table.end ()) {
282
+ return it->second ;
283
+ } else {
284
+ --it;
285
+ return it->second ;
286
+ }
287
+ }
288
+
221
289
RowwiseKernel rowwise_heuristic_dispatch (int M, int N, int K) {
222
290
// Apply shape heuristics to find a suitable kernel implementation.
223
291
@@ -281,6 +349,11 @@ RowwiseKernel rowwise_dispatch(int M, int N, int K) {
281
349
// If we found an optimal kernel, use it.
282
350
if (it != rowwise_lookup_dispatch.end ()) {
283
351
return it->second ;
352
+ } else {
353
+ auto nk_lookup_it = NK_lookup_table.find ({N,K});
354
+ if (nk_lookup_it != NK_lookup_table.end ()){
355
+ return rowwise_nk_lookup (M, nk_lookup_it->second );
356
+ }
284
357
}
285
358
// Otherwise, use heuristics.
286
359
return rowwise_heuristic_dispatch (M, N, K);
0 commit comments