-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Support FlashMLA backend cuda graph #4514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Co-authored-by: yinfan98 <[email protected]> Co-authored-by: Hongbosherlock <[email protected]>
38f575a
to
c28f9bc
Compare
Co-authored-by: yinfan98 <[email protected]> Co-authored-by: Hongbosherlock <[email protected]>
4a47be2
to
4b790da
Compare
|
||
if forward_mode.is_decode_or_idle(): | ||
seq_lens = seq_lens[:bs] | ||
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should avoid CPU-GPU synchronization by avoiding the use of seq_lens.max().item()
.
Can you derive this value from seq_lens_cpu
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I found this problem. I was too busy at work today and didn't have time to modify it. I will fix it tomorrow.
I attempted to test FlashMLA + CUDA Graph on your commit, but I was not successful. The following error occurred:
The test command I used is:
However, everything works fine as long as I don't add Environment:
|
hi @sleepcoo Great pr! but I did some simple tests, and it seems that the performance of flashmla is not as good as of triton_backend. What could be the reason? command:python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3 --max-concurrency 1 --random-input 128 --random-output 1024 --dataset-path /models/dataset/ShareGPT_V3_unfiltered_cleaned_split.json flashmla:python3 -m sglang.launch_server --model-path /models/DeepSeek-R1 triton_backend:python3 -m sglang.launch_server --model-path /models/DeepSeek-R1 result:flashmla:============ Serving Benchmark Result ============ triton_backend:============ Serving Benchmark Result ============ Environment:
|
In this PR, we have fixed the performance issues and tested it. In certain cases, flashmla has advantages. |
Motivation
Support FlashMLA backend cuda graph. Optimize index calculation, complete the calculation in init_forward
Modifications
Test
deepseekV3 accuracy test
GSM8K Accuracy: 0.980
MMLU Average accuracy: 0.878
todo