Skip to content

Commit ec57cc6

Browse files
Performance Optimization: Improved TileShape Configuration for Large Llama Shapes (pytorch#3790)
Summary: ## Issue: Suboptimal TileShape Configuration in FBGEMM for Large Llama Shapes The current FBGEMM F8 kernel utilizes a TileShape configuration of 128x128x128, which is suboptimal for dense F8 tensor core operations on NVIDIA H100 GPUs. The optimal configuration for maximizing tensor core throughput and memory bandwidth usage on H100 is m64n256k32. The current setting leads to inefficiencies, particularly for large GEMM operations in Llama 70B and 405B models when K = 4096. ## Proposed Optimization: 128x256x128 TileShape for Large GEMM Operations This PR modifies the TileShape configuration from 128x128x128 to 128x256x128 for large GEMM workloads. The new configuration is applied via a cooperative kernel, ensuring improved tensor core utilization and memory bandwidth efficiency. Notably, this tile shape is also used in FlashAttention V3 for F8 precision. ## Benchmark Results on H100 GPU ### Benchmark Setup: PyTorch 2.6 CUDA 12.4 CPU: AMD EPYC GPU: NVIDIA H100 Benchmarks are configured with 30 kernel launch iterations and averaged over 25 Benchmark calculations. Benchmarks conducted with Llama model sizes 70B and 405B (M = 16,384) ### Benchmark #### f8f8bf16_rowwise (M = 16,384) | Llam Shape | Old TFlops | New Tflops | Improvement | |---------------------|------------|----------- |-------------| | N = 1280 K = 8192 | 1252 | 1492 | +17.4% | | N = 8192 K = 1024 | 1258 | 1258 | — | | N = 7168 K = 8192 | 1324 | 1463 | +10.5% | | N = 8192 K = 3584 | 1401 | 1401 | — | | N = 13312 K = 6656 | 1259 | 1360 | +8.0% | | N = 13312 K = 16384 | 1170 | 1388 | +18.6% | | N = 16384 K = 6656 | 1238 | 1266 | +2.3% | | N = 16384 K = 16384 | 1166 | 1316 | +12.9% | The cooperative 128x256x128 TileShape consistently outperforms the 128x128x128 Ping-Pong kernel for all large GEMM sizes where K >= 4096. For a small subset of cases, the 128x192x128 Ping-Pong kernel achieves a 2-3% performance advantage, notably in the shape M = 16,384, N = 16,384, K = 16,384 A more detailed heuristic rule could be explored for these specific cases. ## Technical Implementation Introduced TileShape 128x256x128 with a cooperative kernel for f8f8bf16_rowwise The new configuration is selectively applied for large matrices where: - **M > 128 && N > 128** - **AND (M > 2048 || N > 2048)** - **AND K > 4096** Performance Validation: - We ensured that the changes do not introduce performance regressions for existing configurations that do not match the above conditions. - The code modifications were designed to preserve existing configurations outside of large GEMM cases. These changes were made by modifying the minimum necessary code while respecting existing coding practices in FBGEMM. ## Test Coverage ### Unit Tests Results The unit tests in fbgemm_gpu/experimental/gen_ai/test/quantize have been verified for the modified kernels. jiawenliu64 jwfromm Thank you! Differential Revision: D72617756 Pulled By: jiawenliu64
1 parent 23fe369 commit ec57cc6

File tree

1 file changed

+51
-25
lines changed

1 file changed

+51
-25
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise.cu

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ template <
3434
int TBS_N,
3535
int TBS_K,
3636
bool PONG,
37+
bool COOP,
3738
bool FAST_ACCUM,
3839
bool USE_BIAS,
3940
typename INPUT_DTYPE,
@@ -170,6 +171,23 @@ at::Tensor f8f8bf16_rowwise_impl(
170171
using EpilogueEVT =
171172
cute::conditional_t<USE_BIAS, EVTComputeBias, EVTCompute1>;
172173

174+
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
175+
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
176+
using SlowAccum = cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
177+
using FastAccum = cute::conditional_t<
178+
COOP,
179+
cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum,
180+
cute::conditional_t<
181+
PONG,
182+
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum,
183+
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum>>;
184+
using MainLoopSchedule =
185+
cute::conditional_t<FAST_ACCUM, FastAccum, SlowAccum>;
186+
using EpilogueSchedule = cute::conditional_t<
187+
COOP,
188+
cutlass::epilogue::TmaWarpSpecializedCooperative,
189+
cutlass::epilogue::TmaWarpSpecialized>;
190+
173191
using CollectiveEpilogue =
174192
typename cutlass::epilogue::collective::CollectiveBuilder<
175193
cutlass::arch::Sm90,
@@ -185,21 +203,9 @@ at::Tensor f8f8bf16_rowwise_impl(
185203
ElementOutput,
186204
LayoutOutput,
187205
AlignmentOutput,
188-
cutlass::epilogue::TmaWarpSpecialized,
206+
EpilogueSchedule,
189207
EpilogueEVT>::CollectiveOp;
190208

191-
using DefaultSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
192-
using PongSchedule = cutlass::gemm::KernelTmaWarpSpecializedPingpong;
193-
using FastDefaultSchedule =
194-
cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccum;
195-
using FastPongSchedule =
196-
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
197-
using SlowAccum = cute::conditional_t<PONG, PongSchedule, DefaultSchedule>;
198-
using FastAccum =
199-
cute::conditional_t<PONG, FastPongSchedule, FastDefaultSchedule>;
200-
using MainLoopSchedule =
201-
cute::conditional_t<FAST_ACCUM, FastAccum, SlowAccum>;
202-
203209
using CollectiveMainloop =
204210
typename cutlass::gemm::collective::CollectiveBuilder<
205211
ArchTag,
@@ -322,6 +328,7 @@ at::Tensor dispatch_fp8_rowwise_kernel(
322328
at::Tensor w_scale,
323329
std::optional<at::Tensor> bias,
324330
std::optional<at::Tensor> output) {
331+
auto K = XQ.size(1);
325332
KernelMode kernel = get_kernel_mode(XQ, WQ);
326333
if (kernel == KernelMode::Small) {
327334
return f8f8bf16_rowwise_impl<
@@ -332,23 +339,41 @@ at::Tensor dispatch_fp8_rowwise_kernel(
332339
1,
333340
1,
334341
false,
342+
false,
335343
FastAccum,
336344
UseBias,
337345
InputDType,
338346
BiasDType>(XQ, WQ, x_scale, w_scale, bias, output);
339347
} else if (kernel == KernelMode::Large) {
340-
return f8f8bf16_rowwise_impl<
341-
128,
342-
128,
343-
128,
344-
2,
345-
1,
346-
1,
347-
true,
348-
FastAccum,
349-
UseBias,
350-
InputDType,
351-
BiasDType>(XQ, WQ, x_scale, w_scale, bias, output);
348+
if (K < 4096) {
349+
return f8f8bf16_rowwise_impl<
350+
128,
351+
128,
352+
128,
353+
2,
354+
1,
355+
1,
356+
true,
357+
false,
358+
FastAccum,
359+
UseBias,
360+
InputDType,
361+
BiasDType>(XQ, WQ, x_scale, w_scale, bias, output);
362+
} else {
363+
return f8f8bf16_rowwise_impl<
364+
128,
365+
256,
366+
128,
367+
2,
368+
1,
369+
1,
370+
false,
371+
true,
372+
FastAccum,
373+
UseBias,
374+
InputDType,
375+
BiasDType>(XQ, WQ, x_scale, w_scale, bias, output);
376+
}
352377
} else {
353378
return f8f8bf16_rowwise_impl<
354379
128,
@@ -358,6 +383,7 @@ at::Tensor dispatch_fp8_rowwise_kernel(
358383
2,
359384
1,
360385
false,
386+
false,
361387
FastAccum,
362388
UseBias,
363389
InputDType,

0 commit comments

Comments
 (0)