Skip to content

[Attention] Update to lastest FA3 code #13111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/external_projects/vllm_flash_attn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG dc9d410b3e2d6534a4c70724c2515f4def670a22
GIT_TAG 0a721daebe4fa7149f06ecf3d3eabeb6dcd0f1fa
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
182 changes: 92 additions & 90 deletions vllm/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,8 +1043,8 @@ def __init__(
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
self.triton_fa_func = triton_attention

self.triton_fa_func = triton_attention
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
Expand All @@ -1055,6 +1055,70 @@ def __init__(
functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version)

# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim for attention backends that do
# not support different headdims
# We don't need to pad V if we are on a hopper system with FA3
self._pad_v = self.vllm_flash_attn_version is None or not (
self.vllm_flash_attn_version == 3
and current_platform.get_device_capability()[0] == 9)

def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale,
return_softmax_lse, **kwargs):
maybe_padded_v = v
if self._pad_v:
maybe_padded_v = torch.nn.functional.pad(
v, [0, q.shape[-1] - v.shape[-1]], value=0)

if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \
and not return_softmax_lse:
attn_out = self.triton_fa_func(
q,
k,
maybe_padded_v,
**kwargs,
)
if is_vllm_fa:
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
v=maybe_padded_v,
return_softmax_lse=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)
else:
# Use return_attn_probs instead of return_softmax_lse for RoCM
attn_out = self.flash_attn_varlen_func(
q=q,
k=k,
v=maybe_padded_v,
return_attn_probs=return_softmax_lse,
softmax_scale=softmax_scale,
**kwargs,
)

# Unpack the output if there is multiple results,
# triton always returns (output, softmax_lse),
# vllm_flash_attn returns (output, softmax_lse) when
# `return_softmax_lse = True`
# flash_attn (RoCM) returns (output, softmax_lse, ...) when
# `return_attn_probs = True`
rest = None
if isinstance(attn_out, tuple):
attn_out, *rest = attn_out

# unpad if necessary
if self._pad_v:
attn_out = attn_out[..., :v.shape[-1]]

# Remain consistent with old `flash_attn_varlen_func` where there
# is only one output tensor if `return_softmax_lse` is False.
if return_softmax_lse:
assert rest is not None
return attn_out, rest[0]
return attn_out

def _v_up_proj_and_o_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
Expand Down Expand Up @@ -1176,40 +1240,19 @@ def _compute_prefill_context(
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)

# For MLA the v head dim is smaller than qk head dim so we pad
# out v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v,
[0, q.shape[-1] - v.shape[-1]],
value=0)

if is_vllm_fa:
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.
context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)
else:
attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.
context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_attn_probs=True,
)
attn_output, attn_softmax_lse = \
self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
max_seqlen_q=prefill_metadata.max_query_len,
max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)

if output is None:
output = attn_output
Expand Down Expand Up @@ -1252,58 +1295,22 @@ def _forward_prefill(

k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)

# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)

if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN and not has_context:
output = self.triton_fa_func(
q,
k,
v_padded,
None,
prefill_metadata.query_start_loc,
prefill_metadata.query_start_loc,
prefill_metadata.max_prefill_seq_len,
prefill_metadata.max_prefill_seq_len,
True, # causal
self.scale,
None, # attn_mask is None unless applying ALiBi mask
)
## triton flash attention always return 2 objects
if not has_context:
output = output[0]
elif is_vllm_fa:
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=has_context,
)
else:
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_attn_probs=has_context,
)
output = self._flash_attn_varlen_diff_headdims(
q=q,
k=k,
v=v,
cu_seqlens_q=prefill_metadata.query_start_loc,
cu_seqlens_k=prefill_metadata.query_start_loc,
max_seqlen_q=prefill_metadata.max_prefill_seq_len,
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=has_context,
)

if has_context:
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
suffix_output, suffix_lse, *rest = output
suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata)

Expand All @@ -1316,12 +1323,7 @@ def _forward_prefill(
suffix_lse=suffix_lse,
)

# slice by `:v.shape[-1]` in order to remove v headdim padding
output = output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
.reshape(-1, self.num_heads * v.shape[-1])

return self.o_proj(output)[0]
return self.o_proj(output.flatten(start_dim=-2))[0]

@abstractmethod
def _forward_decode(
Expand Down
26 changes: 25 additions & 1 deletion vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,18 @@
"""Attention backend utils"""
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from itertools import accumulate
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Type, TypeVar, Union
from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type,
TypeVar, Union)

import numpy as np
import torch

from vllm.attention import (AttentionMetadata, AttentionMetadataBuilder,
AttentionState)
from vllm.attention.backends.abstract import AttentionType
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
Expand Down Expand Up @@ -583,3 +586,24 @@ def get_num_prefill_decode_query_kv_tokens(

return (num_prefill_query_tokens, num_prefill_kv_tokens,
num_decode_query_tokens)


@dataclass
class MLADims:
q_lora_rank: Optional[int]
kv_lora_rank: int
qk_nope_head_dim: int
qk_rope_head_dim: int
v_head_dim: int


def get_mla_dims(model_config: ModelConfig) -> MLADims:
hf_text_config = model_config.hf_text_config

return MLADims(
q_lora_rank=getattr(hf_text_config, "q_lora_rank", None),
kv_lora_rank=hf_text_config.kv_lora_rank,
qk_nope_head_dim=hf_text_config.qk_nope_head_dim,
qk_rope_head_dim=hf_text_config.qk_rope_head_dim,
v_head_dim=hf_text_config.v_head_dim,
)
Loading