196
196
import torch
197
197
198
198
from vllm import _custom_ops as ops
199
+ from vllm import envs
199
200
from vllm .attention .backends .abstract import (AttentionBackend , AttentionLayer ,
200
201
AttentionMetadata ,
201
202
AttentionMetadataBuilder ,
215
216
from vllm .utils import async_tensor_h2d , cdiv , make_tensor_with_pad , round_down
216
217
from vllm .vllm_flash_attn .fa_utils import get_flash_attn_version
217
218
219
+ if HAS_TRITON :
220
+ from vllm .attention .ops .triton_flash_attention import triton_attention
221
+ else :
222
+ triton_attention = None
218
223
219
224
try :
220
225
from vllm .vllm_flash_attn import flash_attn_varlen_func
@@ -1039,6 +1044,7 @@ def __init__(
1039
1044
self .kv_b_proj = kv_b_proj
1040
1045
self .o_proj = o_proj
1041
1046
1047
+ self .triton_fa_func = triton_attention
1042
1048
# Handle the differences between the flash_attn_varlen from flash_attn
1043
1049
# and the one from vllm_flash_attn. The former is used on RoCM and the
1044
1050
# latter has an additional parameter to control FA2 vs FA3
@@ -1064,6 +1070,14 @@ def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale,
1064
1070
maybe_padded_v = torch .nn .functional .pad (
1065
1071
v , [0 , q .shape [- 1 ] - v .shape [- 1 ]], value = 0 )
1066
1072
1073
+ if is_hip and envs .VLLM_USE_TRITON_FLASH_ATTN \
1074
+ and not return_softmax_lse :
1075
+ attn_out = self .triton_fa_func (
1076
+ q ,
1077
+ k ,
1078
+ maybe_padded_v ,
1079
+ ** kwargs ,
1080
+ )
1067
1081
if is_vllm_fa :
1068
1082
attn_out = self .flash_attn_varlen_func (
1069
1083
q = q ,
0 commit comments