Skip to content

[FEAT][ROCm] Integrate Paged Attention Kernel from AITER #15001

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Apr 22, 2025
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
dc09d66
add AITER paged attention kernel
vllmellm Mar 17, 2025
fe9ff98
include AITER enable for rocm platforms in model end to end tests
vllmellm Mar 17, 2025
d7c5dfb
add AITER into rocm docker base file
vllmellm Mar 17, 2025
c24fc09
Merge remote-tracking branch 'origin/main' into aiter-paged-attn-inte…
vllmellm Mar 18, 2025
1732f9a
use clearer name for paged attention module used in ROCmFlashAttentio…
vllmellm Mar 18, 2025
85296f7
fix get envs variables in unit tests
vllmellm Mar 18, 2025
07ac4d4
Remove AttentionOps class instead use a simple funtion to return appr…
vllmellm Mar 18, 2025
1592e7e
remove cascading logic from vllm.envs
vllmellm Mar 19, 2025
07bf5c6
refactor aiter unit test flags into decorator
tjtanaa Mar 19, 2025
1fdd695
modify the rocm AITER check tests based on new decorator and include …
vllmellm Mar 19, 2025
bb3687d
remove the decorator for enability of rocm AITER ops in tests
vllmellm Mar 26, 2025
2dfa16f
Merge remote-tracking branch 'origin/main' into aiter-paged-attn-inte…
vllmellm Mar 26, 2025
9087f44
match the tests files and run-amd-test script to the main branch
vllmellm Mar 26, 2025
32b7a9b
sync with main
tjtanaa Apr 1, 2025
052d9e0
import AITERPagedAttention only if flag is set
vllmellm Apr 21, 2025
15862f1
prefer current_platform.fp8_dtype over the harcoded dtype
vllmellm Apr 21, 2025
2e65b95
Merge remote-tracking branch 'origin/main' into aiter-paged-attn-inte…
vllmellm Apr 21, 2025
15406cb
cache aiter pa import
vllmellm Apr 21, 2025
a9ef9f9
update aiter commit
vllmellm Apr 21, 2025
e203aed
correct comment
vllmellm Apr 21, 2025
976da61
fix spelling mistake
vllmellm Apr 22, 2025
0f5f2d0
prefer utils cdiv
vllmellm Apr 22, 2025
cc79ec9
Merge remote-tracking branch 'origin/main' into aiter-paged-attn-inte…
vllmellm Apr 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docker/Dockerfile.rocm_base
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
ARG FA_BRANCH="1a7f4dfa"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="8970b25b"
ARG AITER_BRANCH="7e1ed08"
ARG AITER_REPO="https://github.com/ROCm/aiter.git"

FROM ${BASE_IMAGE} AS base
Expand Down
109 changes: 84 additions & 25 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
"""Attention layer ROCm GPUs."""
import itertools
from dataclasses import dataclass
from functools import cache
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import torch
Expand All @@ -26,6 +27,32 @@
_PARTITION_SIZE_ROCM = 256


@cache
def is_rocm_aiter_paged_attn_enabled() -> bool:
return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \
and envs.VLLM_ROCM_USE_AITER \


@cache
def _get_paged_attn_module() -> PagedAttention:
"""
Initializes the appropriate PagedAttention module from `attention/ops`,
which is used as helper function
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.

The choice of attention module depends on whether
AITER paged attention is enabled:
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
- Otherwise, it defaults to using the original `PagedAttention`.
"""
if is_rocm_aiter_paged_attn_enabled():
# Import AITERPagedAttention only when the flag is enabled
from vllm.attention.ops.rocm_aiter_paged_attn import (
AITERPagedAttention)
return AITERPagedAttention()
return PagedAttention()


class ROCmFlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True

Expand Down Expand Up @@ -56,23 +83,26 @@ def get_kv_cache_shape(
num_kv_heads: int,
head_size: int,
) -> Tuple[int, ...]:
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)
paged_attn = _get_paged_attn_module()
return paged_attn.get_kv_cache_shape(num_blocks, block_size,
num_kv_heads, head_size)

@staticmethod
def swap_blocks(
src_kv_cache: torch.Tensor,
dst_kv_cache: torch.Tensor,
src_to_dst: torch.Tensor,
) -> None:
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
paged_attn = _get_paged_attn_module()
paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)

@staticmethod
def copy_blocks(
kv_caches: List[torch.Tensor],
src_to_dists: torch.Tensor,
) -> None:
PagedAttention.copy_blocks(kv_caches, src_to_dists)
paged_attn = _get_paged_attn_module()
paged_attn.copy_blocks(kv_caches, src_to_dists)


@dataclass
Expand Down Expand Up @@ -496,7 +526,10 @@ def __init__(
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

supported_head_sizes = PagedAttention.get_supported_head_sizes()
self.paged_attn_module = _get_paged_attn_module()
supported_head_sizes = self.paged_attn_module.get_supported_head_sizes(
)

if head_size not in supported_head_sizes:
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
Expand Down Expand Up @@ -546,6 +579,8 @@ def __init__(
self.sdpa_attn_func = _sdpa_attention
logger.debug("Using naive (SDPA) attention in ROCmBackend")

self.aiter_kv_scales_initialized = False

def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
tokens, n_kv_heads, head_dim = x.shape
Expand Down Expand Up @@ -624,20 +659,45 @@ def forward(
else:
assert value is None

paged_attn = self.paged_attn_module

# Reshaping kv tensors is required for AITER paged attention kernel
# because it works on a different tensor shape,
# when the size of one element is one byte (int8/fp8 dtypes).
# This reshaping is only required on the first forward call
# and the kv cache must not be empty.
if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1
and not self.aiter_kv_scales_initialized
and kv_cache.shape != torch.Size([0])):
num_blocks = kv_cache.shape[1]
block_size = kv_cache.shape[2] // (self.num_kv_heads *
self.head_size)
k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
dtype=torch.float32,
device=kv_cache.device)
v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
dtype=torch.float32,
device=kv_cache.device)
self.aiter_kv_scales_initialized = True
k_scale.fill_(layer._k_scale.item())
v_scale.fill_(layer._v_scale.item())
layer._k_scale = k_scale
layer._v_scale = v_scale

