Skip to content

Commit fcc54d8

Browse files
more amd tweaks
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent 31be47b commit fcc54d8

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

vllm/attention/backends/mla/common.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@
196196
import torch
197197

198198
from vllm import _custom_ops as ops
199+
from vllm import envs
199200
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
200201
AttentionMetadata,
201202
AttentionMetadataBuilder,
@@ -215,6 +216,10 @@
215216
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
216217
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
217218

219+
if HAS_TRITON:
220+
from vllm.attention.ops.triton_flash_attention import triton_attention
221+
else:
222+
triton_attention = None
218223

219224
try:
220225
from vllm.vllm_flash_attn import flash_attn_varlen_func
@@ -1039,6 +1044,7 @@ def __init__(
10391044
self.kv_b_proj = kv_b_proj
10401045
self.o_proj = o_proj
10411046

1047+
self.triton_fa_func = triton_attention
10421048
# Handle the differences between the flash_attn_varlen from flash_attn
10431049
# and the one from vllm_flash_attn. The former is used on RoCM and the
10441050
# latter has an additional parameter to control FA2 vs FA3
@@ -1064,6 +1070,14 @@ def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale,
10641070
maybe_padded_v = torch.nn.functional.pad(
10651071
v, [0, q.shape[-1] - v.shape[-1]], value=0)
10661072

1073+
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \
1074+
and not return_softmax_lse:
1075+
attn_out = self.triton_fa_func(
1076+
q,
1077+
k,
1078+
maybe_padded_v,
1079+
**kwargs,
1080+
)
10671081
if is_vllm_fa:
10681082
attn_out = self.flash_attn_varlen_func(
10691083
q=q,

vllm/v1/attention/backends/mla/common.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,6 @@
195195
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
196196
AttentionMetadata,
197197
MLAAttentionImpl)
198-
199198
from vllm.attention.backends.utils import get_mla_dims
200199
from vllm.attention.ops.merge_attn_states import merge_attn_states
201200
from vllm.logger import init_logger
@@ -220,8 +219,6 @@
220219
from vllm.v1.worker.gpu_input_batch import InputBatch
221220
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
222221

223-
is_hip = current_platform.is_rocm()
224-
225222
logger = init_logger(__name__)
226223

227224

0 commit comments

Comments
 (0)