-
-
Notifications
You must be signed in to change notification settings - Fork 8.5k
[ROCM] enable aiter fused moe kernel for llama4 bf16 checkpoints #16674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
Thanks for the PR. We will validate @sijiac @houseroad . |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. We need to enable the graph mode in order to achieve the full performance benefit comparing with V1 graph mode without aiter.
@sijiac @hongxiayang we will enable for V1 after this PR is merged. |
As a supplementary information this PR: The GSM8K lmeval score of AITER kernel of
The GSM8K lmeval score of AITER kernel of
|
cc @mgoin, could you take a look at the PR? is good to be merged? |
@@ -39,6 +40,16 @@ def rocm_aiter_fused_experts( | |||
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( | |||
per_token_group_quant_fp8) | |||
|
|||
if apply_router_weight_on_input: | |||
_, topk = topk_weights.shape |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we assert topk_weights.dim() == 2?
|
||
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) | ||
topk_ids = topk_ids.to(torch.int32) | ||
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does AITER require fp32 weight?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, the passed-in topk_weights
must in fp32 dtype, otherwise, it will have numeric issues
fix the linter? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good.
…m-project#16674) Signed-off-by: Yang Wang <[email protected]>
…m-project#16674) Signed-off-by: Agata Dobrzyniewicz <[email protected]>
…m-project#16674) Signed-off-by: Mu Huai <[email protected]>
This PR enabled aiter fused moe for bf16 checkpoints to improve the performance
NOTE: It doesn't support torch.compile() at this moment so we use eager mode for both when doing the benchmark. The issue is tracked in: ROCm/aiter#244
16E bf16 benchmark
128E bf16 benchmark
Generation Test
16E
128E