# Only update KV cache for decoder self-attention
# and encoder-decoder cross-attention
if self.attn_type not in [
AttentionType.ENCODER, AttentionType.ENCODER_ONLY
] and kv_cache.numel() > 0:
key_cache, value_cache = PagedAttention.split_kv_cache(
key_cache, value_cache = paged_attn.split_kv_cache(
kv_cache, self.num_kv_heads, self.head_size)

if key is not None and value is not None:
# Reshape the input keys and values and store them in the
# cache. If kv_cache is not provided, the new key and value
# tensors are not cached. This happens during the initial
# memory profiling run.
PagedAttention.write_to_paged_cache(
paged_attn.write_to_paged_cache(
key,
value,
key_cache,
Expand Down Expand Up @@ -768,23 +828,22 @@ def forward(
# prefix-enabled attention -
# not applicable for encoder-only models
if self.attn_type != AttentionType.ENCODER_ONLY:
output[:
num_prefill_tokens] = PagedAttention.forward_prefix(
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
layer._k_scale,
layer._v_scale,
)
output[:num_prefill_tokens] = paged_attn.forward_prefix(
query,
key,
value,
self.kv_cache_dtype,
key_cache,
value_cache,
prefill_meta.block_tables,
prefill_meta.query_start_loc,
prefill_meta.seq_lens_tensor,
prefill_meta.max_query_len,
self.alibi_slopes,
self.sliding_window[0],
layer._k_scale,
layer._v_scale,
)
# Skip decode phase for encoder-only models
if (decode_meta := attn_metadata.decode_metadata) and (
self.attn_type != AttentionType.ENCODER_ONLY):
Expand Down Expand Up @@ -843,7 +902,7 @@ def forward(
layer._v_scale,
)
else:
output[num_prefill_tokens:] = PagedAttention.forward_decode(
output[num_prefill_tokens:] = paged_attn.forward_decode(
decode_query,
key_cache,
value_cache,
Expand Down
100 changes: 100 additions & 0 deletions vllm/attention/ops/rocm_aiter_paged_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional

import aiter as rocm_aiter
import torch

from vllm.attention.ops.paged_attn import PagedAttention
from vllm.platforms import current_platform

FP8_DTYPE = current_platform.fp8_dtype()


class AITERPagedAttention(PagedAttention):

@staticmethod
def write_to_paged_cache(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
PagedAttention.write_to_paged_cache(key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype, k_scale,
v_scale)
else:
kv_cache_torch_dtype = (FP8_DTYPE
if "fp8" in kv_cache_dtype else torch.int8)
key_cache = key_cache.view(kv_cache_torch_dtype)
value_cache = value_cache.view(kv_cache_torch_dtype)

rocm_aiter.reshape_and_cache_with_pertoken_quant(
key, value, key_cache, value_cache, k_scale, v_scale,
slot_mapping.flatten(), True)

@staticmethod
def forward_decode(
query: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
block_tables: torch.Tensor,
seq_lens: torch.Tensor,
max_seq_len: int,
kv_cache_dtype: str,
num_kv_heads: int,
scale: float,
alibi_slopes: Optional[torch.Tensor],
k_scale: torch.Tensor,
v_scale: torch.Tensor,
tp_rank: int = 0,
blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0,
blocksparse_block_size: int = 64,
blocksparse_head_sliding_step: int = 0,
) -> torch.Tensor:
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
return PagedAttention.forward_decode(
query=query,
key_cache=key_cache,
value_cache=value_cache,
block_tables=block_tables,
seq_lens=seq_lens,
max_seq_len=max_seq_len,
kv_cache_dtype=kv_cache_dtype,
num_kv_heads=num_kv_heads,
scale=scale,
alibi_slopes=alibi_slopes,
k_scale=k_scale,
v_scale=v_scale,
tp_rank=tp_rank,
blocksparse_local_blocks=blocksparse_local_blocks,
blocksparse_vert_stride=blocksparse_vert_stride,
blocksparse_block_size=blocksparse_block_size,
blocksparse_head_sliding_step=blocksparse_head_sliding_step)

if "fp8" in kv_cache_dtype:
key_cache = key_cache.view(torch.float8_e4m3fnuz)
value_cache = value_cache.view(torch.float8_e4m3fnuz)

if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
# use blocksparse paged attention
block_size = value_cache.size(-1)
assert (blocksparse_block_size > 0 and
blocksparse_block_size % block_size == 0), \
(f"{blocksparse_block_size=} needs to be a multiple of"
f"{block_size=} used in block_tables.")

output = torch.empty_like(query)
block_size = value_cache.shape[3]
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size

rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables,
seq_lens, max_num_blocks_per_seq, k_scale,
v_scale, output)
return output
7 changes: 7 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
VLLM_DISABLED_KERNELS: list[str] = []
VLLM_USE_V1: bool = True
VLLM_ROCM_USE_AITER: bool = False
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
VLLM_ROCM_USE_AITER_LINEAR: bool = True
VLLM_ROCM_USE_AITER_MOE: bool = True
VLLM_ROCM_USE_AITER_FP8_BLOCK_SCALED_MOE: bool = False
Expand Down Expand Up @@ -533,6 +534,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
("true", "1")),

# Whether to use aiter paged attention.
# By default is diabled.
"VLLM_ROCM_USE_AITER_PAGED_ATTN":
lambda: (os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in
("true", "1")),

# use aiter linear op if aiter ops are enabled
# The following list of related ops
# - scaled_mm (per-tensor / rowwise)
Expand Down
3 changes: 2 additions & 1 deletion vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
and envs.VLLM_ROCM_USE_AITER))


class RocmPlatform(Platform):
Expand Down
Loading