-
-
Notifications
You must be signed in to change notification settings - Fork 8.5k
[FEAT][ROCm] Integrate Paged Attention Kernel from AITER #15001
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
vllm-bot
merged 23 commits into
vllm-project:main
from
EmbeddedLLM:aiter-paged-attn-integration
Apr 22, 2025
Merged
Changes from 20 commits
Commits
Show all changes
23 commits
Select commit
Hold shift + click to select a range
dc09d66
add AITER paged attention kernel
vllmellm fe9ff98
include AITER enable for rocm platforms in model end to end tests
vllmellm d7c5dfb
add AITER into rocm docker base file
vllmellm c24fc09
Merge remote-tracking branch 'origin/main' into aiter-paged-attn-inte…
vllmellm 1732f9a
use clearer name for paged attention module used in ROCmFlashAttentio…
vllmellm 85296f7
fix get envs variables in unit tests
vllmellm 07ac4d4
Remove AttentionOps class instead use a simple funtion to return appr…
vllmellm 1592e7e
remove cascading logic from vllm.envs
vllmellm 07bf5c6
refactor aiter unit test flags into decorator
tjtanaa 1fdd695
modify the rocm AITER check tests based on new decorator and include …
vllmellm bb3687d
remove the decorator for enability of rocm AITER ops in tests
vllmellm 2dfa16f
Merge remote-tracking branch 'origin/main' into aiter-paged-attn-inte…
vllmellm 9087f44
match the tests files and run-amd-test script to the main branch
vllmellm 32b7a9b
sync with main
tjtanaa 052d9e0
import AITERPagedAttention only if flag is set
vllmellm 15862f1
prefer current_platform.fp8_dtype over the harcoded dtype
vllmellm 2e65b95
Merge remote-tracking branch 'origin/main' into aiter-paged-attn-inte…
vllmellm 15406cb
cache aiter pa import
vllmellm a9ef9f9
update aiter commit
vllmellm e203aed
correct comment
vllmellm 976da61
fix spelling mistake
vllmellm 0f5f2d0
prefer utils cdiv
vllmellm cc79ec9
Merge remote-tracking branch 'origin/main' into aiter-paged-attn-inte…
vllmellm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
from typing import Optional | ||
|
||
import aiter as rocm_aiter | ||
import torch | ||
|
||
from vllm.attention.ops.paged_attn import PagedAttention | ||
from vllm.platforms import current_platform | ||
|
||
FP8_DTYPE = current_platform.fp8_dtype() | ||
|
||
|
||
class AITERPagedAttention(PagedAttention): | ||
|
||
@staticmethod | ||
def write_to_paged_cache( | ||
key: torch.Tensor, | ||
value: torch.Tensor, | ||
key_cache: torch.Tensor, | ||
value_cache: torch.Tensor, | ||
slot_mapping: torch.Tensor, | ||
kv_cache_dtype: str, | ||
k_scale: torch.Tensor, | ||
v_scale: torch.Tensor, | ||
) -> None: | ||
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: | ||
PagedAttention.write_to_paged_cache(key, value, key_cache, | ||
value_cache, slot_mapping, | ||
kv_cache_dtype, k_scale, | ||
v_scale) | ||
else: | ||
kv_cache_torch_dtype = (FP8_DTYPE | ||
if "fp8" in kv_cache_dtype else torch.int8) | ||
key_cache = key_cache.view(kv_cache_torch_dtype) | ||
value_cache = value_cache.view(kv_cache_torch_dtype) | ||
|
||
rocm_aiter.reshape_and_cache_with_pertoken_quant( | ||
key, value, key_cache, value_cache, k_scale, v_scale, | ||
slot_mapping.flatten(), True) | ||
|
||
@staticmethod | ||
def forward_decode( | ||
query: torch.Tensor, | ||
key_cache: torch.Tensor, | ||
value_cache: torch.Tensor, | ||
block_tables: torch.Tensor, | ||
seq_lens: torch.Tensor, | ||
max_seq_len: int, | ||
kv_cache_dtype: str, | ||
num_kv_heads: int, | ||
scale: float, | ||
alibi_slopes: Optional[torch.Tensor], | ||
k_scale: torch.Tensor, | ||
v_scale: torch.Tensor, | ||
tp_rank: int = 0, | ||
blocksparse_local_blocks: int = 0, | ||
blocksparse_vert_stride: int = 0, | ||
blocksparse_block_size: int = 64, | ||
blocksparse_head_sliding_step: int = 0, | ||
) -> torch.Tensor: | ||
if kv_cache_dtype not in ["int8", "fp8", "fp8_e4m3"]: | ||
return PagedAttention.forward_decode( | ||
query=query, | ||
key_cache=key_cache, | ||
value_cache=value_cache, | ||
block_tables=block_tables, | ||
seq_lens=seq_lens, | ||
max_seq_len=max_seq_len, | ||
kv_cache_dtype=kv_cache_dtype, | ||
num_kv_heads=num_kv_heads, | ||
scale=scale, | ||
alibi_slopes=alibi_slopes, | ||
k_scale=k_scale, | ||
v_scale=v_scale, | ||
tp_rank=tp_rank, | ||
blocksparse_local_blocks=blocksparse_local_blocks, | ||
blocksparse_vert_stride=blocksparse_vert_stride, | ||
blocksparse_block_size=blocksparse_block_size, | ||
blocksparse_head_sliding_step=blocksparse_head_sliding_step) | ||
|
||
if "fp8" in kv_cache_dtype: | ||
key_cache = key_cache.view(torch.float8_e4m3fnuz) | ||
value_cache = value_cache.view(torch.float8_e4m3fnuz) | ||
|
||
if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: | ||
# use blocksparse paged attention | ||
block_size = value_cache.size(-1) | ||
assert (blocksparse_block_size > 0 and | ||
blocksparse_block_size % block_size == 0), \ | ||
(f"{blocksparse_block_size=} needs to be a multiple of" | ||
f"{block_size=} used in block_tables.") | ||
|
||
output = torch.empty_like(query) | ||
block_size = value_cache.shape[3] | ||
max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size | ||
vllmellm marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
rocm_aiter.pa_fwd_asm(query, key_cache, value_cache, block_tables, | ||
seq_lens, max_num_blocks_per_seq, k_scale, | ||
v_scale, output) | ||
return output |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.