Skip to content

[FEAT][ROCm] Integrate Paged Attention Kernel from AITER #15001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
dc09d66
add AITER paged attention kernel
vllmellm Mar 17, 2025
fe9ff98
include AITER enable for rocm platforms in model end to end tests
vllmellm Mar 17, 2025
d7c5dfb
add AITER into rocm docker base file
vllmellm Mar 17, 2025
c24fc09
Merge remote-tracking branch 'origin/main' into aiter-paged-attn-inte…
vllmellm Mar 18, 2025
1732f9a
use clearer name for paged attention module used in ROCmFlashAttentio…
vllmellm Mar 18, 2025
85296f7
fix get envs variables in unit tests
vllmellm Mar 18, 2025
07ac4d4
Remove AttentionOps class instead use a simple funtion to return appr…
vllmellm Mar 18, 2025
1592e7e
remove cascading logic from vllm.envs
vllmellm Mar 19, 2025
07bf5c6
refactor aiter unit test flags into decorator
tjtanaa Mar 19, 2025
1fdd695
modify the rocm AITER check tests based on new decorator and include …
vllmellm Mar 19, 2025
bb3687d
remove the decorator for enability of rocm AITER ops in tests
vllmellm Mar 26, 2025
2dfa16f
Merge remote-tracking branch 'origin/main' into aiter-paged-attn-inte…
vllmellm Mar 26, 2025
9087f44
match the tests files and run-amd-test script to the main branch
vllmellm Mar 26, 2025
32b7a9b
sync with main
tjtanaa Apr 1, 2025
052d9e0
import AITERPagedAttention only if flag is set
vllmellm Apr 21, 2025
15862f1
prefer current_platform.fp8_dtype over the harcoded dtype
vllmellm Apr 21, 2025
2e65b95
Merge remote-tracking branch 'origin/main' into aiter-paged-attn-inte…
vllmellm Apr 21, 2025
15406cb
cache aiter pa import
vllmellm Apr 21, 2025
a9ef9f9
update aiter commit
vllmellm Apr 21, 2025
e203aed
correct comment
vllmellm Apr 21, 2025
976da61
fix spelling mistake
vllmellm Apr 22, 2025
0f5f2d0
prefer utils cdiv
vllmellm Apr 22, 2025
cc79ec9
Merge remote-tracking branch 'origin/main' into aiter-paged-attn-inte…
vllmellm Apr 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .buildkite/run-amd-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@ HF_CACHE="$(realpath ~)/huggingface"
mkdir -p "${HF_CACHE}"
HF_MOUNT="/root/.cache/huggingface"

# environment variables
SKIP_ROCM_ATIER_MODEL_TEST_CASES="True"
echo $SKIP_ROCM_ATIER_MODEL_TEST_CASES

commands=$@
echo "Commands:$commands"
#ignore certain kernels tests
Expand Down
13 changes: 13 additions & 0 deletions Dockerfile.rocm_base
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="b7d29fb"
ARG FA_REPO="https://github.com/ROCm/flash-attention.git"
ARG AITER_BRANCH="e1ec015"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"

FROM ${BASE_IMAGE} AS base

Expand Down Expand Up @@ -129,6 +131,15 @@ RUN --mount=type=bind,from=build_amdsmi,src=/app/install/,target=/install \
RUN --mount=type=bind,from=build_pytorch,src=/app/install/,target=/install \
pip install /install/*.whl

ARG AITER_REPO
ARG AITER_BRANCH
RUN git clone --recursive ${AITER_REPO}
RUN cd aiter \
&& git checkout ${AITER_BRANCH} \
&& git submodule update --init --recursive \
&& pip install -r requirements.txt \
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop && pip show aiter

ARG BASE_IMAGE
ARG HIPBLASLT_BRANCH
ARG LEGACY_HIPBLASLT_OPTION
Expand Down Expand Up @@ -156,3 +167,5 @@ RUN echo "BASE_IMAGE: ${BASE_IMAGE}" > /app/versions.txt \
&& echo "PYTORCH_VISION_REPO: ${PYTORCH_VISION_REPO}" >> /app/versions.txt \
&& echo "FA_BRANCH: ${FA_BRANCH}" >> /app/versions.txt \
&& echo "FA_REPO: ${FA_REPO}" >> /app/versions.txt
&& echo "AITER_BRANCH: ${AITER_BRANCH}" >> /app/versions.txt \
&& echo "AITER_REPO: ${AITER_REPO}" >> /app/versions.txt
81 changes: 49 additions & 32 deletions tests/models/decoder_only/language/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
"""
import copy
import json
import os

import jsonschema
import jsonschema.exceptions
import pytest

from vllm.entrypoints.openai.tool_parsers.mistral_tool_parser import ( # noqa
MistralToolParser)
from vllm.platforms import current_platform
from vllm.sampling_params import GuidedDecodingParams, SamplingParams

from ...utils import check_logprobs_close
Expand Down Expand Up @@ -174,15 +176,16 @@
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
dtype: str, max_tokens: int, num_logprobs: int,
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
pytest.skip("Skipping test suite for ROCM AITER")
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# TODO(sang): Sliding window should be tested separately.
with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
Expand All @@ -206,14 +209,16 @@ def test_models(
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_mistral_format(
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_mistral_format(vllm_runner, example_prompts, model: str, dtype: str,
max_tokens: int, num_logprobs: int,
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
pytest.skip("Skipping test suite for ROCM AITER")
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

with vllm_runner(
model,
dtype=dtype,
Expand Down Expand Up @@ -244,11 +249,15 @@ def test_mistral_format(

@pytest.mark.parametrize("model", MISTRAL_FORMAT_MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
def test_mistral_symbolic_languages(
vllm_runner,
model: str,
dtype: str,
) -> None:
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_mistral_symbolic_languages(vllm_runner, model: str, dtype: str,
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
pytest.skip("Skipping test suite for ROCM AITER")
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

with vllm_runner(model,
dtype=dtype,
max_model_len=8192,
Expand All @@ -266,11 +275,15 @@ def test_mistral_symbolic_languages(
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("model",
MISTRAL_FORMAT_MODELS) # v1 can't do func calling
def test_mistral_function_calling(
vllm_runner,
model: str,
dtype: str,
) -> None:
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_mistral_function_calling(vllm_runner, model: str, dtype: str,
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
pytest.skip("Skipping test suite for ROCM AITER")
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

with vllm_runner(model,
dtype=dtype,
tokenizer_mode="mistral",
Expand Down Expand Up @@ -301,11 +314,15 @@ def test_mistral_function_calling(
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("guided_backend",
["outlines", "lm-format-enforcer", "xgrammar"])
def test_mistral_guided_decoding(
vllm_runner,
model: str,
guided_backend: str,
) -> None:
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_mistral_guided_decoding(vllm_runner, model: str, guided_backend: str,
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
pytest.skip("Skipping test suite for ROCM AITER")
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

with vllm_runner(model, dtype='bfloat16',
tokenizer_mode="mistral") as vllm_model:

Expand Down
12 changes: 12 additions & 0 deletions tests/models/decoder_only/language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,12 @@

Run `pytest tests/models/test_models.py`.
"""
import os

import pytest

from vllm.platforms import current_platform

from ...utils import check_logprobs_close

# These have unsupported head_dim for FA. We do not
Expand Down Expand Up @@ -69,6 +73,8 @@
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [5])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_models(
hf_runner,
vllm_runner,
Expand All @@ -77,11 +83,17 @@ def test_models(
dtype: str,
max_tokens: int,
num_logprobs: int,
use_rocm_aiter: bool,
monkeypatch,
) -> None:
if model in REQUIRES_V0:
monkeypatch.setenv("VLLM_USE_V1", "0")

if use_rocm_aiter:
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
pytest.skip("Skipping test suite for ROCM AITER")
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

with hf_runner(model, dtype=dtype) as hf_model:
if model.startswith("THUDM/chatglm3"):
hf_model.model.get_output_embeddings = lambda: \
Expand Down
21 changes: 12 additions & 9 deletions tests/models/decoder_only/language/test_phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

Run `pytest tests/models/test_phimoe.py`.
"""
import os

import pytest
import torch

Expand Down Expand Up @@ -79,15 +81,16 @@ def test_phimoe_routing_function():
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_models(hf_runner, vllm_runner, example_prompts, model: str,
dtype: str, max_tokens: int, num_logprobs: int,
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
if os.getenv("SKIP_ROCM_ATIER_MODEL_TEST_CASES") == "true":
pytest.skip("Skipping test suite for ROCM AITER")
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
Expand Down
22 changes: 19 additions & 3 deletions tests/quantization/test_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", MODELS)
@pytest.mark.parametrize("force_marlin", [False, True])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
monkeypatch) -> None:
use_rocm_aiter: bool, monkeypatch) -> None:
if force_marlin:
monkeypatch.setenv("VLLM_TEST_FORCE_FP8_MARLIN", "1")

if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

with vllm_runner(model_id) as llm:
# note: this does not test accuracy, just that we can run through
# see lm-eval tests for accuracy
Expand All @@ -47,7 +52,13 @@ def test_model_load_and_run(vllm_runner, model_id: str, force_marlin: bool,
@pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_id", KV_CACHE_MODELS)
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str, monkeypatch):
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_kv_cache_model_load_and_run(vllm_runner, model_id: str,
use_rocm_aiter: bool, monkeypatch):
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")
with vllm_runner(model_id, kv_cache_dtype="fp8") as llm:
Expand Down Expand Up @@ -86,8 +97,13 @@ def check_model(model):
reason="FP8 is not supported on this GPU type.")
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("force_marlin", [False, True])
@pytest.mark.parametrize(
"use_rocm_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_load_fp16_model(vllm_runner, kv_cache_dtype: str, force_marlin: bool,
monkeypatch) -> None:
use_rocm_aiter: bool, monkeypatch) -> None:
if use_rocm_aiter:
monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

# vllm_runner.apply_model() relies on V0 internals.
monkeypatch.setenv("VLLM_USE_V1", "0")

Expand Down
Loading