Skip to content

Support overlapping two batches #4068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1,357 commits into from
May 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
1357 commits
Select commit Hold shift + click to select a range
02ced82
more
fzyzcjy May 20, 2025
27e7dbd
copy
fzyzcjy May 20, 2025
9700cc7
copy
fzyzcjy May 20, 2025
4e90881
cp
fzyzcjy May 20, 2025
b1dc5d2
copy
fzyzcjy May 20, 2025
b08f900
fmt
fzyzcjy May 20, 2025
7218c95
more
fzyzcjy May 20, 2025
dd2d031
copy
fzyzcjy May 20, 2025
7b107ea
more
fzyzcjy May 20, 2025
f037123
more
fzyzcjy May 20, 2025
1d68581
extract
fzyzcjy May 20, 2025
5ab576a
more
fzyzcjy May 20, 2025
d34036f
more
fzyzcjy May 20, 2025
c8a3013
fmt
fzyzcjy May 20, 2025
87ba8ad
more
fzyzcjy May 20, 2025
244140d
more
fzyzcjy May 20, 2025
0f4649d
more
fzyzcjy May 20, 2025
594a1e9
more
fzyzcjy May 20, 2025
4550fe2
more
fzyzcjy May 20, 2025
f3e8e4e
more
fzyzcjy May 20, 2025
69a52ab
more
fzyzcjy May 20, 2025
52b52dc
cp
fzyzcjy May 20, 2025
f578aa0
more
fzyzcjy May 20, 2025
1d8cb3a
more
fzyzcjy May 20, 2025
dcf07ba
more
fzyzcjy May 20, 2025
fd1cf86
more
fzyzcjy May 20, 2025
f6746d7
more
fzyzcjy May 20, 2025
15e3087
more
fzyzcjy May 20, 2025
0bfc256
more
fzyzcjy May 20, 2025
2d28b56
more
fzyzcjy May 20, 2025
2b29740
more
fzyzcjy May 20, 2025
ed817b2
more
fzyzcjy May 20, 2025
46d37d3
more
fzyzcjy May 20, 2025
4baf6fd
more
fzyzcjy May 20, 2025
c9e9a32
more
fzyzcjy May 20, 2025
a11b24a
more
fzyzcjy May 20, 2025
19e8739
more
fzyzcjy May 20, 2025
8b78c8c
more
fzyzcjy May 20, 2025
0362358
more
fzyzcjy May 20, 2025
e236710
more
fzyzcjy May 20, 2025
a150bd6
more
fzyzcjy May 20, 2025
01a0f0f
Revert "more"
fzyzcjy May 20, 2025
8763afe
Revert "more"
fzyzcjy May 20, 2025
2548804
more
fzyzcjy May 20, 2025
754e32b
more
fzyzcjy May 20, 2025
28725ab
more
fzyzcjy May 20, 2025
4837f10
more
fzyzcjy May 20, 2025
a87acea
more
fzyzcjy May 20, 2025
edc72af
more
fzyzcjy May 20, 2025
ed9daac
more
fzyzcjy May 20, 2025
cd34cbb
more
fzyzcjy May 20, 2025
c9e7f23
more
fzyzcjy May 20, 2025
b223b78
more
fzyzcjy May 20, 2025
fc23b9e
more
fzyzcjy May 20, 2025
c54067d
more
fzyzcjy May 20, 2025
b9a01eb
more
fzyzcjy May 20, 2025
a51f0dc
more
fzyzcjy May 20, 2025
80c3fbc
more
fzyzcjy May 20, 2025
5574f86
more
fzyzcjy May 20, 2025
c269e50
more
fzyzcjy May 20, 2025
1c05b84
more
fzyzcjy May 20, 2025
9faa60a
more
fzyzcjy May 20, 2025
aef7ac3
more
fzyzcjy May 20, 2025
17ab613
more
fzyzcjy May 20, 2025
6cf762e
more
fzyzcjy May 20, 2025
d5eb60f
more
fzyzcjy May 20, 2025
d271e38
more
fzyzcjy May 20, 2025
6583e3b
more
fzyzcjy May 20, 2025
82d2f99
fmt
fzyzcjy May 20, 2025
0ac1db2
Merge branch 'main-upstream' into feat/deepseekv2_two_batch_overlap
fzyzcjy May 20, 2025
ba70dc2
more
fzyzcjy May 20, 2025
84da150
more
fzyzcjy May 20, 2025
ead5e2e
more
fzyzcjy May 20, 2025
013bc14
rm ci
fzyzcjy May 20, 2025
9cd1126
more
fzyzcjy May 20, 2025
57de45c
more
fzyzcjy May 20, 2025
e480396
more
fzyzcjy May 20, 2025
b06ab90
fmt
fzyzcjy May 20, 2025
0836ccd
more
fzyzcjy May 20, 2025
1459233
more
fzyzcjy May 20, 2025
6082a5a
Merge branch 'feat/deepep_assert' into feat/deepseekv2_two_batch_overlap
fzyzcjy May 20, 2025
3c6b828
more
fzyzcjy May 20, 2025
9fa167e
more
fzyzcjy May 20, 2025
99f6f89
more
fzyzcjy May 20, 2025
4c98f95
more
fzyzcjy May 20, 2025
0265fd5
more
fzyzcjy May 20, 2025
85e1dff
fmt
fzyzcjy May 20, 2025
5084285
rename
fzyzcjy May 20, 2025
a7de5fe
Merge branch 'feat/deepep_num_sms' into feat/deepseekv2_two_batch_ove…
fzyzcjy May 20, 2025
292c5dd
more
fzyzcjy May 20, 2025
44298d7
more
fzyzcjy May 20, 2025
9488a9c
more
fzyzcjy May 20, 2025
8fbbfdc
more
fzyzcjy May 20, 2025
b571e1b
Revert "more"
fzyzcjy May 20, 2025
ba1085a
Revert "more"
fzyzcjy May 20, 2025
69a26b8
Revert "more"
fzyzcjy May 20, 2025
63b9449
Revert "more"
fzyzcjy May 20, 2025
78ae999
more
fzyzcjy May 20, 2025
d98c47c
more
fzyzcjy May 20, 2025
463df5f
more
fzyzcjy May 20, 2025
fa72dfe
more
fzyzcjy May 20, 2025
ab72881
more
fzyzcjy May 20, 2025
4b55b06
more
fzyzcjy May 20, 2025
dd72455
more
fzyzcjy May 20, 2025
55a64f0
more
fzyzcjy May 20, 2025
9b698cb
more
fzyzcjy May 20, 2025
b246796
more
fzyzcjy May 20, 2025
89a81e5
more
fzyzcjy May 20, 2025
1b50bf3
more
fzyzcjy May 20, 2025
7dd185a
more
fzyzcjy May 20, 2025
5c3e2e3
more
fzyzcjy May 20, 2025
2c6675f
more
fzyzcjy May 20, 2025
a51038d
more
fzyzcjy May 20, 2025
7c436a0
fmt
fzyzcjy May 20, 2025
13266bf
more
fzyzcjy May 21, 2025
7e68c8d
more
fzyzcjy May 21, 2025
e56f235
more
fzyzcjy May 21, 2025
46c58d8
more
fzyzcjy May 21, 2025
e15c2b4
more
fzyzcjy May 21, 2025
e60b836
more
fzyzcjy May 21, 2025
2432854
more
fzyzcjy May 21, 2025
541772b
more
fzyzcjy May 21, 2025
7272ad6
more
fzyzcjy May 21, 2025
915e5af
more
fzyzcjy May 21, 2025
d38ae00
more
fzyzcjy May 21, 2025
29ec65c
more
fzyzcjy May 21, 2025
117cb00
more
fzyzcjy May 21, 2025
4353b9d
more
fzyzcjy May 21, 2025
7f41e92
more
fzyzcjy May 21, 2025
02ed4a5
more
fzyzcjy May 21, 2025
fb6ef46
more
fzyzcjy May 21, 2025
06bddad
more
fzyzcjy May 21, 2025
23f9dda
more
fzyzcjy May 21, 2025
317f76e
more
fzyzcjy May 21, 2025
3cdadad
more
fzyzcjy May 21, 2025
ab7d5d7
more
fzyzcjy May 21, 2025
2ccb688
more
fzyzcjy May 21, 2025
1a196da
more
fzyzcjy May 21, 2025
f27129f
more
fzyzcjy May 21, 2025
42ce290
more
fzyzcjy May 21, 2025
49a3691
more
fzyzcjy May 21, 2025
a0c552b
fmt
fzyzcjy May 21, 2025
b9274eb
more
fzyzcjy May 21, 2025
d42df74
Merge branch 'feat/dpsk_attn_mode' into feat/deepseekv2_two_batch_ove…
fzyzcjy May 21, 2025
5f50a12
more
fzyzcjy May 21, 2025
475bf3e
more
fzyzcjy May 21, 2025
93edfe5
more
fzyzcjy May 21, 2025
985f3b7
more
fzyzcjy May 21, 2025
aa63010
more
fzyzcjy May 21, 2025
90ac4e0
more
fzyzcjy May 21, 2025
69aa847
more
fzyzcjy May 21, 2025
7781e7a
more
fzyzcjy May 21, 2025
2580b12
more
fzyzcjy May 21, 2025
c02b0b8
more
fzyzcjy May 21, 2025
2eaee99
more
fzyzcjy May 21, 2025
dc55484
more
fzyzcjy May 21, 2025
a3a83a9
more
fzyzcjy May 21, 2025
e3ec2ef
fmt
fzyzcjy May 21, 2025
413ca45
Merge branch 'feat/dpsk_attn_ab' into feat/deepseekv2_two_batch_overlap
fzyzcjy May 21, 2025
6b2974a
more
fzyzcjy May 21, 2025
eaa9369
more
fzyzcjy May 21, 2025
644962b
more
fzyzcjy May 21, 2025
755fe02
more
fzyzcjy May 21, 2025
dce8996
more
fzyzcjy May 21, 2025
a7bd0a6
fmt
fzyzcjy May 21, 2025
671fbd5
more
fzyzcjy May 21, 2025
4e2dc78
more
fzyzcjy May 21, 2025
324ffaf
more
fzyzcjy May 21, 2025
d048abb
more
fzyzcjy May 21, 2025
2ffe757
more
fzyzcjy May 21, 2025
d78c082
more
fzyzcjy May 21, 2025
eb0bf90
more
fzyzcjy May 21, 2025
a71bee0
more
fzyzcjy May 21, 2025
fdb5c10
more
fzyzcjy May 21, 2025
bae5392
more
fzyzcjy May 21, 2025
a129098
more
fzyzcjy May 21, 2025
3741309
more
fzyzcjy May 21, 2025
5c50574
more
fzyzcjy May 21, 2025
56c0eb4
more
fzyzcjy May 21, 2025
f89e88b
more
fzyzcjy May 21, 2025
2058c4c
more
fzyzcjy May 21, 2025
1e42248
more
fzyzcjy May 21, 2025
205d2da
more
fzyzcjy May 21, 2025
8c99e88
more
fzyzcjy May 21, 2025
372ee5a
cp back
fzyzcjy May 21, 2025
e7456a9
fmt
fzyzcjy May 21, 2025
93909b5
ci
fzyzcjy May 21, 2025
f9b7802
Merge branch 'main' into feat/deepseekv2_two_batch_overlap
fzyzcjy May 21, 2025
4318f99
more
fzyzcjy May 22, 2025
f5c9bc0
mv
fzyzcjy May 22, 2025
eca4f37
ci
fzyzcjy May 23, 2025
991c12c
Merge branch 'main-upstream' into feat/deepseekv2_two_batch_overlap
fzyzcjy May 23, 2025
9ffc49b
more
fzyzcjy May 23, 2025
5a149d9
merge
fzyzcjy May 23, 2025
d5521ee
revert
fzyzcjy May 23, 2025
b3bc049
revert
fzyzcjy May 23, 2025
ed12226
revert
fzyzcjy May 23, 2025
070ed2d
more
fzyzcjy May 23, 2025
b162a0a
more
fzyzcjy May 23, 2025
7c982e2
Merge branch 'main-upstream' into feat/deepseekv2_two_batch_overlap
fzyzcjy May 23, 2025
3037dea
revert attn (temp)
fzyzcjy May 23, 2025
e2ef1e5
revert communicator (temp)
fzyzcjy May 23, 2025
5628a19
Revert "revert communicator (temp)"
fzyzcjy May 23, 2025
2cfeca3
hack comunicator not recompute context
fzyzcjy May 23, 2025
6356f76
rm communicator check_shapes
fzyzcjy May 23, 2025
01c008f
static dispatch communicator
fzyzcjy May 23, 2025
c3cc30a
simp
fzyzcjy May 23, 2025
29f6343
communicator fn mv to class
fzyzcjy May 23, 2025
1d51f62
Revert "revert attn (temp)"
fzyzcjy May 23, 2025
55cba9c
attn speedup
fzyzcjy May 23, 2025
8107d29
temp attn forward method compute once
fzyzcjy May 23, 2025
3a971f9
refactor
fzyzcjy May 23, 2025
ae312b5
comment
fzyzcjy May 23, 2025
a33f2e9
fix ci
fzyzcjy May 23, 2025
30f0486
fix ci
fzyzcjy May 23, 2025
174a099
fix ci
fzyzcjy May 23, 2025
e6d2dfa
fix ci
fzyzcjy May 23, 2025
094b069
Revert "comment"
fzyzcjy May 23, 2025
dc15df6
Revert "refactor"
fzyzcjy May 23, 2025
729be9b
refactor
fzyzcjy May 23, 2025
52a2e28
rm non-tbo strategy
fzyzcjy May 24, 2025
e1ed71f
simp model_forward_maybe_tbo
fzyzcjy May 24, 2025
3eab9a4
simp op
fzyzcjy May 24, 2025
a102b27
further simp op
fzyzcjy May 24, 2025
604739d
use util func
fzyzcjy May 24, 2025
8398ddf
rm from main file
fzyzcjy May 24, 2025
e4a46cc
rm from main file
fzyzcjy May 24, 2025
6f2b194
minor
fzyzcjy May 24, 2025
a9157c8
minor
fzyzcjy May 24, 2025
f7a3ccc
Revert "minor"
fzyzcjy May 24, 2025
2e6f004
rm model.call
fzyzcjy May 24, 2025
a72e3f2
Revert "rm model.call"
fzyzcjy May 24, 2025
637ed7f
temp rm cache
fzyzcjy May 24, 2025
4bcfb5c
ci
fzyzcjy May 24, 2025
1f13724
tests
fzyzcjy May 24, 2025
f2cfeb9
Revert "tests"
fzyzcjy May 24, 2025
adfb77d
Revert "Revert "minor""
fzyzcjy May 24, 2025
3313ee0
minor
fzyzcjy May 24, 2025
c9c1307
ci
fzyzcjy May 24, 2025
8263ca0
Revert "simp model_forward_maybe_tbo"
fzyzcjy May 24, 2025
4bd7504
minor
fzyzcjy May 24, 2025
975eb41
revert
fzyzcjy May 24, 2025
a953279
split forward_absorb
fzyzcjy May 24, 2025
52bafde
reduction
fzyzcjy May 24, 2025
6f5593e
Revert "reduction"
fzyzcjy May 24, 2025
3a9d226
Revert "split forward_absorb"
fzyzcjy May 24, 2025
e280eab
Revert "revert"
fzyzcjy May 24, 2025
789153a
fix torch compile
fzyzcjy May 24, 2025
09bfd6f
Merge branch 'main' into feat/deepseekv2_two_batch_overlap
fzyzcjy May 24, 2025
30ca420
Merge branch 'main-upstream' into feat/deepseekv2_two_batch_overlap
fzyzcjy May 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 241 additions & 0 deletions python/sglang/srt/layers/attention/tbo_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
from typing import TYPE_CHECKING, Callable, List, Optional, Union

import torch

from sglang.srt import two_batch_overlap
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput

if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode


class TboAttnBackend(AttentionBackend):
def __init__(self, primary: AttentionBackend, children: List[AttentionBackend]):
super().__init__()
self.primary = primary
self.children = children

@classmethod
def init_new(cls, creator: Callable[[], AttentionBackend]):
return cls(
primary=creator(),
children=[creator() for _ in range(2)],
)

def init_forward_metadata(self, forward_batch: "ForwardBatch"):
self.primary.init_forward_metadata(forward_batch=forward_batch)
if forward_batch.tbo_children is not None:
for child, forward_batch_child in zip(
self.children, forward_batch.tbo_children, strict=True
):
if forward_batch_child.batch_size > 0:
child.init_forward_metadata(forward_batch=forward_batch_child)

def init_cuda_graph_state(self, max_bs: int):
self.primary.init_cuda_graph_state(max_bs=max_bs)
for item in self.children:
# TODO for children, maybe can provide *smaller* max_bs to optimize
item.init_cuda_graph_state(max_bs=max_bs)

def init_forward_metadata_capture_cuda_graph(
self,
bs: int,
num_tokens: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
):
self.primary.init_forward_metadata_capture_cuda_graph(
bs=bs,
num_tokens=num_tokens,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
encoder_lens=encoder_lens,
forward_mode=forward_mode,
spec_info=spec_info,
)

self._init_forward_metadata_cuda_graph_children(
fn_name="init_forward_metadata_capture_cuda_graph",
bs=bs,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
encoder_lens=encoder_lens,
forward_mode=forward_mode,
spec_info=spec_info,
capture_num_tokens=num_tokens,
)

def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
seq_lens_cpu: Optional[torch.Tensor],
):
self.primary.init_forward_metadata_replay_cuda_graph(
bs=bs,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
seq_lens_sum=seq_lens_sum,
encoder_lens=encoder_lens,
forward_mode=forward_mode,
spec_info=spec_info,
seq_lens_cpu=seq_lens_cpu,
)

self._init_forward_metadata_cuda_graph_children(
fn_name="init_forward_metadata_replay_cuda_graph",
bs=bs,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
encoder_lens=encoder_lens,
forward_mode=forward_mode,
spec_info=spec_info,
replay_seq_lens_sum=seq_lens_sum,
replay_seq_lens_cpu=seq_lens_cpu,
)

def _init_forward_metadata_cuda_graph_children(
self,
fn_name: str,
# common args
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
# capture args
capture_num_tokens: int = None,
# replay args
replay_seq_lens_sum: int = None,
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
):
from sglang.srt.model_executor.forward_batch_info import ForwardMode

if fn_name == "init_forward_metadata_capture_cuda_graph":
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
num_tokens = bs

forward_mode_for_tbo_split = (
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
)
tbo_split_seq_index = two_batch_overlap.compute_split_seq_index(
forward_mode=forward_mode_for_tbo_split,
num_tokens=num_tokens,
extend_lens=None,
)
tbo_split_token_index = two_batch_overlap.compute_split_token_index(
split_seq_index=tbo_split_seq_index,
forward_mode=forward_mode_for_tbo_split,
extend_seq_lens=None,
)

num_tokens_child_left = tbo_split_token_index
num_tokens_child_right = num_tokens - tbo_split_token_index
bs_child_left = num_tokens_child_left
bs_child_right = num_tokens_child_right

assert (
num_tokens_child_left > 0 and num_tokens_child_right > 0
), f"{num_tokens_child_left=} {num_tokens_child_right=} {forward_mode=} {num_tokens=}"

common_pre_split_args = dict(
fn_name=fn_name,
bs=bs,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
encoder_lens=encoder_lens,
forward_mode=forward_mode,
spec_info=spec_info,
capture_num_tokens=capture_num_tokens,
replay_seq_lens_sum=replay_seq_lens_sum,
replay_seq_lens_cpu=replay_seq_lens_cpu,
)

args_left = _init_forward_metadata_cuda_graph_split(
output_bs=bs_child_left,
seq_slice=slice(None, tbo_split_seq_index),
**common_pre_split_args,
)
args_right = _init_forward_metadata_cuda_graph_split(
output_bs=bs_child_right,
seq_slice=slice(tbo_split_seq_index, None),
**common_pre_split_args,
)

child_left, child_right = self.children
getattr(child_left, fn_name)(**args_left)
getattr(child_right, fn_name)(**args_right)

def get_cuda_graph_seq_len_fill_value(self):
ans = self.primary.get_cuda_graph_seq_len_fill_value()
for child in self.children:
assert ans == child.get_cuda_graph_seq_len_fill_value()
return ans

def forward_extend(self, *args, **kwargs):
return self.primary.forward_extend(*args, **kwargs)

def forward_decode(self, *args, **kwargs):
return self.primary.forward_decode(*args, **kwargs)


def _init_forward_metadata_cuda_graph_split(
fn_name: str,
seq_slice: slice,
output_bs: int,
# common args
bs: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
# capture args
capture_num_tokens: int = None,
# replay args
replay_seq_lens_sum: int = None,
replay_seq_lens_cpu: Optional[torch.Tensor] = None,
):
assert encoder_lens is None, "encoder_lens is not supported yet"
assert spec_info is None, "spec_info is not supported yet"

ans = dict(
bs=output_bs,
req_pool_indices=req_pool_indices[seq_slice],
seq_lens=seq_lens[seq_slice],
# directly forward
forward_mode=forward_mode,
# ignore
encoder_lens=None,
spec_info=None,
)

if fn_name == "init_forward_metadata_capture_cuda_graph":
assert capture_num_tokens == bs, "Only support num_tokens==bs currently"
ans.update(
dict(
num_tokens=output_bs,
)
)
elif fn_name == "init_forward_metadata_replay_cuda_graph":
output_seq_lens_cpu = replay_seq_lens_cpu[seq_slice]
ans.update(
dict(
seq_lens_sum=output_seq_lens_cpu.sum().item(),
seq_lens_cpu=output_seq_lens_cpu,
)
)
else:
raise NotImplementedError

return ans
13 changes: 13 additions & 0 deletions python/sglang/srt/layers/quantization/deep_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,3 +391,16 @@ def __patched_func(self, *args, **kwargs):
RuntimeCache.get = __patched_func
yield
RuntimeCache.get = origin_func


@contextmanager
def configure_deep_gemm_num_sms(num_sms):
if num_sms is None:
yield
else:
original_num_sms = deep_gemm.get_num_sms()
deep_gemm.set_num_sms(num_sms)
try:
yield
finally:
deep_gemm.set_num_sms(original_num_sms)
8 changes: 8 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
"disable_radix_cache": ServerArgs.disable_radix_cache,
"enable_deepep_moe": ServerArgs.enable_deepep_moe,
"enable_dp_attention": ServerArgs.enable_dp_attention,
"enable_two_batch_overlap": ServerArgs.enable_two_batch_overlap,
"enable_dp_lm_head": ServerArgs.enable_dp_lm_head,
"enable_ep_moe": ServerArgs.enable_ep_moe,
"deepep_config": ServerArgs.deepep_config,
Expand Down Expand Up @@ -831,6 +832,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_num_tokens: Optional[List[int]] = None
global_num_tokens_for_logprob: Optional[List[int]] = None
can_run_dp_cuda_graph: bool = False
tbo_split_seq_index: Optional[int] = None
global_forward_mode: Optional[ForwardMode] = None

# For processing logprobs
return_logprob: bool = False
Expand Down Expand Up @@ -1624,6 +1627,7 @@ def get_model_worker_batch(self) -> ModelWorkerBatch:
or global_server_args_dict["attention_backend"] == "flashmla"
or global_server_args_dict["attention_backend"] == "fa3"
or global_server_args_dict["attention_backend"] == "cutlass_mla"
or global_server_args_dict["enable_two_batch_overlap"]
):
seq_lens_cpu = self.seq_lens.cpu()
else:
Expand Down Expand Up @@ -1651,6 +1655,8 @@ def get_model_worker_batch(self) -> ModelWorkerBatch:
global_num_tokens=self.global_num_tokens,
global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
tbo_split_seq_index=self.tbo_split_seq_index,
global_forward_mode=self.global_forward_mode,
seq_lens_cpu=seq_lens_cpu,
extend_num_tokens=self.extend_num_tokens,
extend_seq_lens=extend_seq_lens,
Expand Down Expand Up @@ -1729,6 +1735,8 @@ class ModelWorkerBatch:
global_num_tokens: Optional[List[int]]
global_num_tokens_for_logprob: Optional[List[int]]
can_run_dp_cuda_graph: bool
tbo_split_seq_index: Optional[int]
global_forward_mode: Optional[ForwardMode]

# For extend
extend_num_tokens: Optional[int]
Expand Down
26 changes: 25 additions & 1 deletion python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from torch.distributed import barrier

from sglang.global_config import global_config
from sglang.srt import two_batch_overlap
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.constrained.base_grammar_backend import create_grammar_backend
from sglang.srt.disaggregation.decode import (
Expand Down Expand Up @@ -132,7 +133,9 @@
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
from sglang.srt.utils import (
DeepEPMode,
DynamicGradMode,
broadcast_pyobj,
configure_logger,
Expand Down Expand Up @@ -1648,6 +1651,9 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
disable_cuda_graph=self.server_args.disable_cuda_graph,
spec_algorithm=self.spec_algorithm,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
enable_two_batch_overlap=self.server_args.enable_two_batch_overlap,
enable_deepep_moe=self.server_args.enable_deepep_moe,
deepep_mode=DeepEPMode[self.server_args.deepep_mode],
)

@staticmethod
Expand All @@ -1661,6 +1667,9 @@ def prepare_dp_attn_batch_raw(
disable_cuda_graph: bool,
spec_algorithm,
speculative_num_draft_tokens,
enable_two_batch_overlap: bool,
enable_deepep_moe: bool,
deepep_mode: DeepEPMode,
):
# Check if other DP workers have running batches
if local_batch is None:
Expand Down Expand Up @@ -1696,17 +1705,26 @@ def prepare_dp_attn_batch_raw(
is_extend_in_batch = (
local_batch.forward_mode.is_extend() if local_batch else False
)

tbo_preparer = TboDPAttentionPreparer()

local_info = torch.tensor(
[
num_tokens,
can_cuda_graph,
num_tokens_for_logprob,
is_extend_in_batch,
*tbo_preparer.prepare_all_gather(
local_batch,
deepep_mode,
enable_deepep_moe,
enable_two_batch_overlap,
),
],
dtype=torch.int64,
)
global_info = torch.empty(
(dp_size, attn_tp_size, 4),
(dp_size, attn_tp_size, 6),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
Expand All @@ -1719,6 +1737,10 @@ def prepare_dp_attn_batch_raw(
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
is_extend_in_batch = global_info[:, 0, 3].tolist()

tbo_split_seq_index, global_forward_mode = tbo_preparer.compute_output(
global_info[:, :, 4:6]
)

if local_batch is None and max(global_num_tokens) > 0:
local_batch = get_idle_batch()

Expand All @@ -1732,6 +1754,8 @@ def prepare_dp_attn_batch_raw(
local_batch.global_num_tokens_for_logprob = (
global_num_tokens_for_logprob
)
local_batch.tbo_split_seq_index = tbo_split_seq_index
local_batch.global_forward_mode = global_forward_mode

# Check forward mode for cuda graph
if not disable_cuda_graph:
Expand Down
Loading
Loading