Skip to content

Commit 973fdb2

Browse files
ProExpertProgdbyoung18
authored andcommitted
[ROCm] [Attention] Cleanup ROCm output passing (vllm-project#16431)
Signed-off-by: Luka Govedič <[email protected]>
1 parent 2f785c8 commit 973fdb2

File tree

1 file changed

+18
-23
lines changed

1 file changed

+18
-23
lines changed

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828

2929
class ROCmFlashAttentionBackend(AttentionBackend):
30+
accept_output_buffer: bool = True
3031

3132
@staticmethod
3233
def get_name() -> str:
@@ -515,7 +516,7 @@ def __init__(
515516

516517
from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
517518
triton_attention)
518-
self.attn_func = triton_attention
519+
self.triton_attn_func = triton_attention
519520
logger.debug("Using Triton FA in ROCmBackend")
520521
if self.sliding_window != (-1, -1):
521522
logger.warning("ROCm Triton FA does not currently support "
@@ -531,7 +532,7 @@ def __init__(
531532
else:
532533
try:
533534
from flash_attn import flash_attn_varlen_func # noqa: F401
534-
self.attn_func = flash_attn_varlen_func
535+
self.fa_attn_func = flash_attn_varlen_func
535536
logger.debug("Using CK FA in ROCmBackend")
536537
except ModuleNotFoundError:
537538
self.use_naive_attn = True
@@ -542,7 +543,7 @@ def __init__(
542543
"ROCm Naive FlashAttention does not support "
543544
"attention logits soft capping.")
544545

545-
self.attn_func = _sdpa_attention
546+
self.sdpa_attn_func = _sdpa_attention
546547
logger.debug("Using naive (SDPA) attention in ROCmBackend")
547548

548549
def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
@@ -613,6 +614,8 @@ def forward(
613614
Returns:
614615
shape = [num_tokens, num_heads * head_size]
615616
"""
617+
assert output is not None, "Output tensor must be provided."
618+
616619
query = query.view(-1, self.num_heads, self.head_size)
617620
if key is not None:
618621
assert value is not None
@@ -656,7 +659,6 @@ def forward(
656659
assert attn_metadata.num_encoder_tokens is not None
657660
num_prefill_tokens = attn_metadata.num_encoder_tokens
658661

659-
output = torch.empty_like(query)
660662
# Query for decode. KV is not needed because it is already cached.
661663
decode_query = query[num_prefill_tokens:]
662664
# QKV for prefill.
@@ -704,11 +706,11 @@ def forward(
704706
query.dtype,
705707
seq_lens,
706708
make_attn_mask=causal_mask) # type: ignore
707-
out, _ = self.attn_func(
709+
self.triton_attn_func(
708710
query,
709711
key,
710712
value,
711-
None,
713+
output[:num_prefill_tokens],
712714
query_seq_start_loc,
713715
key_seq_start_loc,
714716
query_max_seq_len,
@@ -733,10 +735,11 @@ def forward(
733735
key = key.movedim(0, key.dim() - 2)
734736
value = value.movedim(0, value.dim() - 2)
735737
# sdpa math backend attention
736-
out = self.attn_func(
738+
self.sdpa_attn_func(
737739
query,
738740
key,
739741
value,
742+
output[:num_prefill_tokens],
740743
query_seq_start_loc,
741744
num_prefill_tokens,
742745
self.num_heads,
@@ -745,7 +748,8 @@ def forward(
745748
attn_masks,
746749
)
747750
else:
748-
out = self.attn_func(
751+
# upstream FA does not support an output arg, copy
752+
output[:num_prefill_tokens] = self.fa_attn_func(
749753
q=query,
750754
k=key,
751755
v=value,
@@ -760,12 +764,6 @@ def forward(
760764
softcap=self.logits_soft_cap,
761765
)
762766

763-
# common code for prefill
764-
assert output[:num_prefill_tokens].shape == out.shape
765-
if output.shape[0] > num_prefill_tokens:
766-
output[:num_prefill_tokens] = out
767-
else:
768-
output = out
769767
else:
770768
# prefix-enabled attention -
771769
# not applicable for encoder-only models
@@ -818,14 +816,10 @@ def forward(
818816
device=output.device,
819817
)
820818
max_logits = torch.empty_like(exp_sums)
821-
if num_prefill_tokens > 0:
822-
out = output[num_prefill_tokens:]
823-
else:
824-
out = output
825819

826820
query_start_loc = None
827821
ops.paged_attention_rocm(
828-
out,
822+
output[num_prefill_tokens:],
829823
exp_sums,
830824
max_logits,
831825
tmp_output,
@@ -878,17 +872,18 @@ def _sdpa_attention(
878872
query: torch.Tensor,
879873
key: torch.Tensor,
880874
value: torch.Tensor,
881-
seq_lens: List[int],
875+
output: torch.Tensor,
876+
seq_lens: torch.Tensor,
882877
num_tokens: int,
883878
num_heads: int,
884879
head_size: int,
885880
scale: float,
886881
attn_masks: Optional[List[torch.Tensor]] = None,
887882
) -> torch.Tensor:
888883
start = 0
889-
output = torch.empty((num_tokens, num_heads, head_size),
890-
dtype=query.dtype,
891-
device=query.device)
884+
assert output.shape == (num_tokens, num_heads, head_size)
885+
assert output.dtype == query.dtype
886+
assert output.device == query.device
892887

893888
for i, seq_len in enumerate(seq_lens):
894889
end = start + seq_len

0 commit comments

Comments
 (0)