Skip to content

[Bugfix][Kernel] fix potential cuda graph broken for merge_attn_states kernel #16693

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 16, 2025

Conversation

DefTruth
Copy link
Contributor

@DefTruth DefTruth commented Apr 16, 2025

Fix potential CUDA graph broken for the merge_attn_states kernel. A CUDA graph error related to merge_state was observed in sglang (sgl-project/sglang#5404) and fixed in sgl-project/sglang#5419. Since the merge_attn_states kernel is often active as a fundamental kernel in many scenarios, it would be better to bind the merge_attn_states kernel to the CUDA stream, as required by the CUDA graph. This binding won't affect the performance.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

Could you paste the test plan in the PR description?

@houseroad houseroad added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 16, 2025
@DefTruth
Copy link
Contributor Author

DefTruth commented Apr 16, 2025

Looks good.

Could you paste the test plan in the PR description?

@houseroad local unit tests for this kernel is passed

python3 -m pytest -s test_merge_attn_states.py
INFO 04-16 12:59:54 [__init__.py:239] Automatically detected platform cuda.
/usr/local/lib/python3.10/dist-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"

  warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.10.12, pytest-8.3.3, pluggy-1.5.0
rootdir: /workspace/dev/vipshop/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0, langsmith-0.3.18, forked-1.6.0, shard-0.1.2, buildkite-test-collector-0.1.9, mock-3.14.0, asyncio-0.24.0, rerunfailures-14.0
asyncio: mode=strict, default_loop_scope=None
collected 648 items
Running 648 items in this shard

test_merge_attn_states.py
NUM_TOKENS:256, NUM_HEADS:4, HEAD_SIZE:32, DTYPE: torch.float32, Device: NVIDIA L20
 Torch time: 0.164301ms
Triton time: 0.062115ms
  CUDA time: 0.018589ms, Performance: 3.34154x
----------------------------------------------------------------------------------------------------
Output all match, max abs diff:
(Triton vs Torch) : 4.76837158203125e-07
  (CUDA vs Torch) : 2.384185791015625e-07
  (CUDA vs Triton): 4.76837158203125e-07
----------------------------------------------------------------------------------------------------
Output LSE all match, max abs diff:
(Triton vs Torch) : 1.1920928955078125e-07
  (CUDA vs Torch) : 0.0
  (CUDA vs Triton): 1.1920928955078125e-07
----------------------------------------------------------------------------------------------------
All output values test passed! All inf values are correctly replaced with -inf.
----------------------------------------------------------------------------------------------------
.
NUM_TOKENS:512, NUM_HEADS:4, HEAD_SIZE:32, DTYPE: torch.float32, Device: NVIDIA L20
 Torch time: 0.179862ms
Triton time: 0.059288ms
  CUDA time: 0.017562ms, Performance: 3.37600x
----------------------------------------------------------------------------------------------------
Output all match, max abs diff:
(Triton vs Torch) : 4.76837158203125e-07
  (CUDA vs Torch) : 2.384185791015625e-07
  (CUDA vs Triton): 4.76837158203125e-07
----------------------------------------------------------------------------------------------------
Output LSE all match, max abs diff:
(Triton vs Torch) : 2.384185791015625e-07
  (CUDA vs Torch) : 0.0
  (CUDA vs Triton): 2.384185791015625e-07
----------------------------------------------------------------------------------------------------
All output values test passed! All inf values are correctly replaced with -inf.
----------------------------------------------------------------------------------------------------
.
NUM_TOKENS:613, NUM_HEADS:4, HEAD_SIZE:32, DTYPE: torch.float32, Device: NVIDIA L20
 Torch time: 0.165590ms
Triton time: 0.058877ms
  CUDA time: 0.020069ms, Performance: 2.93375x
----------------------------------------------------------------------------------------------------
Output all match, max abs diff:
(Triton vs Torch) : 4.76837158203125e-07
  (CUDA vs Torch) : 2.384185791015625e-07
  (CUDA vs Triton): 4.76837158203125e-07
----------------------------------------------------------------------------------------------------
Output LSE all match, max abs diff:
(Triton vs Torch) : 2.384185791015625e-07
  (CUDA vs Torch) : 0.0
  (CUDA vs Triton): 2.384185791015625e-07
----------------------------------------------------------------------------------------------------
All output values test passed! All inf values are correctly replaced with -inf.
----------------------------------------------------------------------------------------------------
// ......
============================================================================ 648 passed in 123.44s (0:02:03) =============================================================================

some performance results:

tokens heads headsize dtype device torch triton cuda speedup
256 4 32 float32 L20 0.16430ms 0.06212ms 0.01859ms 3.3415x
512 4 32 float32 L20 0.17986ms 0.05929ms 0.01756ms 3.3760x
613 4 32 float32 L20 0.16559ms 0.05888ms 0.02007ms 2.9337x
1024 4 32 float32 L20 0.16312ms 0.05745ms 0.01756ms 3.2713x
1536 4 32 float32 L20 0.16476ms 0.05872ms 0.01874ms 3.1334x
4096 4 32 float32 L20 0.16675ms 0.06226ms 0.01705ms 3.6520x
256 8 32 float32 L20 0.19216ms 0.05703ms 0.01746ms 3.2658x
256 32 256 bfloat16 L20 0.16143ms 0.05386ms 0.01649ms 3.2657x
512 32 256 bfloat16 L20 0.18058ms 0.05392ms 0.01684ms 3.2010x
613 32 256 bfloat16 L20 0.19149ms 0.05704ms 0.01736ms 3.2855x
1024 32 256 bfloat16 L20 0.33562ms 0.06523ms 0.01916ms 3.4053x
1536 32 256 bfloat16 L20 0.50728ms 0.07685ms 0.02422ms 3.1729x
4096 32 256 bfloat16 L20 1.32142ms 0.32629ms 0.30771ms 1.0604x
256 48 256 bfloat16 L20 0.16998ms 0.05412ms 0.01736ms 3.1181x
512 48 256 bfloat16 L20 0.21401ms 0.06036ms 0.01720ms 3.5087x
613 48 256 bfloat16 L20 0.29475ms 0.06297ms 0.01803ms 3.4921x
1024 48 256 bfloat16 L20 0.50677ms 0.07680ms 0.02417ms 3.1778x
1536 48 256 bfloat16 L20 0.79488ms 0.20915ms 0.16789ms 1.2458x
4096 48 256 bfloat16 L20 1.91892ms 0.52148ms 0.45199ms 1.1537x
256 64 256 bfloat16 L20 0.18099ms 0.05652ms 0.01726ms 3.2747x
512 64 256 bfloat16 L20 0.33525ms 0.06589ms 0.01947ms 3.3851x
613 64 256 bfloat16 L20 0.42850ms 0.06983ms 0.02104ms 3.3190x

for cuda graph, i only test it on sglang, see sgl-project/sglang#5419., the attn part for prefill/chunk-prefill in vllm seems will running with eager mode. Cuda-graph is currently enabled for decoding only. so, that's why vllm will not encounter the same cuda graph error for merge_attn_states kernel as sglang. But, since the merge_attn_states kernel is often active as a fundamental kernel in many scenarios, it would be better to bind the merge_attn_states kernel to the CUDA stream, as required by the CUDA graph. This binding won't affect the performance.

# Whether or not if cuda graph is enabled.
# Cuda-graph is currently enabled for decoding only.
# TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
use_cuda_graph: bool
# New for MLA (compared to FlashAttention)
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor

@vllm-bot vllm-bot merged commit e82ee40 into vllm-project:main Apr 16, 2025
82 of 89 checks passed
lionelvillard pushed a commit to lionelvillard/vllm that referenced this pull request Apr 17, 2025
yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
@DefTruth DefTruth deleted the fix-potential-cuda-graph-broken branch July 2, 2025 05:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants