Skip to content

Support FlashMLA backend cuda graph #4514

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 19, 2025

Conversation

sleepcoo
Copy link
Collaborator

@sleepcoo sleepcoo commented Mar 17, 2025

Motivation

Support FlashMLA backend cuda graph. Optimize index calculation, complete the calculation in init_forward

Modifications

  • Optimize the FlashInfer block table calculation logic to compute only once during the forward pass.
  • Support FlashMLA backend CUDA Graph.
  • Automatically set page=64 when launching FlashMLA.

Test

deepseekV3 accuracy test
GSM8K Accuracy: 0.980
MMLU Average accuracy: 0.878

todo

  • performance test
  • add unit test
  • enable speculative sampling in FlashMLA in the next PR @ispobock

@sleepcoo sleepcoo changed the title Optimize index calculation, complete the calculation in init_forward Support FlashMLA backend cuda graph Mar 17, 2025
@zhyncs zhyncs self-assigned this Mar 17, 2025
Co-authored-by: yinfan98 <[email protected]>
Co-authored-by: Hongbosherlock <[email protected]>
@sleepcoo sleepcoo force-pushed the opt-flashmla-backend branch from 38f575a to c28f9bc Compare March 18, 2025 04:28
@sleepcoo sleepcoo force-pushed the opt-flashmla-backend branch from 4a47be2 to 4b790da Compare March 18, 2025 05:10
@sleepcoo sleepcoo marked this pull request as ready for review March 18, 2025 09:28
@zhyncs zhyncs merged commit b6944f9 into sgl-project:main Mar 19, 2025
1 of 18 checks passed

if forward_mode.is_decode_or_idle():
seq_lens = seq_lens[:bs]
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should avoid CPU-GPU synchronization by avoiding the use of seq_lens.max().item().
Can you derive this value from seq_lens_cpu?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found this problem. I was too busy at work today and didn't have time to modify it. I will fix it tomorrow.

@Nekofish-L
Copy link

I attempted to test FlashMLA + CUDA Graph on your commit, but I was not successful. The following error occurred:

Assertion failed: /root/miniconda3/lib/python3.12/site-packages/deep_gemm/jit/../include/deep_gemm/tma_utils.cuh:80, condition: result == CUDA_SUCCESS
Fatal Python error: PyThreadState_Get: the function must be called with the GIL held, but the GIL is released (the current Python thread state is NULL)
Python runtime state: initialized

Thread 0x00007ff6337c6700 (most recent call first):
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/threading.py", line 331 in wait
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/threading.py", line 629 in wait
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/site-packages/tqdm/_monitor.py", line 60 in run
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/threading.py", line 1045 in _bootstrap_inner
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/threading.py", line 1002 in _bootstrap

Current thread 0x00007ffb40851400 (most recent call first):
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/site-packages/deep_gemm/jit/runtime.py", line 45 in __call__
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/site-packages/deep_gemm/jit_kernels/gemm.py", line 178 in gemm_fp8_fp8_bf16_nt
  File "/data/deepseek/sglang/python/sglang/srt/layers/quantization/fp8_kernel.py", line 61 in deep_gemm_fp8_fp8_bf16_nt
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/site-packages/torch/_ops.py", line 1116 in __call__
  File "/data/deepseek/sglang/python/sglang/srt/layers/quantization/fp8_kernel.py", line 782 in w8a8_block_fp8_matmul
  File "/data/deepseek/sglang/python/sglang/srt/layers/quantization/fp8_utils.py", line 144 in apply_w8a8_block_fp8_linear
  File "/data/deepseek/sglang/python/sglang/srt/layers/quantization/fp8.py", line 422 in apply
  File "/data/deepseek/sglang/python/sglang/srt/layers/linear.py", line 1277 in forward
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/data/deepseek/sglang/python/sglang/srt/models/deepseek_v2.py", line 813 in forward_absorb
  File "/data/deepseek/sglang/python/sglang/srt/models/deepseek_v2.py", line 698 in forward
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/data/deepseek/sglang/python/sglang/srt/models/deepseek_v2.py", line 1061 in forward
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/data/deepseek/sglang/python/sglang/srt/models/deepseek_v2.py", line 1166 in forward
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747 in _call_impl
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736 in _wrapped_call_impl
  File "/data/deepseek/sglang/python/sglang/srt/models/deepseek_v2.py", line 1206 in forward
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116 in decorate_context
  File "/data/deepseek/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 433 in run_once
  File "/data/deepseek/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 440 in capture_one_batch_size
  File "/data/deepseek/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 348 in capture
  File "/data/deepseek/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 264 in __init__
  File "/data/deepseek/sglang/python/sglang/srt/model_executor/model_runner.py", line 909 in init_cuda_graphs
  File "/data/deepseek/sglang/python/sglang/srt/model_executor/model_runner.py", line 206 in initialize
  File "/data/deepseek/sglang/python/sglang/srt/model_executor/model_runner.py", line 168 in __init__
  File "/data/deepseek/sglang/python/sglang/srt/managers/tp_worker.py", line 74 in __init__
  File "/data/deepseek/sglang/python/sglang/srt/managers/tp_worker_overlap_thread.py", line 63 in __init__
  File "/data/deepseek/sglang/python/sglang/srt/managers/scheduler.py", line 227 in __init__
  File "/data/deepseek/sglang/python/sglang/srt/managers/scheduler.py", line 1809 in run_scheduler_process
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/multiprocessing/process.py", line 108 in run
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/multiprocessing/process.py", line 314 in _bootstrap
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/multiprocessing/spawn.py", line 135 in _main
  File "/root/miniconda3/envs/sgl_main/lib/python3.11/multiprocessing/spawn.py", line 122 in spawn_main
  File "<string>", line 1 in <module>

