-
Notifications
You must be signed in to change notification settings - Fork 610
FP8 Grouped Gemm Optimization #3655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This pull request was exported from Phabricator. Differential Revision: D69072529 |
✅ Deploy Preview for pytorch-fbgemm-docs ready!
To edit notification comments on pull requests, go to your Netlify site configuration. |
606449f
to
632354b
Compare
Summary: X-link: facebookresearch/FBGEMM#731 While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`. To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead. To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor. In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads. Reviewed By: jiawenliu64 Differential Revision: D69072529
This pull request was exported from Phabricator. Differential Revision: D69072529 |
Summary: X-link: facebookresearch/FBGEMM#731 While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`. To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead. To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor. In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads. Reviewed By: jiawenliu64 Differential Revision: D69072529
632354b
to
3582d66
Compare
This pull request was exported from Phabricator. Differential Revision: D69072529 |
Summary: Pull Request resolved: pytorch#3655 X-link: facebookresearch/FBGEMM#731 While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`. To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead. To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor. In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads. Differential Revision: D69072529 Reviewed By: jiawenliu64
This pull request was exported from Phabricator. Differential Revision: D69072529 |
3582d66
to
08fcc98
Compare
Summary: Pull Request resolved: pytorch#3655 X-link: facebookresearch/FBGEMM#731 While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`. To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead. To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor. In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads. Reviewed By: jiawenliu64 Differential Revision: D69072529
Summary: Pull Request resolved: pytorch#3655 X-link: facebookresearch/FBGEMM#731 While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`. To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead. To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor. In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads. Differential Revision: D69072529 Reviewed By: jiawenliu64
Summary: Pull Request resolved: pytorch#3655 X-link: facebookresearch/FBGEMM#731 While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`. To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead. To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor. In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads. Differential Revision: D69072529 Reviewed By: jianyuh, jiawenliu64
This pull request was exported from Phabricator. Differential Revision: D69072529 |
Summary: Pull Request resolved: pytorch#3655 X-link: facebookresearch/FBGEMM#731 While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`. To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead. To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor. In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads. Reviewed By: jianyuh, jiawenliu64 Differential Revision: D69072529
08fcc98
to
2db3c2f
Compare
This pull request was exported from Phabricator. Differential Revision: D69072529 |
Summary: Pull Request resolved: pytorch#3655 X-link: facebookresearch/FBGEMM#731 While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`. To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead. To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor. In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads. Reviewed By: jianyuh, jiawenliu64 Differential Revision: D69072529
2db3c2f
to
67ac9f8
Compare
Summary: Pull Request resolved: pytorch#3655 X-link: facebookresearch/FBGEMM#731 While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`. To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead. To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor. In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads. Reviewed By: jianyuh, jiawenliu64 Differential Revision: D69072529
This pull request was exported from Phabricator. Differential Revision: D69072529 |
67ac9f8
to
e2dd52b
Compare
This pull request has been merged in d564c8c. |
Summary: X-link: pytorch#3655 Pull Request resolved: facebookresearch/FBGEMM#731 While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`. To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead. To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor. In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads. Reviewed By: jianyuh, jiawenliu64 Differential Revision: D69072529 fbshipit-source-id: b90b4d1c76bf813f94f36cd21a55118442f62b38
Summary:
X-link: https://github.com/facebookresearch/FBGEMM/pull/731
While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping
quantize_fp8_row
and having to slice input tensors before callingf8f8bf16_rowwise_grouped
.To fix the former, we enable
triton_quantize_fp8_row
to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead.To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor.
In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads.
Reviewed By: jiawenliu64
Differential Revision: D69072529