Skip to content

Commit a4a6182

Browse files
vllmellmtjtanaa
authored andcommitted
[FEAT][ROCm] Integrate Paged Attention Kernel from AITER (vllm-project#15001)
Signed-off-by: vllmellm <[email protected]> Signed-off-by: tjtanaa <[email protected]> Co-authored-by: tjtanaa <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent 6a7bfc8 commit a4a6182

File tree

5 files changed

+195
-27
lines changed

5 files changed

+195
-27
lines changed

docker/Dockerfile.rocm_base

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
1212
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
1313
ARG FA_BRANCH="1a7f4dfa"
1414
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
15-
ARG AITER_BRANCH="5a77249"
15+
ARG AITER_BRANCH="7e1ed08"
1616
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
1717

1818
FROM ${BASE_IMAGE} AS base

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 84 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"""Attention layer ROCm GPUs."""
33
import itertools
44
from dataclasses import dataclass
5+
from functools import cache
56
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
67

78
import torch
@@ -26,6 +27,32 @@
2627
_PARTITION_SIZE_ROCM = 256
2728

2829

30+
@cache
31+
def is_rocm_aiter_paged_attn_enabled() -> bool:
32+
return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \
33+
and envs.VLLM_ROCM_USE_AITER \
34+
35+
36+
@cache
37+
def _get_paged_attn_module() -> PagedAttention:
38+
"""
39+
Initializes the appropriate PagedAttention module from `attention/ops`,
40+
which is used as helper function
41+
by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`.
42+
43+
The choice of attention module depends on whether
44+
AITER paged attention is enabled:
45+
- If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`.
46+
- Otherwise, it defaults to using the original `PagedAttention`.
47+
"""
48+
if is_rocm_aiter_paged_attn_enabled():
49+
# Import AITERPagedAttention only when the flag is enabled
50+
from vllm.attention.ops.rocm_aiter_paged_attn import (
51+
AITERPagedAttention)
52+
return AITERPagedAttention()
53+
return PagedAttention()
54+
55+
2956
class ROCmFlashAttentionBackend(AttentionBackend):
3057
accept_output_buffer: bool = True
3158

@@ -56,23 +83,26 @@ def get_kv_cache_shape(
5683
num_kv_heads: int,
5784
head_size: int,
5885
) -> Tuple[int, ...]:
59-
return PagedAttention.get_kv_cache_shape(num_blocks, block_size,
60-
num_kv_heads, head_size)
86+
paged_attn = _get_paged_attn_module()
87+
return paged_attn.get_kv_cache_shape(num_blocks, block_size,
88+
num_kv_heads, head_size)
6189

6290
@staticmethod
6391
def swap_blocks(
6492
src_kv_cache: torch.Tensor,
6593
dst_kv_cache: torch.Tensor,
6694
src_to_dst: torch.Tensor,
6795
) -> None:
68-
PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
96+
paged_attn = _get_paged_attn_module()
97+
paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
6998

7099
@staticmethod
71100
def copy_blocks(
72101
kv_caches: List[torch.Tensor],
73102
src_to_dists: torch.Tensor,
74103
) -> None:
75-
PagedAttention.copy_blocks(kv_caches, src_to_dists)
104+
paged_attn = _get_paged_attn_module()
105+
paged_attn.copy_blocks(kv_caches, src_to_dists)
76106

77107

78108
@dataclass
@@ -496,7 +526,10 @@ def __init__(
496526
assert self.num_heads % self.num_kv_heads == 0
497527
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
498528

499-
supported_head_sizes = PagedAttention.get_supported_head_sizes()
529+
self.paged_attn_module = _get_paged_attn_module()
530+
supported_head_sizes = self.paged_attn_module.get_supported_head_sizes(
531+
)
532+
500533
if head_size not in supported_head_sizes:
501534
raise ValueError(
502535
f"Head size {head_size} is not supported by PagedAttention. "
@@ -546,6 +579,8 @@ def __init__(
546579
self.sdpa_attn_func = _sdpa_attention
547580
logger.debug("Using naive (SDPA) attention in ROCmBackend")
548581

582+
self.aiter_kv_scales_initialized = False
583+
549584
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
550585
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
551586
tokens, n_kv_heads, head_dim = x.shape
@@ -624,20 +659,45 @@ def forward(
624659
else:
625660
assert value is None
626661

662+
paged_attn = self.paged_attn_module
663+
664+
# Reshaping kv tensors is required for AITER paged attention kernel
665+
# because it works on a different tensor shape,
666+
# when the size of one element is one byte (int8/fp8 dtypes).
667+
# This reshaping is only required on the first forward call
668+
# and the kv cache must not be empty.
669+
if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1
670+
and not self.aiter_kv_scales_initialized
671+
and kv_cache.shape != torch.Size([0])):
672+
num_blocks = kv_cache.shape[1]
673+
block_size = kv_cache.shape[2] // (self.num_kv_heads *
674+
self.head_size)
675+
k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
676+
dtype=torch.float32,
677+
device=kv_cache.device)
678+
v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size),
679+
dtype=torch.float32,
680+
device=kv_cache.device)
681+
self.aiter_kv_scales_initialized = True
682+
k_scale.fill_(layer._k_scale.item())
683+
v_scale.fill_(layer._v_scale.item())
684+
layer._k_scale = k_scale
685+
layer._v_scale = v_scale
686+
627687
# Only update KV cache for decoder self-attention
628688
# and encoder-decoder cross-attention
629689
if self.attn_type not in [
630690
AttentionType.ENCODER, AttentionType.ENCODER_ONLY
631691
] and kv_cache.numel() > 0:
632-
key_cache, value_cache = PagedAttention.split_kv_cache(
692+
key_cache, value_cache = paged_attn.split_kv_cache(
633693
kv_cache, self.num_kv_heads, self.head_size)
634694

635695
if key is not None and value is not None:
636696
# Reshape the input keys and values and store them in the
637697
# cache. If kv_cache is not provided, the new key and value
638698
# tensors are not cached. This happens during the initial
639699
# memory profiling run.
640-
PagedAttention.write_to_paged_cache(
700+
paged_attn.write_to_paged_cache(
641701
key,
642702
value,
643703
key_cache,
@@ -768,23 +828,22 @@ def forward(
768828
# prefix-enabled attention -
769829
# not applicable for encoder-only models
770830
if self.attn_type != AttentionType.ENCODER_ONLY:
771-
output[:
772-
num_prefill_tokens] = PagedAttention.forward_prefix(
773-
query,
774-
key,
775-
value,
776-
self.kv_cache_dtype,
777-
key_cache,
778-
value_cache,
779-
prefill_meta.block_tables,
780-
prefill_meta.query_start_loc,
781-
prefill_meta.seq_lens_tensor,
782-
prefill_meta.max_query_len,
783-
self.alibi_slopes,
784-
self.sliding_window[0],
785-
layer._k_scale,
786-
layer._v_scale,
787-
)
831+
output[:num_prefill_tokens] = paged_attn.forward_prefix(
832+
query,
833+
key,
834+
value,
835+
self.kv_cache_dtype,
836+
key_cache,
837+
value_cache,
838+
prefill_meta.block_tables,
839+
prefill_meta.query_start_loc,
840+
prefill_meta.seq_lens_tensor,
841+
prefill_meta.max_query_len,
842+
self.alibi_slopes,
843+
self.sliding_window[0],
844+
layer._k_scale,
845+
layer._v_scale,
846+
)
788847
# Skip decode phase for encoder-only models
789848
if (decode_meta := attn_metadata.decode_metadata) and (
790849
self.attn_type != AttentionType.ENCODER_ONLY):
@@ -843,7 +902,7 @@ def forward(
843902
layer._v_scale,
844903
)
845904
else:
846-
output[num_prefill_tokens:] = PagedAttention.forward_decode(
905+
output[num_prefill_tokens:] = paged_attn.forward_decode(
847906
decode_query,
848907
key_cache,
849908
value_cache,
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
from typing import Optional
3+
4+
import aiter as rocm_aiter
5+
import torch
6+
7+
from vllm.attention.ops.paged_attn import PagedAttention
8+
from vllm.platforms import current_platform
9+
from vllm.utils import cdiv
10+
11+
FP8_DTYPE = current_platform.fp8_dtype()
12+
13+
14+
class AITERPagedAttention(PagedAttention):
15+
16+
@staticmethod
17+
def write_to_paged_cache(
18+
key: torch.Tensor,
19+
value: torch.Tensor,
20+
key_cache: torch.Tensor,
21+
value_cache: torch.Tensor,
22+
slot_mapping: torch.Tensor,
23+
kv_cache_dtype: str,
24+
k_scale: torch.Tensor,
25+
v_scale: torch.Tensor,
26+
) -> None:
27+
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
28+
PagedAttention.write_to_paged_cache(key, value, key_cache,
29+
value_cache, slot_mapping,
30+
kv_cache_dtype, k_scale,
31+
v_scale)
32+
else:
33+
kv_cache_torch_dtype = (FP8_DTYPE
34+
if "fp8" in kv_cache_dtype else torch.int8)
35+
key_cache = key_cache.view(kv_cache_torch_dtype)
36+
value_cache = value_cache.view(kv_cache_torch_dtype)
37+
38+
rocm_aiter.reshape_and_cache_with_pertoken_quant(
39+
key, value, key_cache, value_cache, k_scale, v_scale,
40+
slot_mapping.flatten(), True)
41+
42+
@staticmethod
43+
def forward_decode(
44+
query: torch.Tensor,
45+
key_cache: torch.Tensor,
46+
value_cache: torch.Tensor,
47+
block_tables: torch.Tensor,
48+
seq_lens: torch.Tensor,
49+
max_seq_len: int,
50+
kv_cache_dtype: str,
51+
num_kv_heads: int,
52+
scale: float,
53+
alibi_slopes: Optional[torch.Tensor],
54+
k_scale: torch.Tensor,
55+
v_scale: torch.Tensor,
56+
tp_rank: int = 0,
57+
blocksparse_local_blocks: int = 0,
58+
blocksparse_vert_stride: int = 0,
59+
blocksparse_block_size: int = 64,
60+
blocksparse_head_sliding_step: int = 0,
61+
) -> torch.Tensor:
62+
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]:
63+
return PagedAttention.forward_decode(
64+
query=query,
65+
key_cache=key_cache,
66+
value_cache=value_cache,
67+
block_tables=block_tables,
68+
seq_lens=seq_lens,
69+
max_seq_len=max_seq_len,
70+
kv_cache_dtype=kv_cache_dtype,
71+
num_kv_heads=num_kv_heads,
72+
scale=scale,
73+
alibi_slopes=alibi_slopes,
74+
k_scale=k_scale,
75+
v_scale=v_scale,
76+
tp_rank=tp_rank,
77+
blocksparse_local_blocks=blocksparse_local_blocks,
78+
blocksparse_vert_stride=blocksparse_vert_stride,
79+
blocksparse_block_size=blocksparse_block_size,
80+
blocksparse_head_sliding_step=blocksparse_head_sliding_step)
81+
82+
if "fp8" in kv_cache_dtype:
83+
key_cache = key_cache.view(torch.float8_e4m3fnuz)
84+
value_cache = value_cache.view(torch.float8_e4m3fnuz)
85+
86+
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1:
87+
# use blocksparse paged attention
88+
block_size = value_cache.size(-1)
89+
assert (blocksparse_block_size > 0 and
90+
blocksparse_block_size % block_size == 0), \
91+
(f"{blocksparse_block_size=} needs to be a multiple of"
92+
f"{block_size=} used in block_tables.")
93+
94+
output = torch.empty_like(query)
95+
block_size = value_cache.shape[3]
96+
max_num_blocks_per_seq = cdiv(max_seq_len, block_size)
97+
98+
rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables,
99+
seq_lens, max_num_blocks_per_seq, k_scale,
100+
v_scale, output)
101+
return output

vllm/envs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
VLLM_DISABLED_KERNELS: list[str] = []
7676
VLLM_USE_V1: bool = True
7777
VLLM_ROCM_USE_AITER: bool = False
78+
VLLM_ROCM_USE_AITER_PAGED_ATTN: bool = False
7879
VLLM_ROCM_USE_AITER_LINEAR: bool = True
7980
VLLM_ROCM_USE_AITER_MOE: bool = True
8081
VLLM_ROCM_USE_AITER_RMSNORM: bool = True
@@ -533,6 +534,12 @@ def maybe_convert_int(value: Optional[str]) -> Optional[int]:
533534
lambda: (os.getenv("VLLM_ROCM_USE_AITER", "False").lower() in
534535
("true", "1")),
535536

537+
# Whether to use aiter paged attention.
538+
# By default is disabled.
539+
"VLLM_ROCM_USE_AITER_PAGED_ATTN":
540+
lambda: (os.getenv("VLLM_ROCM_USE_AITER_PAGED_ATTN", "False").lower() in
541+
("true", "1")),
542+
536543
# use aiter linear op if aiter ops are enabled
537544
# The following list of related ops
538545
# - scaled_mm (per-tensor / rowwise)

vllm/platforms/rocm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
118118
and (head_size == 64 or head_size == 128)
119119
and (block_size == 16 or block_size == 32)
120120
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768
121-
and envs.VLLM_ROCM_CUSTOM_PAGED_ATTN)
121+
and not (envs.VLLM_ROCM_USE_AITER_PAGED_ATTN
122+
and envs.VLLM_ROCM_USE_AITER))
122123

123124

124125
class RocmPlatform(Platform):

0 commit comments

Comments
 (0)