Skip to content

[Performance] Use the max num_tokens per DP rank as the CUDA graph batch size #6092

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

xpmemeda
Copy link

@xpmemeda xpmemeda commented May 7, 2025

Motivation

With --enable-dp-attention flag, cuda_graph_runner uses the batch size of each DP rank to capture CUDA graphs, but uses the sum of tokens from all DP ranks to select and replay the CUDA graph, which results in selecting an excessively large batch size.

See this issue 5527

Modifications

Change the CUDA graph selection logic from sum to max (tokens).

Benchmarks

H20 * 16,python -u -m sglang.launch_server --model-path /sgl-workspace/DeepSeek-R1 --nnodes 2 --trust-remote-code --served-model-name DeepSeek-R1 --dist-init-addr 29.226.64.238:5000 --node-rank 0 --host 0.0.0.0 --port 8000 --tp 16 --disable-radix-cache --schedule-policy fcfs --chunked-prefill-size 32768 --disable-overlap-schedule --mem-fraction-static 0.79 --attention-backend flashinfer --enable-dp-attention --dp-size 16

python -m sglang.bench_one_batch_server --model None --base-url http://127.0.0.1:8000 --batch-size 128 --input-len 1000 --output-len 1000

Before
企业微信截图_17466305911505

After
image

Checklist

@ch-wan
Copy link
Collaborator

ch-wan commented May 7, 2025

@xpmemeda I observe that you are reverting partial changes in #4390. However, using sum token is necessary to keep all intermediate tensor static. As an extreme case, if we have 8 DP workers and their processing batch sizes are [120, 1, 1, 1, 1, 1, 1, 1], the input of GroupedGeMM is of size [127, hidden_dim]. After this PR, it becomes [120*8, hidden_dim], most of the input tensors are padded with 0. The correct way to resolve this issue is to set a large batch size for cuda graph, or to use DP ffn and LM head. We are fixing our PRs to support the latter feature.

@xpmemeda
Copy link
Author

xpmemeda commented May 8, 2025

if we have 8 DP workers and their processing batch sizes are [120, 1, 1, 1, 1, 1, 1, 1], the input of GroupedGeMM is of size [127, hidden_dim].

@ch-wan 现状是 FFN 的输入大小会变成 [127 * 8, hidden_dim],而不是 [127, hidden_dim]。

@xpmemeda
Copy link
Author

gently ping. @ch-wan @zhyncs @merrymercy

@ch-wan
Copy link
Collaborator

ch-wan commented May 10, 2025

if we have 8 DP workers and their processing batch sizes are [120, 1, 1, 1, 1, 1, 1, 1], the input of GroupedGeMM is of size [127, hidden_dim].

@ch-wan 现状是 FFN 的输入大小会变成 [127 * 8, hidden_dim],而不是 [127, hidden_dim]。

Sorry, I don't understand your response. Could you please add more details?

@xpmemeda
Copy link
Author

xpmemeda commented May 10, 2025

if we have 8 DP workers and their processing batch sizes are [120, 1, 1, 1, 1, 1, 1, 1], the input of GroupedGeMM is of size [127, hidden_dim].

@ch-wan 现状是 FFN 的输入大小会变成 [127 * 8, hidden_dim],而不是 [127, hidden_dim]。

Sorry, I don't understand your response. Could you please add more details?

@ch-wan 我们可以先讨论一下 capture_one_batch的逻辑,参数 bs 是对于单个 DP worker 而言的,还是对于所有“DP worker 的和“而言的。不论是哪种情况,capture 和 replay 对于 bs 的理解都应该一致,这个无关模型结构。

假设一:bs 表示单个 DP worker 持有的请求数量

那么 replay 阶段应该用 max 来与 capture 行为保持一致,也就是这个 pr 的修改。

假设二:bs 表示所有 DP worker 持有的请求数量

那么 capture 阶段有个 bug,这里 ForwardBatch 投喂的输入不对,导致的后果就是 gather 过后,非 DP 状态下的 FFN 的输入变成 [sum * dp_size, hidden_size]。

而且对于某个固定的 bs,在这个 bs 之下,没法推测每个 DP worker 实际持有多少请求,现有逻辑其实不太合理。

@ch-wan
Copy link
Collaborator

ch-wan commented May 10, 2025

The current implementation was checked internally. It is correct.

To answer your question, bs in this line represents the global batch size so that it can reuse the input buffer of FFN. This design does not incur redundant computation because the sequence length of padded queries are 0s.

You concern regarding excessive tensor shape after all-gather is reasonable but not valid. We only copy the first several effective tokens to the communication buffer. See this.

@xpmemeda
Copy link
Author

The current implementation was checked internally. It is correct.

To answer your question, bs in this line represents the global batch size so that it can reuse the input buffer of FFN. This design does not incur redundant computation because the sequence length of padded queries are 0s.

You concern regarding excessive tensor shape after all-gather is reasonable but not valid. We only copy the first several effective tokens to the communication buffer. See this.

@ch-wan 好的,我理解了,谢谢回复。

@xpmemeda xpmemeda closed this May 10, 2025
@xpmemeda
Copy link
Author

The current implementation was checked internally. It is correct.

To answer your question, bs in this line represents the global batch size so that it can reuse the input buffer of FFN. This design does not incur redundant computation because the sequence length of padded queries are 0s.

You concern regarding excessive tensor shape after all-gather is reasonable but not valid. We only copy the first several effective tokens to the communication buffer. See this.

@ch-wan 这里引出了另一个问题,既然没有引入额外的计算,为什么改成 sum 之后,dp 性能大幅下降。也就是 issue中提到的内容,我前面的 benchmark 也验证了这点。

@ch-wan
Copy link
Collaborator

ch-wan commented May 10, 2025

@xpmemeda Probably because the slow run did not activate CUDA graph. Please consider to increase --cuda-run-max-bs.

@xpmemeda
Copy link
Author

xpmemeda commented May 11, 2025

@xpmemeda Probably because the slow run did not activate CUDA graph. Please consider to increase --cuda-run-max-bs.

@ch-wan 我确认过了 CUDA graph 是生效的,nsight 分析了下,看起来用 sum 会导致大多数 kernel 耗时增加,可能要再确认一下 sum 是不是真的没有产生额外计算。

用 sum(before this pr):
adf115a771e62c99d677b193d1e83dd

用 max(after this pr):
726592cc24c20cb3a3ed76be874f8ec

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants