Skip to content

Commit e7cfb5a

Browse files
LucasWilkinsondbyoung18
authored andcommitted
[Attention] Update to lastest FA3 code (vllm-project#13111)
Signed-off-by: Lucas Wilkinson <[email protected]>
1 parent 874764b commit e7cfb5a

File tree

5 files changed

+241
-118
lines changed

5 files changed

+241
-118
lines changed

cmake/external_projects/vllm_flash_attn.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ else()
3838
FetchContent_Declare(
3939
vllm-flash-attn
4040
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
41-
GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22
41+
GIT_TAG 0a721daebe4fa7149f06ecf3d3eabeb6dcd0f1fa
4242
GIT_PROGRESS TRUE
4343
# Don't share the vllm-flash-attn build between build types
4444
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn

vllm/attention/backends/mla/common.py

Lines changed: 92 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,8 +1043,8 @@ def __init__(
10431043
self.q_proj = q_proj
10441044
self.kv_b_proj = kv_b_proj
10451045
self.o_proj = o_proj
1046-
self.triton_fa_func = triton_attention
10471046

1047+
self.triton_fa_func = triton_attention
10481048
# Handle the differences between the flash_attn_varlen from flash_attn
10491049
# and the one from vllm_flash_attn. The former is used on RoCM and the
10501050
# latter has an additional parameter to control FA2 vs FA3
@@ -1055,6 +1055,70 @@ def __init__(
10551055
functools.partial(flash_attn_varlen_func,
10561056
fa_version=self.vllm_flash_attn_version)
10571057

1058+
# For MLA the v head dim is smaller than qk head dim so we pad out
1059+
# v with 0s to match the qk head dim for attention backends that do
1060+
# not support different headdims
1061+
# We don't need to pad V if we are on a hopper system with FA3
1062+
self._pad_v = self.vllm_flash_attn_version is None or not (
1063+
self.vllm_flash_attn_version == 3
1064+
and current_platform.get_device_capability()[0] == 9)
1065+
1066+
def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale,
1067+
return_softmax_lse, **kwargs):
1068+
maybe_padded_v = v
1069+
if self._pad_v:
1070+
maybe_padded_v = torch.nn.functional.pad(
1071+
v, [0, q.shape[-1] - v.shape[-1]], value=0)
1072+
1073+
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \
1074+
and not return_softmax_lse:
1075+
attn_out = self.triton_fa_func(
1076+
q,
1077+
k,
1078+
maybe_padded_v,
1079+
**kwargs,
1080+
)
1081+
if is_vllm_fa:
1082+
attn_out = self.flash_attn_varlen_func(
1083+
q=q,
1084+
k=k,
1085+
v=maybe_padded_v,
1086+
return_softmax_lse=return_softmax_lse,
1087+
softmax_scale=softmax_scale,
1088+
**kwargs,
1089+
)
1090+
else:
1091+
# Use return_attn_probs instead of return_softmax_lse for RoCM
1092+
attn_out = self.flash_attn_varlen_func(
1093+
q=q,
1094+
k=k,
1095+
v=maybe_padded_v,
1096+
return_attn_probs=return_softmax_lse,
1097+
softmax_scale=softmax_scale,
1098+
**kwargs,
1099+
)
1100+
1101+
# Unpack the output if there is multiple results,
1102+
# triton always returns (output, softmax_lse),
1103+
# vllm_flash_attn returns (output, softmax_lse) when
1104+
# `return_softmax_lse = True`
1105+
# flash_attn (RoCM) returns (output, softmax_lse, ...) when
1106+
# `return_attn_probs = True`
1107+
rest = None
1108+
if isinstance(attn_out, tuple):
1109+
attn_out, *rest = attn_out
1110+
1111+
# unpad if necessary
1112+
if self._pad_v:
1113+
attn_out = attn_out[..., :v.shape[-1]]
1114+
1115+
# Remain consistent with old `flash_attn_varlen_func` where there
1116+
# is only one output tensor if `return_softmax_lse` is False.
1117+
if return_softmax_lse:
1118+
assert rest is not None
1119+
return attn_out, rest[0]
1120+
return attn_out
1121+
10581122
def _v_up_proj_and_o_proj(self, x):
10591123
# Convert from (B, N, L) to (N, B, L)
10601124
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
@@ -1176,40 +1240,19 @@ def _compute_prefill_context(
11761240
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
11771241
dim=-1)
11781242

1179-
# For MLA the v head dim is smaller than qk head dim so we pad
1180-
# out v with 0s to match the qk head dim
1181-
v_padded = torch.nn.functional.pad(v,
1182-
[0, q.shape[-1] - v.shape[-1]],
1183-
value=0)
1184-
1185-
if is_vllm_fa:
1186-
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
1187-
q=q,
1188-
k=k,
1189-
v=v_padded,
1190-
cu_seqlens_q=prefill_metadata.query_start_loc,
1191-
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
1192-
max_seqlen_q=prefill_metadata.max_query_len,
1193-
max_seqlen_k=prefill_metadata.
1194-
context_chunk_max_seq_lens[i],
1195-
softmax_scale=self.scale,
1196-
causal=False, # Context is unmasked
1197-
return_softmax_lse=True,
1198-
)
1199-
else:
1200-
attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func(
1201-
q=q,
1202-
k=k,
1203-
v=v_padded,
1204-
cu_seqlens_q=prefill_metadata.query_start_loc,
1205-
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
1206-
max_seqlen_q=prefill_metadata.max_query_len,
1207-
max_seqlen_k=prefill_metadata.
1208-
context_chunk_max_seq_lens[i],
1209-
softmax_scale=self.scale,
1210-
causal=False, # Context is unmasked
1211-
return_attn_probs=True,
1212-
)
1243+
attn_output, attn_softmax_lse = \
1244+
self._flash_attn_varlen_diff_headdims(
1245+
q=q,
1246+
k=k,
1247+
v=v,
1248+
cu_seqlens_q=prefill_metadata.query_start_loc,
1249+
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
1250+
max_seqlen_q=prefill_metadata.max_query_len,
1251+
max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i],
1252+
softmax_scale=self.scale,
1253+
causal=False, # Context is unmasked
1254+
return_softmax_lse=True,
1255+
)
12131256

12141257
if output is None:
12151258
output = attn_output
@@ -1252,58 +1295,22 @@ def _forward_prefill(
12521295

12531296
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
12541297

1255-
# For MLA the v head dim is smaller than qk head dim so we pad out
1256-
# v with 0s to match the qk head dim
1257-
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
1258-
value=0)
1259-
1260-
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context:
1261-
output = self.triton_fa_func(
1262-
q,
1263-
k,
1264-
v_padded,
1265-
None,
1266-
prefill_metadata.query_start_loc,
1267-
prefill_metadata.query_start_loc,
1268-
prefill_metadata.max_prefill_seq_len,
1269-
prefill_metadata.max_prefill_seq_len,
1270-
True, # causal
1271-
self.scale,
1272-
None, # attn_mask is None unless applying ALiBi mask
1273-
)
1274-
## triton flash attention always return 2 objects
1275-
if not has_context:
1276-
output = output[0]
1277-
elif is_vllm_fa:
1278-
output = self.flash_attn_varlen_func(
1279-
q=q,
1280-
k=k,
1281-
v=v_padded,
1282-
cu_seqlens_q=prefill_metadata.query_start_loc,
1283-
cu_seqlens_k=prefill_metadata.query_start_loc,
1284-
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
1285-
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
1286-
softmax_scale=self.scale,
1287-
causal=True,
1288-
return_softmax_lse=has_context,
1289-
)
1290-
else:
1291-
output = self.flash_attn_varlen_func(
1292-
q=q,
1293-
k=k,
1294-
v=v_padded,
1295-
cu_seqlens_q=prefill_metadata.query_start_loc,
1296-
cu_seqlens_k=prefill_metadata.query_start_loc,
1297-
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
1298-
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
1299-
softmax_scale=self.scale,
1300-
causal=True,
1301-
return_attn_probs=has_context,
1302-
)
1298+
output = self._flash_attn_varlen_diff_headdims(
1299+
q=q,
1300+
k=k,
1301+
v=v,
1302+
cu_seqlens_q=prefill_metadata.query_start_loc,
1303+
cu_seqlens_k=prefill_metadata.query_start_loc,
1304+
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
1305+
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
1306+
softmax_scale=self.scale,
1307+
causal=True,
1308+
return_softmax_lse=has_context,
1309+
)
13031310

13041311
if has_context:
13051312
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
1306-
suffix_output, suffix_lse, *rest = output
1313+
suffix_output, suffix_lse = output
13071314
context_output, context_lse = self._compute_prefill_context( \
13081315
q, kv_c_and_k_pe_cache, attn_metadata)
13091316

@@ -1316,12 +1323,7 @@ def _forward_prefill(
13161323
suffix_lse=suffix_lse,
13171324
)
13181325

1319-
# slice by `:v.shape[-1]` in order to remove v headdim padding
1320-
output = output\
1321-
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
1322-
.reshape(-1, self.num_heads * v.shape[-1])
1323-
1324-
return self.o_proj(output)[0]
1326+
return self.o_proj(output.flatten(start_dim=-2))[0]
13251327

13261328
@abstractmethod
13271329
def _forward_decode(

vllm/attention/backends/utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@
22
"""Attention backend utils"""
33
from collections import defaultdict
44
from contextlib import contextmanager
5+
from dataclasses import dataclass
56
from itertools import accumulate
6-
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
7+
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
8+
TypeVar, Union)
79

810
import numpy as np
911
import torch
1012

1113
from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
1214
AttentionState)
1315
from vllm.attention.backends.abstract import AttentionType
16+
from vllm.config import ModelConfig
1417
from vllm.logger import init_logger
1518
from vllm.multimodal import MultiModalPlaceholderMap
1619
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
@@ -583,3 +586,24 @@ def get_num_prefill_decode_query_kv_tokens(
583586

584587
return (num_prefill_query_tokens, num_prefill_kv_tokens,
585588
num_decode_query_tokens)
589+
590+
591+
@dataclass
592+
class MLADims:
593+
q_lora_rank: Optional[int]
594+
kv_lora_rank: int
595+
qk_nope_head_dim: int
596+
qk_rope_head_dim: int
597+
v_head_dim: int
598+
599+
600+
def get_mla_dims(model_config: ModelConfig) -> MLADims:
601+
hf_text_config = model_config.hf_text_config
602+
603+
return MLADims(
604+
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
605+
kv_lora_rank=hf_text_config.kv_lora_rank,
606+
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
607+
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
608+
v_head_dim=hf_text_config.v_head_dim,
609+
)

0 commit comments

Comments
 (0)