Skip to content

FP8 Grouped Gemm Optimization #3655

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed

Conversation

jwfromm
Copy link
Contributor

@jwfromm jwfromm commented Feb 4, 2025

Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/731

While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping quantize_fp8_row and having to slice input tensors before calling f8f8bf16_rowwise_grouped.

To fix the former, we enable triton_quantize_fp8_row to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead.

To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor.

In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads.

Reviewed By: jiawenliu64

Differential Revision: D69072529

@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69072529

Copy link

netlify bot commented Feb 4, 2025

Deploy Preview for pytorch-fbgemm-docs ready!

Name Link
🔨 Latest commit e2dd52b
🔍 Latest deploy log https://app.netlify.com/sites/pytorch-fbgemm-docs/deploys/67a55edff7eb6c0008ac8453
😎 Deploy Preview https://deploy-preview-3655--pytorch-fbgemm-docs.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

jwfromm added a commit to jwfromm/FBGEMM that referenced this pull request Feb 6, 2025
Summary:

X-link: facebookresearch/FBGEMM#731

While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`.

To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead.

To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor.

In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads.

Reviewed By: jiawenliu64

Differential Revision: D69072529
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69072529

jwfromm added a commit to jwfromm/FBGEMM that referenced this pull request Feb 6, 2025
Summary:

X-link: facebookresearch/FBGEMM#731

While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`.

To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead.

To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor.

In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads.

Reviewed By: jiawenliu64

Differential Revision: D69072529
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69072529

jwfromm pushed a commit to jwfromm/FBGEMM that referenced this pull request Feb 6, 2025
Summary:
Pull Request resolved: pytorch#3655

X-link: facebookresearch/FBGEMM#731

While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`.

To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead.

To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor.

In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads.

Differential Revision: D69072529

Reviewed By: jiawenliu64
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69072529

jwfromm added a commit to jwfromm/FBGEMM that referenced this pull request Feb 6, 2025
Summary:
Pull Request resolved: pytorch#3655

X-link: facebookresearch/FBGEMM#731

While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`.

To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead.

To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor.

In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads.

Reviewed By: jiawenliu64

Differential Revision: D69072529
jwfromm pushed a commit to jwfromm/FBGEMM that referenced this pull request Feb 6, 2025
Summary:
Pull Request resolved: pytorch#3655

X-link: facebookresearch/FBGEMM#731

While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`.

To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead.

To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor.

In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads.

Differential Revision: D69072529

Reviewed By: jiawenliu64
jwfromm pushed a commit to jwfromm/FBGEMM that referenced this pull request Feb 7, 2025
Summary:
Pull Request resolved: pytorch#3655

X-link: facebookresearch/FBGEMM#731

While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`.

To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead.

To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor.

In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads.

Differential Revision: D69072529

Reviewed By: jianyuh, jiawenliu64
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69072529

jwfromm added a commit to jwfromm/FBGEMM that referenced this pull request Feb 7, 2025
Summary:
Pull Request resolved: pytorch#3655

X-link: facebookresearch/FBGEMM#731

While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`.

To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead.

To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor.

In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads.

Reviewed By: jianyuh, jiawenliu64

Differential Revision: D69072529
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69072529

jwfromm added a commit to jwfromm/FBGEMM that referenced this pull request Feb 7, 2025
Summary:
Pull Request resolved: pytorch#3655

X-link: facebookresearch/FBGEMM#731

While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`.

To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead.

To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor.

In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads.

Reviewed By: jianyuh, jiawenliu64

Differential Revision: D69072529
Summary:
Pull Request resolved: pytorch#3655

X-link: facebookresearch/FBGEMM#731

While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`.

To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead.

To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor.

In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads.

Reviewed By: jianyuh, jiawenliu64

Differential Revision: D69072529
@facebook-github-bot
Copy link
Contributor

This pull request was exported from Phabricator. Differential Revision: D69072529

@facebook-github-bot
Copy link
Contributor

This pull request has been merged in d564c8c.

@q10 q10 added the feature:fp8 label Feb 8, 2025
q10 pushed a commit to q10/FBGEMM that referenced this pull request Apr 10, 2025
Summary:
X-link: pytorch#3655

Pull Request resolved: facebookresearch/FBGEMM#731

While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`.

To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead.

To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor.

In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads.

Reviewed By: jianyuh, jiawenliu64

Differential Revision: D69072529

fbshipit-source-id: b90b4d1c76bf813f94f36cd21a55118442f62b38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants