Skip to content

[BugFix] Fix vllm_flash_attn install issues #17267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
/vllm/model_executor/guided_decoding @mgoin @russellb
/vllm/multimodal @DarkLight1337 @ywang96
/vllm/vllm_flash_attn @LucasWilkinson
CMakeLists.txt @tlrmchlsmth

# vLLM V1
Expand Down
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

# vllm-flash-attn built from source
vllm/vllm_flash_attn/*
!vllm/vllm_flash_attn/__init__.py
!vllm/vllm_flash_attn/fa_utils.py

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
23 changes: 17 additions & 6 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,11 +269,12 @@ def run(self):
# First, run the standard build_ext command to compile the extensions
super().run()

# copy vllm/vllm_flash_attn/*.py from self.build_lib to current
# copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current
# directory so that they can be included in the editable build
import glob
files = glob.glob(
os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "*.py"))
files = glob.glob(os.path.join(self.build_lib, "vllm",
"vllm_flash_attn", "**", "*.py"),
recursive=True)
for file in files:
dst_file = os.path.join("vllm/vllm_flash_attn",
os.path.basename(file))
Expand Down Expand Up @@ -377,12 +378,22 @@ def run(self) -> None:
"vllm/_flashmla_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
"vllm/vllm_flash_attn/flash_attn_interface.py",
"vllm/cumem_allocator.abi3.so",
# "vllm/_version.py", # not available in nightly wheels yet
]
file_members = filter(lambda x: x.filename in files_to_copy,
wheel.filelist)
import re

file_members = list(
filter(lambda x: x.filename in files_to_copy, wheel.filelist))

# vllm_flash_attn python code:
# Regex from
# `glob.translate('vllm/vllm_flash_attn/**/*.py', recursive=True)`
compiled_regex = re.compile(
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
file_members += list(
filter(lambda x: compiled_regex.match(x.filename),
wheel.filelist))

for file in file_members:
print(f"Extracting and including {file.filename} "
Expand Down
4 changes: 2 additions & 2 deletions vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
is_all_encoder_attn_metadata_set, is_block_tables_empty)
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.logger import init_logger
from vllm.multimodal import MultiModalPlaceholderMap
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
flash_attn_with_kvcache)
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)

if TYPE_CHECKING:
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
Expand Down Expand Up @@ -689,7 +689,7 @@
assert output is not None, "Output tensor must be provided."

# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16:

Check failure on line 692 in vllm/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unsupported operand types for > ("int" and "None") [operator]

Check failure on line 692 in vllm/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unsupported operand types for > ("int" and "None") [operator]

Check failure on line 692 in vllm/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unsupported operand types for > ("int" and "None") [operator]

Check failure on line 692 in vllm/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unsupported operand types for > ("int" and "None") [operator]

Check failure on line 692 in vllm/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unsupported operand types for > ("int" and "None") [operator]

Check failure on line 692 in vllm/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unsupported operand types for > ("int" and "None") [operator]

Check failure on line 692 in vllm/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unsupported operand types for > ("int" and "None") [operator]

Check failure on line 692 in vllm/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unsupported operand types for > ("int" and "None") [operator]

Check failure on line 692 in vllm/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unsupported operand types for > ("int" and "None") [operator]

Check failure on line 692 in vllm/attention/backends/flash_attn.py

View workflow job for this annotation

GitHub Actions / pre-commit

Unsupported operand types for > ("int" and "None") [operator]
assert (
layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), (
"key/v_scale is only supported in FlashAttention 3 with "
Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@
compute_slot_mapping_start_idx,
is_block_tables_empty)
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
Expand All @@ -214,7 +215,6 @@
from vllm.platforms import current_platform
from vllm.triton_utils import HAS_TRITON
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version

if HAS_TRITON:
from vllm.attention.ops.triton_flash_attention import triton_attention
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1377,7 +1377,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
supported = False
if fp8_attention and will_use_fa:
from vllm.vllm_flash_attn.fa_utils import (
from vllm.attention.utils.fa_utils import (
flash_attn_supports_fp8)
supported = flash_attn_supports_fp8()
if not supported:
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)

if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,14 +197,14 @@
MLAAttentionImpl)
from vllm.attention.backends.utils import get_mla_dims
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version

try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
Expand Down
22 changes: 0 additions & 22 deletions vllm/vllm_flash_attn/__init__.py

This file was deleted.

245 changes: 0 additions & 245 deletions vllm/vllm_flash_attn/flash_attn_interface.pyi

This file was deleted.