Skip to content

[Attention] Update to lastest FA3 code #13111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Apr 17, 2025

Conversation

LucasWilkinson
Copy link
Collaborator

@LucasWilkinson LucasWilkinson commented Feb 11, 2025

NOTE: Tested MLA on AMD V0 is working, V1 is broken but is also broken on main

Perf: https://docs.google.com/spreadsheets/d/1U5lsoCKuWq99Cz1QbWkc0dBn1bij1Ifb3tphE2UXJj0/edit?usp=sharing

Main:

--------------------------------------
Full Command:
VLLM_USE_V1=0 lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384,max_num_batched_tokens=1024,enable_chunked_prefill=1 --task gsm8k --num_fewshot 5 --limit 10

Extracted Result Table:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |  0.8|±  |0.1333|
|     |       |strict-match    |     5|exact_match|↑  |  0.8|±  |0.1333|
Log file saved at: logs/deepseek_v0_chunked_20250326_005420.log

--------------------------------------
Full Command:
VLLM_USE_V1=0 lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384 --task gsm8k --num_fewshot 5 --limit 10

Extracted Result Table:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |  0.8|±  |0.1333|
|     |       |strict-match    |     5|exact_match|↑  |  0.8|±  |0.1333|
Log file saved at: logs/deepseek_v0_nchunked_20250326_005531.log

--------------------------------------
Full Command:
VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384,max_num_batched_tokens=1024,enable_chunked_prefill=1 --task gsm8k --num_fewshot 5 --limit 10

Extracted Result Table:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |  0.8|±  |0.1333|
|     |       |strict-match    |     5|exact_match|↑  |  0.8|±  |0.1333|
Log file saved at: logs/deepseek_v1_chunked_20250326_005722.log

--------------------------------------
Full Command:
VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384 --task gsm8k --num_fewshot 5 --limit 10

Extracted Result Table:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |  0.8|±  |0.1333|
|     |       |strict-match    |     5|exact_match|↑  |  0.8|±  |0.1333|
Log file saved at: logs/deepseek_v1_nchunked_20250326_005836.log

--------------------------------------
Full Command:
VLLM_USE_V1=0 lm_eval --model vllm --model_args pretrained=meta-llama/Meta-Llama-3-8B,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=8192 --task gsm8k --num_fewshot 5 --limit 10

Extracted Result Table:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |  0.5|±  |0.1667|
|     |       |strict-match    |     5|exact_match|↑  |  0.5|±  |0.1667|
Log file saved at: logs/metalla_v0_20250326_005934.log

--------------------------------------
Full Command:
VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=meta-llama/Meta-Llama-3-8B,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=8192 --task gsm8k --num_fewshot 5 --limit 10

Extracted Result Table:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |  0.5|±  |0.1667|
|     |       |strict-match    |     5|exact_match|↑  |  0.5|±  |0.1667|
Log file saved at: logs/metalla_v1_20250326_010112.log

This PR:

--------------------------------------
Full Command:
VLLM_USE_V1=0 lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384,max_num_batched_tokens=1024,enable_chunked_prefill=1 --task gsm8k --num_fewshot 5 --limit 10

Extracted Result Table:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |  0.8|±  |0.1333|
|     |       |strict-match    |     5|exact_match|↑  |  0.8|±  |0.1333|
Log file saved at: logs/deepseek_v0_chunked_20250326_012034.log

--------------------------------------
Full Command:
VLLM_USE_V1=0 lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384 --task gsm8k --num_fewshot 5 --limit 10

Extracted Result Table:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |  0.8|±  |0.1333|
|     |       |strict-match    |     5|exact_match|↑  |  0.8|±  |0.1333|
Log file saved at: logs/deepseek_v0_nchunked_20250326_012150.log

--------------------------------------
Full Command:
VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384,max_num_batched_tokens=1024,enable_chunked_prefill=1 --task gsm8k --num_fewshot 5 --limit 10

Extracted Result Table:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |  0.8|±  |0.1333|
|     |       |strict-match    |     5|exact_match|↑  |  0.8|±  |0.1333|
Log file saved at: logs/deepseek_v1_chunked_20250326_012340.log

--------------------------------------
Full Command:
VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-V2-Lite-Chat,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=16384 --task gsm8k --num_fewshot 5 --limit 10

Extracted Result Table:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |  0.8|±  |0.1333|
|     |       |strict-match    |     5|exact_match|↑  |  0.8|±  |0.1333|
Log file saved at: logs/deepseek_v1_nchunked_20250326_012458.log

--------------------------------------
Full Command:
VLLM_USE_V1=0 lm_eval --model vllm --model_args pretrained=meta-llama/Meta-Llama-3-8B,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=8192 --task gsm8k --num_fewshot 5 --limit 10

Extracted Result Table:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |  0.5|±  |0.1667|
|     |       |strict-match    |     5|exact_match|↑  |  0.5|±  |0.1667|
Log file saved at: logs/metalla_v0_20250326_012559.log

--------------------------------------
Full Command:
VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=meta-llama/Meta-Llama-3-8B,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=8192 --task gsm8k --num_fewshot 5 --limit 10

Extracted Result Table:
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |  0.5|±  |0.1633|
|     |       |strict-match    |     5|exact_match|↑  |  0.5|±  |0.1633|
Log file saved at: logs/metalla_v1_20250326_012743.log

See accuracy drops for:

VLLM_USE_V1=1 lm_eval --model vllm --model_args pretrained=meta-llama/Meta-Llama-3-8B,tensor_parallel_size=2,dtype=auto,gpu_memory_utilization=0.9,trust_remote_code=True,max_model_len=8192 --task gsm8k --num_fewshot 5 --limit 10

Due to the dynamic split scheduler

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the ci/build label Feb 11, 2025
@LucasWilkinson LucasWilkinson changed the title [WIP][Attention] Update to lastest FA3 code that supports different K and V head dims [Attention] Update to lastest FA3 code that supports different K and V head dims Feb 11, 2025
@LucasWilkinson LucasWilkinson marked this pull request as ready for review February 11, 2025 20:56
@LucasWilkinson
Copy link
Collaborator Author

@khluu can we run the perf CI on this? would be nice to check for regressions since theres alot of FA changes

if has_context:
if not current_platform.is_cuda():
raise NotImplementedError(
"Chunked Prefill for MLA is not currently supported on"
"non-cuda platforms")
output = self.flash_attn_varlen_func(
output = self.flash_attn_varlen_diff_headdims(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to handle the fact that flash_attn_varlen_diff_headdims returns both output and *rest in this case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we did, but the return was slicing a tensor

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link

mergify bot commented Feb 27, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

causal=True,
return_attn_probs=has_context,
)
if has_context and not current_platform.is_cuda():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this not work for ROCm?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry this was cruft from an earlier slack discussion that cast doubts on if return_softmax_lse was supported on RoCM

Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Lucas Wilkinson <[email protected]>
@simon-mo simon-mo merged commit 183dad7 into vllm-project:main Apr 17, 2025
42 of 48 checks passed
@chaunceyjiang
Copy link
Contributor

Hi, @LucasWilkinson @simon-mo. This PR uses a new function get_scheduler_metadata:

https://github.com/vllm-project/flash-attention/blob/main/vllm_flash_attn/flash_attn_interface.py#L78

It seems that a new release of vllm_flash_attn is needed.

@chaunceyjiang
Copy link
Contributor

https://pypi.org/project/vllm-flash-attn/

It’s been a long time since vllm-flash-attn had a new release. Should we consider publishing a new version?

@LucasWilkinson
Copy link
Collaborator Author

Hi, @LucasWilkinson @simon-mo. This PR uses a new function get_scheduler_metadata:

https://github.com/vllm-project/flash-attention/blob/main/vllm_flash_attn/flash_attn_interface.py#L78

It seems that a new release of vllm_flash_attn is needed.

We currently ship vllm_flash_attn inside the vLLM wheel, so you will likely need to rebuild from scratch: #16813 (comment)

@nnding
Copy link

nnding commented Apr 21, 2025

Hi, @LucasWilkinson. Does the latest FA3 still fail on Lovelace GPUs due to shared memory limits for some shapes?

yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>
Signed-off-by: Agata Dobrzyniewicz <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants