Skip to content

[FEAT][ROCm] Integrate Paged Attention Kernel from AITER #15001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Apr 22, 2025

Conversation

vllmellm
Copy link
Contributor

@vllmellm vllmellm commented Mar 18, 2025

This PR integrates Paged Attention Kernel from AITER (AI Tensor Engine for ROCm)

The pa_fwd_asm kernel from AITER is integrated as a new paged attention op in /vllm/attention/ops/rocm_aiter_paged_attn.py and implemented into the ROCM attention backend in /vllm/attention/backends/rocm_flash_attn.py.

This feature is disabled by default, even when the parent switch (VLLM_ROCM_USE_AITER=1) is enabled. To use this kernel, both the parent switch and its dedicated environment variable VLLM_ROCM_USE_AITER_PAGED_ATTN must be enabled.

Note:

  • The AITER paged attention module supports the following kv_cache_dtypes:
    • int8
    • "fp8"
    • "fp8_e4m3"
    • bfloat16
    • float16
  • However, for float16 and bfloat16 kv_cache_dtype, the module currently does not support decoding of models with more than 1 kv_head. Thus, a fallback to the original v1/v2 paged attention is added.

Performance Improvement Tables

The https://github.com/vllm-project/vllm/blob/main/benchmarks/benchmark_serving.py script has been used to evaluate the performance on the following models

  • Llama-3.1-8B-Instruct
  • Mixtral-8x7B-Instruct-v0.1
  • Llama-3.1-70B-Instruct
  • Mixtral-8x22B-Instruct-v0.1

Dataset: Random
Input length: 1024
Output length: 128

The ROCm Custom Paged Attention method, which can be enabled using the VLLM_ROCM_CUSTOM_PAGED_ATTN=1 flag, was used as a baseline for comparison. Furthermore, all benchmarks were run using --quantization fp8 and --kv-cache-dtype fp8 args.

Request throughput (req/s)

Model Custom Paged Attention Aiter Paged Attention
Llama-3.1-8B-Instruct 25.74 36.28
Mixtral-8x7B-Instruct-v0.1 7.00 9.85
Llama-3.1-70B-Instruct 12.36 12.34
Mixtral-8x22B-Instruct-v0.1 12.97 13.41

Output token throughput (tok/s)

Model Custom Paged Attention Aiter Paged Attention
Llama-3.1-8B-Instruct 3240.12 4396.78
Mixtral-8x7B-Instruct-v0.1 719.26 1123.64
Llama-3.1-70B-Instruct 1158.62 1162.65
Mixtral-8x22B-Instruct-v0.1 1365.70 1374.49

Total Token throughput (tok/s)

Model Custom Paged Attention Aiter Paged Attention
Llama-3.1-8B-Instruct 29595.76 41544.51
Mixtral-8x7B-Instruct-v0.1 7883.01 11208.82
Llama-3.1-70B-Instruct 13813.79 13793.99
Mixtral-8x22B-Instruct-v0.1 14651.28 15105.98

Mean TTFT (ms)

Model Custom Paged Attention Aiter Paged Attention
Llama-3.1-8B-Instruct 19095.64 13246.17
Mixtral-8x7B-Instruct-v0.1 70922.54 52720.16
Llama-3.1-70B-Instruct 39295.75 39108.72
Mixtral-8x22B-Instruct-v0.1 32518.47 33817.36

Mean TPOT (ms)

Model Custom Paged Attention Aiter Paged Attention
Llama-3.1-8B-Instruct 53.98 52.42
Mixtral-8x7B-Instruct-v0.1 620.13 402.20
Llama-3.1-70B-Instruct 316.37 323.72
Mixtral-8x22B-Instruct-v0.1 329.17 342.43

Mean ITL (ms)

Model Custom Paged Attention Aiter Paged Attention
Llama-3.1-8B-Instruct 44.78 41.73
Mixtral-8x7B-Instruct-v0.1 296.86 173.87
Llama-3.1-70B-Instruct 170.88 170.14
Mixtral-8x22B-Instruct-v0.1 146.42 143.07

Lmeval

AttentionType Tasks Version Filter n-shot Metric Value Stderr
Custom PA gsm8k 3 flexible-extract 5 exact_match 0.7703 ± 0.0116
strict-match 5 exact_match 0.7422 ± 0.0120
AITER PA gsm8k 3 flexible-extract 5 exact_match 0.7612 ± 0.0117
strict-match 5 exact_match 0.7324 ± 0.0122

AITER Operations Testing Overview

1. High-Level Integration Tests

The integration of AITER ops is tested at a higher module level in the following files under /tests/models/decoder_only/language:

  • test_models.py
  • test_phimoe.py
  • test_mistral.py
  • test_granite.py

These tests involve running various models to ensure overall functionality.

2. AITER MoE Specific Test

  • The AITER Mixture of Experts (MoE) is specifically tested for the Mixtral model in:
    /tests/kernels/test_moe.py

3. Quantization Testing

  • Quantization methods for AITER-enabled modules are tested in:
    /tests/quantization/test_fp8.py

4. Kernel Function Dispatch Testing

  • The correct dispatching of kernel functions (AITER-enabled or not) is verified in:
    /tests/model_executor/test_enabled_custom_ops.py

Environment Settings

Updates in Dockerfile.rocm_base:

Added AITER Package:

  • AITER_BRANCH: 7e1ed08
    Note:

  • When setting up AITER, it is crucial to use the command git clone --recursive. This is because the package depends on a third-party package (Composable Kernel).

  • For building and installing the AITER Python package, you must use the PREBUILD_KERNELS=1 flag along with the command python3 setup.py develop. This ensures that all kernels in the AITER package are built successfully.

The following branches were used as references for this integration:

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Mar 18, 2025
@vllmellm vllmellm marked this pull request as ready for review March 18, 2025 04:41
Copy link
Contributor

@SageMoore SageMoore left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you post Lm_eval results for the main models that this kernel supports?

@@ -15,6 +15,7 @@
CommonMetadataBuilder)
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
from vllm.attention.ops.rocm_aiter_paged_attn import AITERPagedAttention
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this always attempt to import AITER even if it's disabled?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore Are you suggesting to move this line to line 50 after checking whether it is enabled?

Copy link
Contributor Author

@vllmellm vllmellm Apr 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SageMoore @hongxiayang AITER is now imported only when the flag is set.

Copy link

mergify bot commented Mar 26, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vllmellm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 26, 2025
@mergify mergify bot removed the needs-rebase label Mar 26, 2025
Copy link

mergify bot commented Mar 31, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vllmellm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 31, 2025
Signed-off-by: tjtanaa <[email protected]>
@tjtanaa
Copy link
Contributor

tjtanaa commented Apr 21, 2025

@sunway513 @hongxiayang We have just updated the PR with lm_eval and performance values for --quantization fp8 --kv-cache-dtype fp8 V0 Engine. It is now ready for review.

@hongxiayang
Copy link
Collaborator

what's left to get this PR merged in? cc @hongxiayang
@sunway513 Based on feedback last time, lm_eval result is requested.

Hi, @SageMoore: @tjtanaa has updated the description and included lm_eval result and addressed the review feedback, can you please review again at your earliest convenience? As you already know, this is blocking the decommission of the ROCm fork. Thanks a lot.

@sunway513
Copy link

great! @gshtras who can help expedite the review?

Signed-off-by: vllmellm <[email protected]>
@hongxiayang
Copy link
Collaborator

hongxiayang commented Apr 21, 2025

cc @DarkLight1337 Can you help to expedite the review and merge of this PR?

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 21, 2025
@ProExpertProg
Copy link
Collaborator

Does the AITER kernel support fused output quantization?

@tlrmchlsmth
Copy link
Collaborator

@vllmellm please fix the pre-commit

Signed-off-by: vllmellm <[email protected]>
Copy link

mergify bot commented Apr 22, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @vllmellm.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 22, 2025
@mergify mergify bot removed the needs-rebase label Apr 22, 2025
@vllm-bot vllm-bot merged commit 0e237f0 into vllm-project:main Apr 22, 2025
43 of 46 checks passed
frieda-huang pushed a commit to frieda-huang/vllm that referenced this pull request Apr 23, 2025
…t#15001)

Signed-off-by: vllmellm <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Co-authored-by: tjtanaa <[email protected]>
Signed-off-by: Frieda (Jingying) Huang <[email protected]>
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
…t#15001)

Signed-off-by: vllmellm <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Co-authored-by: tjtanaa <[email protected]>
Signed-off-by: Agata Dobrzyniewicz <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
…t#15001)

Signed-off-by: vllmellm <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Co-authored-by: tjtanaa <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
minpeter pushed a commit to minpeter/vllm that referenced this pull request Jun 24, 2025
…t#15001)

Signed-off-by: vllmellm <[email protected]>
Signed-off-by: tjtanaa <[email protected]>
Co-authored-by: tjtanaa <[email protected]>
Signed-off-by: minpeter <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants