Skip to content

[DeepEP] Eliminate unnecessary DP cudagraph padding #5557

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from

Conversation

yuleil
Copy link
Contributor

@yuleil yuleil commented Apr 19, 2025

Motivation

In DP attention, the batch size of the CUDA graph is expanded to the sum of batches across all DP ranks to enable all_gather global tokens before MLP fwd. Since decoding is memory-bound, this padding does not introduce significant performance overhead. Under DeepEP, the padded tokens from each DP rank participate in dispatch/combine operations, resulting in a DP times increase in communication costs.

DeepEP's MoE computation adopts a fixed shape with masking instead of gathering tokens from all DP ranks, so padding batch size to the global tokens is actually unnecessary.

Before this fix:
image

After this fix:
image

With H20, EP16,120 batch decoding with cudagraph, the combine time has been reduced by 10 times, and TPOT has decreased from 300ms to ~100ms.

Modifications

Checklist

@ch-wan
Copy link
Collaborator

ch-wan commented Apr 19, 2025

Even if DeepEP is enabled, this padding is necessary for cuda graph because DeepSeek-V3 contains several dense FFNs. Could you please add a barrier before the communication operator to confirm whether the communication volume is redundant under DeepEP? This is weird to me because low-latency dispatch cannot handle input over num_max_dispatch_tokens_per_rank (128 by default). The performance gain you observe may come from a weaker launching condition for cuda graph. See #5527.

@yuleil
Copy link
Contributor Author

yuleil commented Apr 21, 2025

Even if DeepEP is enabled, this padding is necessary for cuda graph because DeepSeek-V3 contains several dense FFNs. Could you please add a barrier before the communication operator to confirm whether the communication volume is redundant under DeepEP? This is weird to me because low-latency dispatch cannot handle input over num_max_dispatch_tokens_per_rank (128 by default). The performance gain you observe may come from a weaker launching condition for cuda graph. See #5527.

this padding is necessary for cuda graph because DeepSeek-V3 contains several dense FFNs

I will continue to investigate the situation with the dense FFN.

Could you please add a barrier before the communication operator to confirm whether the communication volume is redundant under DeepEP

Okay, I’ll give it a try.

This is weird to me because low-latency dispatch cannot handle input over num_max_dispatch_tokens_per_rank (128 by default)

The low-latency dispatch returns a tensor of shape [num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, hidden]. Therefore, as long as GBS (120 in my experiments) is less than num_max_dispatch_tokens_per_rank, it can be handled correctly.

The performance gain you observe may come from a weaker launching condition for cuda graph.

The slow combine is occurring under CUDA Graph.

image

@TianyuZhang1214
Copy link
Contributor

Hi @yuleil ,

I've been testing this PR together with #5435 and encountered a CUDA memory error during batch decoding. Here are the details:

Reproduction Steps:

  1. Applied this PR alongside Integrating PD disaggregation with DP attention and DeepEP #5435
  2. Launch SGLang with parameters mentioned in Integrating PD disaggregation with DP attention and DeepEP #5435.
  3. Run the following bench_serving method:
    python3 -m sglang.bench_serving \
        --port 8000 \
        --backend sglang \
        --dataset-name random \
        --num-prompt 128 \
        --random-input 4096 \
        --random-output 1500 \
        --random-range-ratio 1 \
        --dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \
        --max-concurrency 128

Error:

[2025-04-22 02:22:22 DP2 TP4] Scheduler hit an exception: Traceback (most recent call last):
  File "/home/nas/code/sglang/python/sglang/srt/managers/scheduler.py", line 2015, in run_scheduler_process
    scheduler.event_loop_normal_disagg_decode()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/nas/code/sglang/python/sglang/srt/disaggregation/decode.py", line 456, in event_loop_normal_disagg_decode
    self.process_batch_result(batch, result)
  File "/home/nas/code/sglang/python/sglang/srt/managers/scheduler.py", line 1408, in process_batch_result
    self.process_batch_result_decode(batch, result)
  File "/home/nas/code/sglang/python/sglang/srt/managers/scheduler_output_processor_mixin.py", line 194, in process_batch_result_decode
    next_token_ids = next_token_ids.tolist()
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Environment:

# python3 -m sglang.check_env        
Python: 3.10.12 (main, Feb  4 2025, 14:57:36) [GCC 11.4.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA H20
GPU 0,1,2,3,4,5,6,7 Compute Capability: 9.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.4, V12.4.131
CUDA Driver Version: 550.127.08
PyTorch: 2.5.1+cu124
sglang: 0.4.5.post1
sgl_kernel: 0.0.9.post2
flashinfer: Module Not Found
triton: 3.1.0
transformers: 4.51.1
torchao: 0.9.0
numpy: 2.2.4
aiohttp: 3.11.14
fastapi: 0.115.12
hf_transfer: 0.1.9
huggingface_hub: 0.30.2
interegular: 0.3.3
modelscope: 1.24.1
orjson: 3.10.16
outlines: 0.1.11
packaging: 24.2
psutil: 7.0.0
pydantic: 2.11.1
multipart: Module Not Found
zmq: Module Not Found
uvicorn: 0.34.0
uvloop: 0.21.0
vllm: Module Not Found
xgrammar: 0.1.17
openai: 1.69.0
tiktoken: 0.9.0
anthropic: 0.49.0
litellm: 1.65.0
decord: 0.6.0
NVIDIA Topology: 
	GPU0	GPU1	GPU2	GPU3	GPU4	GPU5	GPU6	GPU7	NIC0	NIC1	NIC2	NIC3	CPU Affinity	NUMA AffinityGPU NUMA ID
GPU0	 X 	NV18	NV18	NV18	NV18	NV18	NV18	NV18	NODE	NODE	SYS	SYS	0-47,96-143	0		/A
GPU1	NV18	 X 	NV18	NV18	NV18	NV18	NV18	NV18	PIX	NODE	SYS	SYS	0-47,96-143	0		/A
GPU2	NV18	NV18	 X 	NV18	NV18	NV18	NV18	NV18	NODE	NODE	SYS	SYS	0-47,96-143	0		/A
GPU3	NV18	NV18	NV18	 X 	NV18	NV18	NV18	NV18	NODE	PIX	SYS	SYS	0-47,96-143	0		/A
GPU4	NV18	NV18	NV18	NV18	 X 	NV18	NV18	NV18	SYS	SYS	PIX	NODE	48-95,144-191	1		/A
GPU5	NV18	NV18	NV18	NV18	NV18	 X 	NV18	NV18	SYS	SYS	NODE	NODE	48-95,144-191	1		/A
GPU6	NV18	NV18	NV18	NV18	NV18	NV18	 X 	NV18	SYS	SYS	NODE	PIX	48-95,144-191	1		/A
GPU7	NV18	NV18	NV18	NV18	NV18	NV18	NV18	 X 	SYS	SYS	NODE	NODE	48-95,144-191	1		/A
NIC0	NODE	PIX	NODE	NODE	SYS	SYS	SYS	SYS	 X 	NODE	SYS	SYS				
NIC1	NODE	NODE	NODE	PIX	SYS	SYS	SYS	SYS	NODE	 X 	SYS	SYS				
NIC2	SYS	SYS	SYS	SYS	PIX	NODE	NODE	NODE	SYS	SYS	 X 	NODE				
NIC3	SYS	SYS	SYS	SYS	NODE	NODE	PIX	NODE	SYS	SYS	NODE	 X 				

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_bond_0
  NIC1: mlx5_bond_1
  NIC2: mlx5_bond_2
  NIC3: mlx5_bond_3


ulimit soft: 1048576

@yuleil yuleil force-pushed the fix_deepep_batch branch from 9c6ff3a to 8f9586f Compare April 23, 2025 04:41
@yuleil yuleil force-pushed the fix_deepep_batch branch from 8f9586f to 349df6d Compare April 23, 2025 04:44
@yuleil
Copy link
Contributor Author

yuleil commented Apr 23, 2025

Hi @yuleil ,

I've been testing this PR together with #5435 and encountered a CUDA memory error during batch decoding. Here are the details:

Reproduction Steps:

  1. Applied this PR alongside Integrating PD disaggregation with DP attention and DeepEP #5435
  2. Launch SGLang with parameters mentioned in Integrating PD disaggregation with DP attention and DeepEP #5435.
  3. Run the following bench_serving method:
    python3 -m sglang.bench_serving \
        --port 8000 \
        --backend sglang \
        --dataset-name random \
        --num-prompt 128 \
        --random-input 4096 \
        --random-output 1500 \
        --random-range-ratio 1 \
        --dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \
        --max-concurrency 128

Error:

[2025-04-22 02:22:22 DP2 TP4] Scheduler hit an exception: Traceback (most recent call last):
  File "/home/nas/code/sglang/python/sglang/srt/managers/scheduler.py", line 2015, in run_scheduler_process
    scheduler.event_loop_normal_disagg_decode()
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/nas/code/sglang/python/sglang/srt/disaggregation/decode.py", line 456, in event_loop_normal_disagg_decode
    self.process_batch_result(batch, result)
  File "/home/nas/code/sglang/python/sglang/srt/managers/scheduler.py", line 1408, in process_batch_result
    self.process_batch_result_decode(batch, result)
  File "/home/nas/code/sglang/python/sglang/srt/managers/scheduler_output_processor_mixin.py", line 194, in process_batch_result_decode
    next_token_ids = next_token_ids.tolist()
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Environment:

# python3 -m sglang.check_env        
Python: 3.10.12 (main, Feb  4 2025, 14:57:36) [GCC 11.4.0]
CUDA available: True
GPU 0,1,2,3,4,5,6,7: NVIDIA H20
GPU 0,1,2,3,4,5,6,7 Compute Capability: 9.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.4, V12.4.131
CUDA Driver Version: 550.127.08
PyTorch: 2.5.1+cu124
sglang: 0.4.5.post1
sgl_kernel: 0.0.9.post2
flashinfer: Module Not Found
triton: 3.1.0
transformers: 4.51.1
torchao: 0.9.0
numpy: 2.2.4
aiohttp: 3.11.14
fastapi: 0.115.12
hf_transfer: 0.1.9
huggingface_hub: 0.30.2
interegular: 0.3.3
modelscope: 1.24.1
orjson: 3.10.16
outlines: 0.1.11
packaging: 24.2
psutil: 7.0.0
pydantic: 2.11.1
multipart: Module Not Found
zmq: Module Not Found
uvicorn: 0.34.0
uvloop: 0.21.0
vllm: Module Not Found
xgrammar: 0.1.17
openai: 1.69.0
tiktoken: 0.9.0
anthropic: 0.49.0
litellm: 1.65.0
decord: 0.6.0
NVIDIA Topology: 
	GPU0	GPU1	GPU2	GPU3	GPU4	GPU5	GPU6	GPU7	NIC0	NIC1	NIC2	NIC3	CPU Affinity	NUMA AffinityGPU NUMA ID
GPU0	 X 	NV18	NV18	NV18	NV18	NV18	NV18	NV18	NODE	NODE	SYS	SYS	0-47,96-143	0		/A
GPU1	NV18	 X 	NV18	NV18	NV18	NV18	NV18	NV18	PIX	NODE	SYS	SYS	0-47,96-143	0		/A
GPU2	NV18	NV18	 X 	NV18	NV18	NV18	NV18	NV18	NODE	NODE	SYS	SYS	0-47,96-143	0		/A
GPU3	NV18	NV18	NV18	 X 	NV18	NV18	NV18	NV18	NODE	PIX	SYS	SYS	0-47,96-143	0		/A
GPU4	NV18	NV18	NV18	NV18	 X 	NV18	NV18	NV18	SYS	SYS	PIX	NODE	48-95,144-191	1		/A
GPU5	NV18	NV18	NV18	NV18	NV18	 X 	NV18	NV18	SYS	SYS	NODE	NODE	48-95,144-191	1		/A
GPU6	NV18	NV18	NV18	NV18	NV18	NV18	 X 	NV18	SYS	SYS	NODE	PIX	48-95,144-191	1		/A
GPU7	NV18	NV18	NV18	NV18	NV18	NV18	NV18	 X 	SYS	SYS	NODE	NODE	48-95,144-191	1		/A
NIC0	NODE	PIX	NODE	NODE	SYS	SYS	SYS	SYS	 X 	NODE	SYS	SYS				
NIC1	NODE	NODE	NODE	PIX	SYS	SYS	SYS	SYS	NODE	 X 	SYS	SYS				
NIC2	SYS	SYS	SYS	SYS	PIX	NODE	NODE	NODE	SYS	SYS	 X 	NODE				
NIC3	SYS	SYS	SYS	SYS	NODE	NODE	PIX	NODE	SYS	SYS	NODE	 X 				

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

NIC Legend:

  NIC0: mlx5_bond_0
  NIC1: mlx5_bond_1
  NIC2: mlx5_bond_2
  NIC3: mlx5_bond_3


ulimit soft: 1048576

According to @ch-wan 's response, the first three dense layers and the lm_head of DeepSeekV3 still use TP and require gathering, therefore relying on the padding of sum(global_num_tokens_cpu). Setting moe_dense_tp_size=1 can prevent the dense layers from using TP, but making lm_head non-TP depends on another patch.

I have verified that after making both the dense layers and lm_head fully DP, this fix can run correctly.

@ch-wan
Copy link
Collaborator

ch-wan commented Apr 23, 2025

@TianyuZhang1214 Please check these two PRs: #5558 and #5657

@TianyuZhang1214
Copy link
Contributor

According to @ch-wan 's response, the first three dense layers and the lm_head of DeepSeekV3 still use TP and require gathering, therefore relying on the padding of sum(global_num_tokens_cpu). Setting moe_dense_tp_size=1 can prevent the dense layers from using TP, but making lm_head non-TP depends on another patch.

I have verified that after making both the dense layers and lm_head fully DP, this fix can run correctly.

Thanks for your reply! Your PR has consistently resolved our issues with remarkable efficiency!

I’m currently on Ant Group and have just sent a detailed email to your Gmail account regarding potential collaboration opportunities. Could you kindly review it at your earliest convenience?

@TianyuZhang1214
Copy link
Contributor

@TianyuZhang1214 Please check these two PRs: #5558 and #5657

@ch-wan Thank you for sharing PRs #5558 and #5657. I've tested them by deploying SGLang across 4 H20 nodes (8×96G) configured as:

  • 2 prefill nodes
  • 2 decode nodes
  • running DeepSeek-R1 model.

Then I've encountered error message in refill node as follows:

    return self.model.forward(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/home/nas/code/sglang/python/sglang/srt/models/deepseek_v2.py", line 1501, in forward
    hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nas/code/sglang/python/sglang/srt/models/deepseek_v2.py", line 1425, in forward
    hidden_states, residual = layer(
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nas/code/sglang/python/sglang/srt/models/deepseek_v2.py", line 1214, in forward
    return self.forward_ffn_with_full_input(
  File "/home/nas/code/sglang/python/sglang/srt/models/deepseek_v2.py", line 1260, in forward_ffn_with_full_input
    dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
  File "/home/nas/code/sglang/python/sglang/srt/layers/dp_attention.py", line 270, in dp_gather_partial
    _dp_gather(global_tokens, local_tokens, forward_batch, is_partial=True)
  File "/home/nas/code/sglang/python/sglang/srt/layers/dp_attention.py", line 237, in _dp_gather
    local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
  File "/home/nas/code/sglang/python/sglang/srt/layers/dp_attention.py", line 177, in get_dp_local_info
    cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
TypeError: cumsum() received an invalid combination of arguments - got (NoneType, dim=int), but expected one of:
 * (Tensor input, int dim, *, torch.dtype dtype = None, Tensor out = None)
 * (Tensor input, name dim, *, torch.dtype dtype = None, Tensor out = None)

Launch script:

# Prefill Node 0 (1 is the same)
MOONCAKE_CONFIG_PATH=./prefill_node_0.json SUPPORT_CUTLASS_BLOCK_FP8=1 python3 -m sglang.launch_server \
--model-path /home/moyun.zty/models/deepseek-ai__DeepSeek-R1 \
--disaggregation-mode prefill \
--host 10.13.3.156 \
--port 30001 \
--trust-remote-code \
--dist-init-addr 10.13.3.156:50000 \
--nnodes 2 \
--node-rank 0 \
--tp-size 16 \
--dp-size 16 \
--enable-dp-attention \
--enable-deepep-moe \
--deepep-mode normal \
--mem-fraction-static 0.9 \
--quantization fp8 \
--log-level debug \
--chunked-prefill-size 8196 \
--disable-radix-cache \
--context-length 65535 \
--max-running-requests 128 \
--stream-output \
--log-requests \
--attention-backend flashinfer \
--enable-mixed-chunk \
--flashinfer-mla-disable-ragged \
> sglang-prefill.log 2>&1 &
# Decode Node 0 (1 is the same)
MOONCAKE_CONFIG_PATH=./prefill_node_0.json SUPPORT_CUTLASS_BLOCK_FP8=1 python3 -m sglang.launch_server \
--model-path /home/moyun.zty/models/deepseek-ai__DeepSeek-R1 \
--disaggregation-mode decode \
--host 10.13.3.169 \
--port 30001 \
--trust-remote-code \
--dist-init-addr 10.13.3.169:50001 \
--nnodes 2 \
--node-rank 0 \
--tp-size 16 \
--dp-size 16 \
--enable-dp-attention \
--enable-deepep-moe \
--deepep-mode low_latency \
--moe-dense-tp-size 1 \
--mem-fraction-static 0.85 \
--quantization fp8 \
--log-level debug \
--disable-radix-cache \
--context-length 65535 \
--max-running-requests 128 \
--stream-output \
--log-requests \
--attention-backend flashinfer \
--enable-mixed-chunk \
--flashinfer-mla-disable-ragged \
> sglang-decode.log 2>&1 &

@liusy58
Copy link
Contributor

liusy58 commented May 7, 2025

@TianyuZhang1214 Hi, could you please provide the MOONCAKE_CONFIG for me?

@yuleil yuleil closed this May 15, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants