Skip to content

Improve DP attention #4390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 13, 2025
Merged

Improve DP attention #4390

merged 3 commits into from
Mar 13, 2025

Conversation

merrymercy
Copy link
Contributor

@merrymercy merrymercy commented Mar 13, 2025

  • Use a better padding strategy for cuda graph. If TP=8, DP=8, when batch size = 1, the previous implementation will pad it to global batch size 8. The new implementation will allow running global batch size 1, so it is faster at low bs range. It is 1.15x faster then the old implementation for deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct at TP=8 and bs=1.
  • Support TP != DP. Now you need to explicitly specify --dp and --tp. The constraint is that --dp should be smaller than --tp. You can first set --tp as the number of total GPUs you have, then tune --dp to trade-off between latency and KV cache capacity (or throughput). For example, to achieve better latency for small bs, you can do --tp 8 --dp 2. To allow more KV cache capacity for larger bs, you can do --tp 8 --dp 8. An example command:
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct --trust-remote-code --tp 8 --enable-dp-attention --dp 2

99% of the code is done by @dhou-xai .

Co-authored-by: dhou-xai <[email protected]>
Co-authored-by: SangBin Cho <[email protected]>

@merrymercy merrymercy merged commit 8e66fbe into main Mar 13, 2025
36 of 39 checks passed
@merrymercy merrymercy deleted the pr-dp branch March 13, 2025 15:23
@xihuai18
Copy link
Contributor

Can we run 671b models with --dp 2 --tp 8 in 16 x H100 ?

hebiao064 pushed a commit to hebiao064/sglang that referenced this pull request Mar 13, 2025
Co-authored-by: dhou-xai <[email protected]>
Co-authored-by: SangBin Cho <[email protected]>
@merrymercy
Copy link
Contributor Author

@xihuai18 Yes. You can use --tp 16 --dp 2 for 16 x H100

@merrymercy
Copy link
Contributor Author

Please share the command here and in the docs once you finish the testing

@binarycrayon
Copy link
Contributor

We should have a hyperparameter tuning best practice in the documentation

"""
You can first set --tp as the number of total GPUs you have, then tune --dp to trade-off between latency and KV cache capacity (or throughput). For example, to achieve better latency for small bs, you can do --tp 8 --dp 2. To allow more KV cache capacity for larger bs, you can do --tp 8 --dp 8.
"""

@Wesley-Jzy
Copy link

Wesley-Jzy commented Mar 13, 2025

I also try to run it on tp16 dp2 setting. I found that capture cuda graph will cause segmentation fault. I can run it with --disable-cuda-graph or update nccl. Also, for older sglang version, lower nccl is okay. May I know is it necessary for me to update nccl for this version update?

@xihuai18
Copy link
Contributor

[2025-03-14 14:50:57 DP0 TP4] Scheduler hit an exception: Traceback (most recent call last):
File "/path/to/sglang/python/sglang/srt/managers/scheduler.py", line 1748, in run_scheduler_process
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
File "/path/to/sglang/python/sglang/srt/managers/scheduler.py", line 230, in init
self.draft_worker = EAGLEWorker(
File "/path/to/sglang/python/sglang/srt/speculative/eagle_worker.py", line 102, in init
self.init_cuda_graphs()
File "/path/to/sglang/python/sglang/srt/speculative/eagle_worker.py", line 153, in init_cuda_graphs
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
File "/path/to/sglang/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py", line 78, in init
self.capture()
File "/path/to/sglang/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py", line 99, in capture
CudaGraphRunner.capture(self)
File "/path/to/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 336, in capture
) = self.capture_one_batch_size(bs, forward)
File "/path/to/sglang/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py", line 162, in capture_one_batch_size
run_once()
File "/path/to/sglang/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py", line 152, in run_once
ret = self.eagle_worker.draft_forward(forward_batch)
File "/path/to/sglang/python/sglang/srt/speculative/eagle_worker.py", line 325, in draft_forward
logits_output = self.model_runner.model.forward(
File "/usr/local/conda/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/path/to/sglang/python/sglang/srt/models/deepseek_nextn.py", line 154, in forward
hidden_states = self.model(input_ids, positions, forward_batch)
File "/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/path/to/sglang/python/sglang/srt/models/deepseek_nextn.py", line 105, in forward
hidden_states, residual = self.decoder(
File "/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
return forward_call(*args, **kwargs)
File "/path/to/sglang/python/sglang/srt/models/deepseek_v2.py", line 951, in forward
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
TypeError: 'NoneType' object is not subscriptable

not compatible with MTP, will it be supported in the future?

@xihuai18
Copy link
Contributor

Please share the command here and in the docs once you finish the testing

--tp 16 --enable-dp-attention --dp 2 for running in 16 x H100 (fp8) or 16 x A100 (int8)

following options are tested but failed:

  • --enable-torch-compile: always OOM
  • --speculative-algo EAGLE --speculative-draft $NEXTN_PATH (MTP): not compatible

@Wesley-Jzy
Copy link

I also met OOM

@jokerwyt
Copy link
Contributor

Do we still need self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size now? @merrymercy

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants