-
-
Notifications
You must be signed in to change notification settings - Fork 8.5k
[Bugfix][Kernel] fix potential cuda graph broken for merge_attn_states kernel #16693
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix][Kernel] fix potential cuda graph broken for merge_attn_states kernel #16693
Conversation
Signed-off-by: DefTruth <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good.
Could you paste the test plan in the PR description?
@houseroad local unit tests for this kernel is passed python3 -m pytest -s test_merge_attn_states.py
INFO 04-16 12:59:54 [__init__.py:239] Automatically detected platform cuda.
/usr/local/lib/python3.10/dist-packages/pytest_asyncio/plugin.py:208: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset.
The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session"
warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET))
================================================================================== test session starts ===================================================================================
platform linux -- Python 3.10.12, pytest-8.3.3, pluggy-1.5.0
rootdir: /workspace/dev/vipshop/vllm
configfile: pyproject.toml
plugins: anyio-4.9.0, langsmith-0.3.18, forked-1.6.0, shard-0.1.2, buildkite-test-collector-0.1.9, mock-3.14.0, asyncio-0.24.0, rerunfailures-14.0
asyncio: mode=strict, default_loop_scope=None
collected 648 items
Running 648 items in this shard
test_merge_attn_states.py
NUM_TOKENS:256, NUM_HEADS:4, HEAD_SIZE:32, DTYPE: torch.float32, Device: NVIDIA L20
Torch time: 0.164301ms
Triton time: 0.062115ms
CUDA time: 0.018589ms, Performance: 3.34154x
----------------------------------------------------------------------------------------------------
Output all match, max abs diff:
(Triton vs Torch) : 4.76837158203125e-07
(CUDA vs Torch) : 2.384185791015625e-07
(CUDA vs Triton): 4.76837158203125e-07
----------------------------------------------------------------------------------------------------
Output LSE all match, max abs diff:
(Triton vs Torch) : 1.1920928955078125e-07
(CUDA vs Torch) : 0.0
(CUDA vs Triton): 1.1920928955078125e-07
----------------------------------------------------------------------------------------------------
All output values test passed! All inf values are correctly replaced with -inf.
----------------------------------------------------------------------------------------------------
.
NUM_TOKENS:512, NUM_HEADS:4, HEAD_SIZE:32, DTYPE: torch.float32, Device: NVIDIA L20
Torch time: 0.179862ms
Triton time: 0.059288ms
CUDA time: 0.017562ms, Performance: 3.37600x
----------------------------------------------------------------------------------------------------
Output all match, max abs diff:
(Triton vs Torch) : 4.76837158203125e-07
(CUDA vs Torch) : 2.384185791015625e-07
(CUDA vs Triton): 4.76837158203125e-07
----------------------------------------------------------------------------------------------------
Output LSE all match, max abs diff:
(Triton vs Torch) : 2.384185791015625e-07
(CUDA vs Torch) : 0.0
(CUDA vs Triton): 2.384185791015625e-07
----------------------------------------------------------------------------------------------------
All output values test passed! All inf values are correctly replaced with -inf.
----------------------------------------------------------------------------------------------------
.
NUM_TOKENS:613, NUM_HEADS:4, HEAD_SIZE:32, DTYPE: torch.float32, Device: NVIDIA L20
Torch time: 0.165590ms
Triton time: 0.058877ms
CUDA time: 0.020069ms, Performance: 2.93375x
----------------------------------------------------------------------------------------------------
Output all match, max abs diff:
(Triton vs Torch) : 4.76837158203125e-07
(CUDA vs Torch) : 2.384185791015625e-07
(CUDA vs Triton): 4.76837158203125e-07
----------------------------------------------------------------------------------------------------
Output LSE all match, max abs diff:
(Triton vs Torch) : 2.384185791015625e-07
(CUDA vs Torch) : 0.0
(CUDA vs Triton): 2.384185791015625e-07
----------------------------------------------------------------------------------------------------
All output values test passed! All inf values are correctly replaced with -inf.
----------------------------------------------------------------------------------------------------
// ......
============================================================================ 648 passed in 123.44s (0:02:03) ============================================================================= some performance results:
for cuda graph, i only test it on sglang, see sgl-project/sglang#5419., the attn part for prefill/chunk-prefill in vllm seems will running with eager mode. Cuda-graph is currently enabled for decoding only. so, that's why vllm will not encounter the same cuda graph error for merge_attn_states kernel as sglang. But, since the merge_attn_states kernel is often active as a fundamental kernel in many scenarios, it would be better to bind the merge_attn_states kernel to the CUDA stream, as required by the CUDA graph. This binding won't affect the performance. vllm/vllm/attention/backends/mla/common.py Lines 452 to 460 in 44fa4d5
|
…s kernel (vllm-project#16693) Signed-off-by: DefTruth <[email protected]>
…s kernel (vllm-project#16693) Signed-off-by: DefTruth <[email protected]> Signed-off-by: Yang Wang <[email protected]>
…s kernel (vllm-project#16693) Signed-off-by: DefTruth <[email protected]>
…s kernel (vllm-project#16693) Signed-off-by: DefTruth <[email protected]>
…s kernel (vllm-project#16693) Signed-off-by: DefTruth <[email protected]> Signed-off-by: Agata Dobrzyniewicz <[email protected]>
…s kernel (vllm-project#16693) Signed-off-by: DefTruth <[email protected]> Signed-off-by: Mu Huai <[email protected]>
Fix potential CUDA graph broken for the merge_attn_states kernel. A CUDA graph error related to merge_state was observed in sglang (sgl-project/sglang#5404) and fixed in sgl-project/sglang#5419. Since the merge_attn_states kernel is often active as a fundamental kernel in many scenarios, it would be better to bind the merge_attn_states kernel to the CUDA stream, as required by the CUDA graph. This binding won't affect the performance.