Skip to content

Performance Optimization: Improved TileShape Configuration for Large Llama Shapes #3790

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

MatrixAssembler
Copy link
Contributor

Issue: Suboptimal TileShape Configuration in FBGEMM for Large Llama Shapes

The current FBGEMM F8 kernel utilizes a TileShape configuration of 128x128x128,
which is suboptimal for dense F8 tensor core operations on NVIDIA H100 GPUs.

The optimal configuration for maximizing tensor core throughput and memory bandwidth usage
on H100 is m64n256k32. The current setting leads to inefficiencies, particularly
for large GEMM operations in Llama 70B and 405B models when K = 4096.

Proposed Optimization: 128x256x128 TileShape for Large GEMM Operations

This PR modifies the TileShape configuration from 128x128x128 to 128x256x128
for large GEMM workloads. The new configuration is applied via a cooperative kernel,
ensuring improved tensor core utilization and memory bandwidth efficiency.

Notably, this tile shape is also used in FlashAttention V3 for F8 precision.

Benchmark Results on H100 GPU

Benchmark Setup:

PyTorch 2.6
CUDA 12.4
CPU: AMD EPYC
GPU: NVIDIA H100
Benchmarks are configured with 30 kernel launch iterations
and averaged over 25 Benchmark calculations.
Benchmarks conducted with Llama model sizes 70B and 405B (M = 16,384)

Benchmark

f8f8bf16_rowwise (M = 16,384)

Llam Shape Old TFlops New Tflops Improvement
N = 1280 K = 8192 1252 1492 +17.4%
N = 8192 K = 1024 1258 1258
N = 7168 K = 8192 1324 1463 +10.5%
N = 8192 K = 3584 1401 1401
N = 13312 K = 6656 1259 1360 +8.0%
N = 13312 K = 16384 1170 1388 +18.6%
N = 16384 K = 6656 1238 1266 +2.3%
N = 16384 K = 16384 1166 1316 +12.9%

The cooperative 128x256x128 TileShape consistently outperforms the 128x128x128
Ping-Pong kernel for all large GEMM sizes where K >= 4096.

For a small subset of cases, the 128x192x128 Ping-Pong kernel achieves a 2-3%
performance advantage, notably in the shape M = 16,384, N = 16,384, K = 16,384

A more detailed heuristic rule could be explored for these specific cases.

Technical Implementation

Introduced TileShape 128x256x128 with a cooperative kernel for f8f8bf16_rowwise

The new configuration is selectively applied for large matrices where:

  • M > 128 && N > 128
  • AND (M > 2048 || N > 2048)
  • AND K > 4096

Performance Validation:

  • We ensured that the changes do not introduce performance regressions
    for existing configurations that do not match the above conditions.
  • The code modifications were designed to preserve existing configurations outside of large GEMM cases.

These changes were made by modifying the minimum necessary code while respecting
existing coding practices in FBGEMM.

Test Coverage

Unit Tests Results

The unit tests in fbgemm_gpu/experimental/gen_ai/test/quantize
have been verified for the modified kernels.

@jiawenliu64 @jwfromm Thank you!

Copy link

netlify bot commented Mar 10, 2025

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit 0579557
🔍 Latest deploy log https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/67f3f716e5556700089b75ad
😎 Deploy Preview https://deploy-preview-3790--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@facebook-github-bot
Copy link
Contributor

@jiawenliu64 has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

jiawenliu64 pushed a commit to jiawenliu64/FBGEMM that referenced this pull request Apr 8, 2025
…Llama Shapes (pytorch#3790)

Summary:
## Issue: Suboptimal TileShape Configuration in FBGEMM for Large Llama Shapes
The current FBGEMM F8 kernel utilizes a TileShape configuration of 128x128x128,
which is suboptimal for dense F8 tensor core operations on NVIDIA H100 GPUs.

The optimal configuration for maximizing tensor core throughput and memory bandwidth usage
on H100 is m64n256k32. The current setting leads to inefficiencies, particularly
for large GEMM operations in Llama 70B and 405B models when K = 4096.


## Proposed Optimization: 128x256x128 TileShape for Large GEMM Operations
This PR modifies the TileShape configuration from 128x128x128 to 128x256x128
for large GEMM workloads. The new configuration is applied via a cooperative kernel,
ensuring improved tensor core utilization and memory bandwidth efficiency.

Notably, this tile shape is also used in FlashAttention V3 for F8 precision.


## Benchmark Results on H100 GPU
### Benchmark Setup:
PyTorch 2.6
CUDA 12.4
CPU: AMD EPYC
GPU: NVIDIA H100
Benchmarks are configured with 30 kernel launch iterations
and averaged over 25 Benchmark calculations.
Benchmarks conducted with Llama model sizes 70B and 405B (M = 16,384)

### Benchmark
#### f8f8bf16_rowwise (M = 16,384)
| Llam Shape          | Old TFlops | New Tflops | Improvement |
|---------------------|------------|----------- |-------------|
| N =  1280 K =  8192 |       1252 |       1492 |      +17.4% |
| N =  8192 K =  1024 |       1258 |       1258 |          —  |
| N =  7168 K =  8192 |       1324 |       1463 |      +10.5% |
| N =  8192 K =  3584 |       1401 |       1401 |          —  |
| N = 13312 K =  6656 |       1259 |       1360 |       +8.0% |
| N = 13312 K = 16384 |       1170 |       1388 |      +18.6% |
| N = 16384 K =  6656 |       1238 |       1266 |       +2.3% |
| N = 16384 K = 16384 |       1166 |       1316 |      +12.9% |

The cooperative 128x256x128 TileShape consistently outperforms the 128x128x128
Ping-Pong kernel for all large GEMM sizes where K >= 4096.

For a small subset of cases, the 128x192x128 Ping-Pong kernel achieves a 2-3%
performance advantage, notably in the shape M = 16,384, N = 16,384, K = 16,384

A more detailed heuristic rule could be explored for these specific cases.

## Technical Implementation
Introduced TileShape 128x256x128 with a cooperative kernel for f8f8bf16_rowwise

The new configuration is selectively applied for large matrices where:
- **M > 128 && N > 128**
- **AND (M > 2048 || N > 2048)**
- **AND K > 4096**

Performance Validation:
- We ensured that the changes do not introduce performance regressions
  for existing configurations that do not match the above conditions.
- The code modifications were designed to preserve existing configurations outside of large GEMM cases.

These changes were made by modifying the minimum necessary code while respecting
existing coding practices in FBGEMM.


## Test Coverage
### Unit Tests Results
The unit tests in fbgemm_gpu/experimental/gen_ai/test/quantize
have been verified for the modified kernels.

jiawenliu64 jwfromm Thank you!


Differential Revision: D72617756

Pulled By: jiawenliu64
@facebook-github-bot
Copy link
Contributor

@jiawenliu64 merged this pull request in d925730.

q10 pushed a commit to q10/FBGEMM that referenced this pull request Apr 10, 2025
…Llama Shapes (pytorch#1025)

Summary:
X-link: pytorch#3942

Pull Request resolved: facebookresearch/FBGEMM#1025

## Issue: Suboptimal TileShape Configuration in FBGEMM for Large Llama Shapes
The current FBGEMM F8 kernel utilizes a TileShape configuration of 128x128x128,
which is suboptimal for dense F8 tensor core operations on NVIDIA H100 GPUs.

The optimal configuration for maximizing tensor core throughput and memory bandwidth usage
on H100 is m64n256k32. The current setting leads to inefficiencies, particularly
for large GEMM operations in Llama 70B and 405B models when K = 4096.

## Proposed Optimization: 128x256x128 TileShape for Large GEMM Operations
This PR modifies the TileShape configuration from 128x128x128 to 128x256x128
for large GEMM workloads. The new configuration is applied via a cooperative kernel,
ensuring improved tensor core utilization and memory bandwidth efficiency.

Notably, this tile shape is also used in FlashAttention V3 for F8 precision.

## Benchmark Results on H100 GPU
### Benchmark Setup:
PyTorch 2.6
CUDA 12.4
CPU: AMD EPYC
GPU: NVIDIA H100
Benchmarks are configured with 30 kernel launch iterations
and averaged over 25 Benchmark calculations.
Benchmarks conducted with Llama model sizes 70B and 405B (M = 16,384)

### Benchmark
#### f8f8bf16_rowwise (M = 16,384)
| Llam Shape          | Old TFlops | New Tflops | Improvement |
|---------------------|------------|----------- |-------------|
| N =  1280 K =  8192 |       1252 |       1492 |      +17.4% |
| N =  8192 K =  1024 |       1258 |       1258 |          —  |
| N =  7168 K =  8192 |       1324 |       1463 |      +10.5% |
| N =  8192 K =  3584 |       1401 |       1401 |          —  |
| N = 13312 K =  6656 |       1259 |       1360 |       +8.0% |
| N = 13312 K = 16384 |       1170 |       1388 |      +18.6% |
| N = 16384 K =  6656 |       1238 |       1266 |       +2.3% |
| N = 16384 K = 16384 |       1166 |       1316 |      +12.9% |

The cooperative 128x256x128 TileShape consistently outperforms the 128x128x128
Ping-Pong kernel for all large GEMM sizes where K >= 4096.

For a small subset of cases, the 128x192x128 Ping-Pong kernel achieves a 2-3%
performance advantage, notably in the shape M = 16,384, N = 16,384, K = 16,384

A more detailed heuristic rule could be explored for these specific cases.

## Technical Implementation
Introduced TileShape 128x256x128 with a cooperative kernel for f8f8bf16_rowwise

The new configuration is selectively applied for large matrices where:
- **M > 128 && N > 128**
- **AND (M > 2048 || N > 2048)**
- **AND K > 4096**

Performance Validation:
- We ensured that the changes do not introduce performance regressions
  for existing configurations that do not match the above conditions.
- The code modifications were designed to preserve existing configurations outside of large GEMM cases.

These changes were made by modifying the minimum necessary code while respecting
existing coding practices in FBGEMM.

## Test Coverage
### Unit Tests Results
The unit tests in fbgemm_gpu/experimental/gen_ai/test/quantize
have been verified for the modified kernels.

jiawenliu64 jwfromm Thank you!

X-link: pytorch#3790

Reviewed By: jianyuh

Differential Revision: D72617756

Pulled By: jiawenliu64

fbshipit-source-id: ec6f78af18fdda38acae99dda9a772d03e0b1ea7
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants