27
27
28
28
29
29
class ROCmFlashAttentionBackend (AttentionBackend ):
30
+ accept_output_buffer : bool = True
30
31
31
32
@staticmethod
32
33
def get_name () -> str :
@@ -515,7 +516,7 @@ def __init__(
515
516
516
517
from vllm .attention .ops .triton_flash_attention import ( # noqa: F401
517
518
triton_attention )
518
- self .attn_func = triton_attention
519
+ self .triton_attn_func = triton_attention
519
520
logger .debug ("Using Triton FA in ROCmBackend" )
520
521
if self .sliding_window != (- 1 , - 1 ):
521
522
logger .warning ("ROCm Triton FA does not currently support "
@@ -531,7 +532,7 @@ def __init__(
531
532
else :
532
533
try :
533
534
from flash_attn import flash_attn_varlen_func # noqa: F401
534
- self .attn_func = flash_attn_varlen_func
535
+ self .fa_attn_func = flash_attn_varlen_func
535
536
logger .debug ("Using CK FA in ROCmBackend" )
536
537
except ModuleNotFoundError :
537
538
self .use_naive_attn = True
@@ -542,7 +543,7 @@ def __init__(
542
543
"ROCm Naive FlashAttention does not support "
543
544
"attention logits soft capping." )
544
545
545
- self .attn_func = _sdpa_attention
546
+ self .sdpa_attn_func = _sdpa_attention
546
547
logger .debug ("Using naive (SDPA) attention in ROCmBackend" )
547
548
548
549
def repeat_kv (self , x : torch .Tensor , n_rep : int ) -> torch .Tensor :
@@ -613,6 +614,8 @@ def forward(
613
614
Returns:
614
615
shape = [num_tokens, num_heads * head_size]
615
616
"""
617
+ assert output is not None , "Output tensor must be provided."
618
+
616
619
query = query .view (- 1 , self .num_heads , self .head_size )
617
620
if key is not None :
618
621
assert value is not None
@@ -656,7 +659,6 @@ def forward(
656
659
assert attn_metadata .num_encoder_tokens is not None
657
660
num_prefill_tokens = attn_metadata .num_encoder_tokens
658
661
659
- output = torch .empty_like (query )
660
662
# Query for decode. KV is not needed because it is already cached.
661
663
decode_query = query [num_prefill_tokens :]
662
664
# QKV for prefill.
@@ -704,11 +706,11 @@ def forward(
704
706
query .dtype ,
705
707
seq_lens ,
706
708
make_attn_mask = causal_mask ) # type: ignore
707
- out , _ = self .attn_func (
709
+ self .triton_attn_func (
708
710
query ,
709
711
key ,
710
712
value ,
711
- None ,
713
+ output [: num_prefill_tokens ] ,
712
714
query_seq_start_loc ,
713
715
key_seq_start_loc ,
714
716
query_max_seq_len ,
@@ -733,10 +735,11 @@ def forward(
733
735
key = key .movedim (0 , key .dim () - 2 )
734
736
value = value .movedim (0 , value .dim () - 2 )
735
737
# sdpa math backend attention
736
- out = self .attn_func (
738
+ self .sdpa_attn_func (
737
739
query ,
738
740
key ,
739
741
value ,
742
+ output [:num_prefill_tokens ],
740
743
query_seq_start_loc ,
741
744
num_prefill_tokens ,
742
745
self .num_heads ,
@@ -745,7 +748,8 @@ def forward(
745
748
attn_masks ,
746
749
)
747
750
else :
748
- out = self .attn_func (
751
+ # upstream FA does not support an output arg, copy
752
+ output [:num_prefill_tokens ] = self .fa_attn_func (
749
753
q = query ,
750
754
k = key ,
751
755
v = value ,
@@ -760,12 +764,6 @@ def forward(
760
764
softcap = self .logits_soft_cap ,
761
765
)
762
766
763
- # common code for prefill
764
- assert output [:num_prefill_tokens ].shape == out .shape
765
- if output .shape [0 ] > num_prefill_tokens :
766
- output [:num_prefill_tokens ] = out
767
- else :
768
- output = out
769
767
else :
770
768
# prefix-enabled attention -
771
769
# not applicable for encoder-only models
@@ -818,14 +816,10 @@ def forward(
818
816
device = output .device ,
819
817
)
820
818
max_logits = torch .empty_like (exp_sums )
821
- if num_prefill_tokens > 0 :
822
- out = output [num_prefill_tokens :]
823
- else :
824
- out = output
825
819
826
820
query_start_loc = None
827
821
ops .paged_attention_rocm (
828
- out ,
822
+ output [ num_prefill_tokens :] ,
829
823
exp_sums ,
830
824
max_logits ,
831
825
tmp_output ,
@@ -878,17 +872,18 @@ def _sdpa_attention(
878
872
query : torch .Tensor ,
879
873
key : torch .Tensor ,
880
874
value : torch .Tensor ,
881
- seq_lens : List [int ],
875
+ output : torch .Tensor ,
876
+ seq_lens : torch .Tensor ,
882
877
num_tokens : int ,
883
878
num_heads : int ,
884
879
head_size : int ,
885
880
scale : float ,
886
881
attn_masks : Optional [List [torch .Tensor ]] = None ,
887
882
) -> torch .Tensor :
888
883
start = 0
889
- output = torch . empty (( num_tokens , num_heads , head_size ),
890
- dtype = query .dtype ,
891
- device = query .device )
884
+ assert output . shape == ( num_tokens , num_heads , head_size )
885
+ assert output . dtype == query .dtype
886
+ assert output . device == query .device
892
887
893
888
for i , seq_len in enumerate (seq_lens ):
894
889
end = start + seq_len
0 commit comments