You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Performance Optimization: Improved TileShape Configuration for Large Llama Shapes (pytorch#3790)
Summary:
## Issue: Suboptimal TileShape Configuration in FBGEMM for Large Llama Shapes
The current FBGEMM F8 kernel utilizes a TileShape configuration of 128x128x128,
which is suboptimal for dense F8 tensor core operations on NVIDIA H100 GPUs.
The optimal configuration for maximizing tensor core throughput and memory bandwidth usage
on H100 is m64n256k32. The current setting leads to inefficiencies, particularly
for large GEMM operations in Llama 70B and 405B models when K = 4096.
## Proposed Optimization: 128x256x128 TileShape for Large GEMM Operations
This PR modifies the TileShape configuration from 128x128x128 to 128x256x128
for large GEMM workloads. The new configuration is applied via a cooperative kernel,
ensuring improved tensor core utilization and memory bandwidth efficiency.
Notably, this tile shape is also used in FlashAttention V3 for F8 precision.
## Benchmark Results on H100 GPU
### Benchmark Setup:
PyTorch 2.6
CUDA 12.4
CPU: AMD EPYC
GPU: NVIDIA H100
Benchmarks are configured with 30 kernel launch iterations
and averaged over 25 Benchmark calculations.
Benchmarks conducted with Llama model sizes 70B and 405B (M = 16,384)
### Benchmark
#### f8f8bf16_rowwise (M = 16,384)
| Llam Shape | Old TFlops | New Tflops | Improvement |
|---------------------|------------|----------- |-------------|
| N = 1280 K = 8192 | 1252 | 1492 | +17.4% |
| N = 8192 K = 1024 | 1258 | 1258 | — |
| N = 7168 K = 8192 | 1324 | 1463 | +10.5% |
| N = 8192 K = 3584 | 1401 | 1401 | — |
| N = 13312 K = 6656 | 1259 | 1360 | +8.0% |
| N = 13312 K = 16384 | 1170 | 1388 | +18.6% |
| N = 16384 K = 6656 | 1238 | 1266 | +2.3% |
| N = 16384 K = 16384 | 1166 | 1316 | +12.9% |
The cooperative 128x256x128 TileShape consistently outperforms the 128x128x128
Ping-Pong kernel for all large GEMM sizes where K >= 4096.
For a small subset of cases, the 128x192x128 Ping-Pong kernel achieves a 2-3%
performance advantage, notably in the shape M = 16,384, N = 16,384, K = 16,384
A more detailed heuristic rule could be explored for these specific cases.
## Technical Implementation
Introduced TileShape 128x256x128 with a cooperative kernel for f8f8bf16_rowwise
The new configuration is selectively applied for large matrices where:
- **M > 128 && N > 128**
- **AND (M > 2048 || N > 2048)**
- **AND K > 4096**
Performance Validation:
- We ensured that the changes do not introduce performance regressions
for existing configurations that do not match the above conditions.
- The code modifications were designed to preserve existing configurations outside of large GEMM cases.
These changes were made by modifying the minimum necessary code while respecting
existing coding practices in FBGEMM.
## Test Coverage
### Unit Tests Results
The unit tests in fbgemm_gpu/experimental/gen_ai/test/quantize
have been verified for the modified kernels.
jiawenliu64 jwfromm Thank you!
Differential Revision: D72617756
Pulled By: jiawenliu64
0 commit comments