|
2 | 2 | """Attention layer ROCm GPUs."""
|
3 | 3 | import itertools
|
4 | 4 | from dataclasses import dataclass
|
| 5 | +from functools import cache |
5 | 6 | from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
|
6 | 7 |
|
7 | 8 | import torch
|
|
26 | 27 | _PARTITION_SIZE_ROCM = 256
|
27 | 28 |
|
28 | 29 |
|
| 30 | +@cache |
| 31 | +def is_rocm_aiter_paged_attn_enabled() -> bool: |
| 32 | + return envs.VLLM_ROCM_USE_AITER_PAGED_ATTN \ |
| 33 | + and envs.VLLM_ROCM_USE_AITER \ |
| 34 | + |
| 35 | + |
| 36 | +@cache |
| 37 | +def _get_paged_attn_module() -> PagedAttention: |
| 38 | + """ |
| 39 | + Initializes the appropriate PagedAttention module from `attention/ops`, |
| 40 | + which is used as helper function |
| 41 | + by `ROCmFlashAttentionImpl` and `ROCmFlashAttentionBackend`. |
| 42 | +
|
| 43 | + The choice of attention module depends on whether |
| 44 | + AITER paged attention is enabled: |
| 45 | + - If enabled, `ROCmFlashAttentionImpl` uses `AITERPagedAttention`. |
| 46 | + - Otherwise, it defaults to using the original `PagedAttention`. |
| 47 | + """ |
| 48 | + if is_rocm_aiter_paged_attn_enabled(): |
| 49 | + # Import AITERPagedAttention only when the flag is enabled |
| 50 | + from vllm.attention.ops.rocm_aiter_paged_attn import ( |
| 51 | + AITERPagedAttention) |
| 52 | + return AITERPagedAttention() |
| 53 | + return PagedAttention() |
| 54 | + |
| 55 | + |
29 | 56 | class ROCmFlashAttentionBackend(AttentionBackend):
|
30 | 57 | accept_output_buffer: bool = True
|
31 | 58 |
|
@@ -56,23 +83,26 @@ def get_kv_cache_shape(
|
56 | 83 | num_kv_heads: int,
|
57 | 84 | head_size: int,
|
58 | 85 | ) -> Tuple[int, ...]:
|
59 |
| - return PagedAttention.get_kv_cache_shape(num_blocks, block_size, |
60 |
| - num_kv_heads, head_size) |
| 86 | + paged_attn = _get_paged_attn_module() |
| 87 | + return paged_attn.get_kv_cache_shape(num_blocks, block_size, |
| 88 | + num_kv_heads, head_size) |
61 | 89 |
|
62 | 90 | @staticmethod
|
63 | 91 | def swap_blocks(
|
64 | 92 | src_kv_cache: torch.Tensor,
|
65 | 93 | dst_kv_cache: torch.Tensor,
|
66 | 94 | src_to_dst: torch.Tensor,
|
67 | 95 | ) -> None:
|
68 |
| - PagedAttention.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) |
| 96 | + paged_attn = _get_paged_attn_module() |
| 97 | + paged_attn.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst) |
69 | 98 |
|
70 | 99 | @staticmethod
|
71 | 100 | def copy_blocks(
|
72 | 101 | kv_caches: List[torch.Tensor],
|
73 | 102 | src_to_dists: torch.Tensor,
|
74 | 103 | ) -> None:
|
75 |
| - PagedAttention.copy_blocks(kv_caches, src_to_dists) |
| 104 | + paged_attn = _get_paged_attn_module() |
| 105 | + paged_attn.copy_blocks(kv_caches, src_to_dists) |
76 | 106 |
|
77 | 107 |
|
78 | 108 | @dataclass
|
@@ -496,7 +526,10 @@ def __init__(
|
496 | 526 | assert self.num_heads % self.num_kv_heads == 0
|
497 | 527 | self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
498 | 528 |
|
499 |
| - supported_head_sizes = PagedAttention.get_supported_head_sizes() |
| 529 | + self.paged_attn_module = _get_paged_attn_module() |
| 530 | + supported_head_sizes = self.paged_attn_module.get_supported_head_sizes( |
| 531 | + ) |
| 532 | + |
500 | 533 | if head_size not in supported_head_sizes:
|
501 | 534 | raise ValueError(
|
502 | 535 | f"Head size {head_size} is not supported by PagedAttention. "
|
@@ -546,6 +579,8 @@ def __init__(
|
546 | 579 | self.sdpa_attn_func = _sdpa_attention
|
547 | 580 | logger.debug("Using naive (SDPA) attention in ROCmBackend")
|
548 | 581 |
|
| 582 | + self.aiter_kv_scales_initialized = False |
| 583 | + |
549 | 584 | def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
550 | 585 | """torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
|
551 | 586 | tokens, n_kv_heads, head_dim = x.shape
|
@@ -624,20 +659,45 @@ def forward(
|
624 | 659 | else:
|
625 | 660 | assert value is None
|
626 | 661 |
|
| 662 | + paged_attn = self.paged_attn_module |
| 663 | + |
| 664 | + # Reshaping kv tensors is required for AITER paged attention kernel |
| 665 | + # because it works on a different tensor shape, |
| 666 | + # when the size of one element is one byte (int8/fp8 dtypes). |
| 667 | + # This reshaping is only required on the first forward call |
| 668 | + # and the kv cache must not be empty. |
| 669 | + if (is_rocm_aiter_paged_attn_enabled() and kv_cache.dtype.itemsize == 1 |
| 670 | + and not self.aiter_kv_scales_initialized |
| 671 | + and kv_cache.shape != torch.Size([0])): |
| 672 | + num_blocks = kv_cache.shape[1] |
| 673 | + block_size = kv_cache.shape[2] // (self.num_kv_heads * |
| 674 | + self.head_size) |
| 675 | + k_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), |
| 676 | + dtype=torch.float32, |
| 677 | + device=kv_cache.device) |
| 678 | + v_scale = torch.empty((self.num_kv_heads, num_blocks * block_size), |
| 679 | + dtype=torch.float32, |
| 680 | + device=kv_cache.device) |
| 681 | + self.aiter_kv_scales_initialized = True |
| 682 | + k_scale.fill_(layer._k_scale.item()) |
| 683 | + v_scale.fill_(layer._v_scale.item()) |
| 684 | + layer._k_scale = k_scale |
| 685 | + layer._v_scale = v_scale |
| 686 | + |
627 | 687 | # Only update KV cache for decoder self-attention
|
628 | 688 | # and encoder-decoder cross-attention
|
629 | 689 | if self.attn_type not in [
|
630 | 690 | AttentionType.ENCODER, AttentionType.ENCODER_ONLY
|
631 | 691 | ] and kv_cache.numel() > 0:
|
632 |
| - key_cache, value_cache = PagedAttention.split_kv_cache( |
| 692 | + key_cache, value_cache = paged_attn.split_kv_cache( |
633 | 693 | kv_cache, self.num_kv_heads, self.head_size)
|
634 | 694 |
|
635 | 695 | if key is not None and value is not None:
|
636 | 696 | # Reshape the input keys and values and store them in the
|
637 | 697 | # cache. If kv_cache is not provided, the new key and value
|
638 | 698 | # tensors are not cached. This happens during the initial
|
639 | 699 | # memory profiling run.
|
640 |
| - PagedAttention.write_to_paged_cache( |
| 700 | + paged_attn.write_to_paged_cache( |
641 | 701 | key,
|
642 | 702 | value,
|
643 | 703 | key_cache,
|
@@ -768,23 +828,22 @@ def forward(
|
768 | 828 | # prefix-enabled attention -
|
769 | 829 | # not applicable for encoder-only models
|
770 | 830 | if self.attn_type != AttentionType.ENCODER_ONLY:
|
771 |
| - output[: |
772 |
| - num_prefill_tokens] = PagedAttention.forward_prefix( |
773 |
| - query, |
774 |
| - key, |
775 |
| - value, |
776 |
| - self.kv_cache_dtype, |
777 |
| - key_cache, |
778 |
| - value_cache, |
779 |
| - prefill_meta.block_tables, |
780 |
| - prefill_meta.query_start_loc, |
781 |
| - prefill_meta.seq_lens_tensor, |
782 |
| - prefill_meta.max_query_len, |
783 |
| - self.alibi_slopes, |
784 |
| - self.sliding_window[0], |
785 |
| - layer._k_scale, |
786 |
| - layer._v_scale, |
787 |
| - ) |
| 831 | + output[:num_prefill_tokens] = paged_attn.forward_prefix( |
| 832 | + query, |
| 833 | + key, |
| 834 | + value, |
| 835 | + self.kv_cache_dtype, |
| 836 | + key_cache, |
| 837 | + value_cache, |
| 838 | + prefill_meta.block_tables, |
| 839 | + prefill_meta.query_start_loc, |
| 840 | + prefill_meta.seq_lens_tensor, |
| 841 | + prefill_meta.max_query_len, |
| 842 | + self.alibi_slopes, |
| 843 | + self.sliding_window[0], |
| 844 | + layer._k_scale, |
| 845 | + layer._v_scale, |
| 846 | + ) |
788 | 847 | # Skip decode phase for encoder-only models
|
789 | 848 | if (decode_meta := attn_metadata.decode_metadata) and (
|
790 | 849 | self.attn_type != AttentionType.ENCODER_ONLY):
|
@@ -843,7 +902,7 @@ def forward(
|
843 | 902 | layer._v_scale,
|
844 | 903 | )
|
845 | 904 | else:
|
846 |
| - output[num_prefill_tokens:] = PagedAttention.forward_decode( |
| 905 | + output[num_prefill_tokens:] = paged_attn.forward_decode( |
847 | 906 | decode_query,
|
848 | 907 | key_cache,
|
849 | 908 | value_cache,
|
|
0 commit comments