Extension modules: numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, psutil._psutil_linux, psutil._psutil_posix, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, zmq.backend.cython._zmq, multidict._multidict, yarl._quoting_c, propcache._helpers_c, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket.mask, aiohttp._websocket.reader_c, frozenlist._frozenlist, uvloop.loop, setproctitle, yaml._yaml, markupsafe._speedups, PIL._imaging, PIL._imagingft, sentencepiece._sentencepiece, msgspec._core, msgpack._cmsgpack, google._upb._message, ray._raylet, cuda_utils, regex._regex, __triton_launcher (total: 52)

The test command I used is:

python3 -m sglang.launch_server --model /data/deepseek/DeepSeek-R1 \
    --dist-init-addr $master_ip:5000 --nnodes 2 --node-rank 0 \
    --host 0.0.0.0 --port 8124 --trust-remote-code --tp 16 --enable-flashmla --disable-cuda-graph


python3 -m sglang.launch_server --model /data/deepseek/DeepSeek-R1 \
    --dist-init-addr $master_ip:5000 --nnodes 2 --node-rank 1 \
    --host 0.0.0.0 --port 8124 --trust-remote-code --tp 16 --enable-flashmla --disable-cuda-graph

However, everything works fine as long as I don't add --enable-flashmha.

Environment:

  • CentOS
  • 2 * H20 * 8
sgl-kernel                        0.0.5.post3
sglang                            0.4.4.post1         /data/deepseek/sglang/python
flash_mla                         1.0.0+b31bfe7

@tianchongchong
Copy link

hi @sleepcoo Great pr! but I did some simple tests, and it seems that the performance of flashmla is not as good as of triton_backend. What could be the reason?

command:

python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3 --max-concurrency 1 --random-input 128 --random-output 1024 --dataset-path /models/dataset/ShareGPT_V3_unfiltered_cleaned_split.json

flashmla:

python3 -m sglang.launch_server --model-path /models/DeepSeek-R1
--tensor-parallel-size 8 --disable-custom-all-reduce --host 0.0.0.0
--port 30000 --trust-remote-code --mem-fraction-static 0.9 --enable-flashmla

triton_backend:

python3 -m sglang.launch_server --model-path /models/DeepSeek-R1
--tensor-parallel-size 8 --disable-custom-all-reduce --host 0.0.0.0
--port 30000 --trust-remote-code --mem-fraction-static 0.9

result:

flashmla:

============ Serving Benchmark Result ============
Backend: sglang
Traffic request rate: inf
Max reqeuest concurrency: 1
Successful requests: 3
Benchmark duration (s): 112.67
Total input tokens: 159
Total generated tokens: 2467
Total generated tokens (retokenized): 2463
Request throughput (req/s): 0.03
Input token throughput (tok/s): 1.41
Output token throughput (tok/s): 21.90
Total token throughput (tok/s): 23.31
Concurrency: 1.00
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 37537.68
Median E2E Latency (ms): 38623.17
---------------Time to First Token----------------
Mean TTFT (ms): 320.54
Median TTFT (ms): 324.43
P99 TTFT (ms): 332.49
---------------Inter-Token Latency----------------
Mean ITL (ms): 45.31
Median ITL (ms): 45.33
P95 ITL (ms): 45.92
P99 ITL (ms): 46.17
Max ITL (ms): 50.67

triton_backend:

============ Serving Benchmark Result ============
Backend: sglang
Traffic request rate: inf
Max reqeuest concurrency: 1
Successful requests: 3
Benchmark duration (s): 110.48
Total input tokens: 159
Total generated tokens: 2467
Total generated tokens (retokenized): 2466
Request throughput (req/s): 0.03
Input token throughput (tok/s): 1.44
Output token throughput (tok/s): 22.33
Total token throughput (tok/s): 23.77
Concurrency: 1.00
----------------End-to-End Latency----------------
Mean E2E Latency (ms): 36809.89
Median E2E Latency (ms): 37824.62
---------------Time to First Token----------------
Mean TTFT (ms): 399.44
Median TTFT (ms): 350.24
P99 TTFT (ms): 523.33
---------------Inter-Token Latency----------------
Mean ITL (ms): 44.33
Median ITL (ms): 44.33
P95 ITL (ms): 44.83
P99 ITL (ms): 44.95
Max ITL (ms): 47.54

Environment:

  • Ubuntu 22.04
  • H20 * 8

@sleepcoo
Copy link
Collaborator Author

hi @sleepcoo Great pr! but I did some simple tests, and it seems that the performance of flashmla is not as good as of triton_backend. What could be the reason?

command:

python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompts 3 --max-concurrency 1 --random-input 128 --random-output 1024 --dataset-path /models/dataset/ShareGPT_V3_unfiltered_cleaned_split.json

flashmla:

python3 -m sglang.launch_server --model-path /models/DeepSeek-R1 --tensor-parallel-size 8 --disable-custom-all-reduce --host 0.0.0.0 --port 30000 --trust-remote-code --mem-fraction-static 0.9 --enable-flashmla

triton_backend:

python3 -m sglang.launch_server --model-path /models/DeepSeek-R1 --tensor-parallel-size 8 --disable-custom-all-reduce --host 0.0.0.0 --port 30000 --trust-remote-code --mem-fraction-static 0.9

result:

flashmla:

============ Serving Benchmark Result ============ Backend: sglang Traffic request rate: inf Max reqeuest concurrency: 1 Successful requests: 3 Benchmark duration (s): 112.67 Total input tokens: 159 Total generated tokens: 2467 Total generated tokens (retokenized): 2463 Request throughput (req/s): 0.03 Input token throughput (tok/s): 1.41 Output token throughput (tok/s): 21.90 Total token throughput (tok/s): 23.31 Concurrency: 1.00 ----------------End-to-End Latency---------------- Mean E2E Latency (ms): 37537.68 Median E2E Latency (ms): 38623.17 ---------------Time to First Token---------------- Mean TTFT (ms): 320.54 Median TTFT (ms): 324.43 P99 TTFT (ms): 332.49 ---------------Inter-Token Latency---------------- Mean ITL (ms): 45.31 Median ITL (ms): 45.33 P95 ITL (ms): 45.92 P99 ITL (ms): 46.17 Max ITL (ms): 50.67

triton_backend:

============ Serving Benchmark Result ============ Backend: sglang Traffic request rate: inf Max reqeuest concurrency: 1 Successful requests: 3 Benchmark duration (s): 110.48 Total input tokens: 159 Total generated tokens: 2467 Total generated tokens (retokenized): 2466 Request throughput (req/s): 0.03 Input token throughput (tok/s): 1.44 Output token throughput (tok/s): 22.33 Total token throughput (tok/s): 23.77 Concurrency: 1.00 ----------------End-to-End Latency---------------- Mean E2E Latency (ms): 36809.89 Median E2E Latency (ms): 37824.62 ---------------Time to First Token---------------- Mean TTFT (ms): 399.44 Median TTFT (ms): 350.24 P99 TTFT (ms): 523.33 ---------------Inter-Token Latency---------------- Mean ITL (ms): 44.33 Median ITL (ms): 44.33 P95 ITL (ms): 44.83 P99 ITL (ms): 44.95 Max ITL (ms): 47.54

Environment:

  • Ubuntu 22.04
  • H20 * 8

In this PR, we have fixed the performance issues and tested it. In certain cases, flashmla has advantages.

@sleepcoo sleepcoo deleted the opt-flashmla-backend branch March 21, 2025 08:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants