Skip to content

Commit fb860f3

Browse files
author
Leos
committed
feat: two batch overlap
1 parent bcc16a0 commit fb860f3

File tree

5 files changed

+1776
-167
lines changed

5 files changed

+1776
-167
lines changed

python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py

Lines changed: 202 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,17 @@
2323
)
2424
from sglang.srt.model_executor.forward_batch_info import ForwardMode
2525

26+
import logging
27+
28+
logger = logging.getLogger(__name__)
29+
2630

2731
class DeepEPDispatchMode(IntEnum):
2832
NORMAL = auto()
2933
LOW_LATENCY = auto()
3034

3135

3236
class DeepEPBuffer:
33-
3437
_buffer = None
3538
_dispatch_mode: Optional[DeepEPDispatchMode] = None
3639
_hidden_size: Optional[int] = None
@@ -145,6 +148,20 @@ def __init__(
145148
self.num_max_dispatch_tokens_per_rank = 128
146149

147150
self.handle = None
151+
self.event = None
152+
153+
def launch_dispatch(
154+
self,
155+
hidden_states: torch.Tensor,
156+
topk_idx: torch.Tensor,
157+
topk_weights: torch.Tensor,
158+
num_experts: int,
159+
num_max_dispatch_tokens_per_rank: int,
160+
) -> torch.Tensor:
161+
raise NotImplementedError
162+
163+
def wait_dispatch(self):
164+
raise NotImplementedError
148165

149166
def dispatch_a(
150167
self,
@@ -157,6 +174,17 @@ def dispatch_a(
157174
def dispatch_b(self, *args, **kwargs):
158175
raise NotImplementedError
159176

177+
def launch_combine(
178+
self,
179+
hidden_states: torch.Tensor,
180+
topk_idx: torch.Tensor,
181+
topk_weights: torch.Tensor,
182+
):
183+
raise NotImplementedError
184+
185+
def wait_combine(self):
186+
raise NotImplementedError
187+
160188
def combine_a(
161189
self,
162190
hidden_states: torch.Tensor,
@@ -179,6 +207,60 @@ def __init__(self, async_finish: bool, **kwargs):
179207
self.async_finish = async_finish
180208
self.src2dst = None
181209

210+
def launch_dispatch(
211+
self,
212+
hidden_states: torch.Tensor,
213+
topk_idx: torch.Tensor,
214+
topk_weights: torch.Tensor,
215+
num_experts: int,
216+
num_max_dispatch_tokens_per_rank: int,
217+
) -> torch.Tensor:
218+
topk_idx = topk_idx.to(torch.int64)
219+
previous_event = Buffer.capture() if self.async_finish else None
220+
(
221+
self.hidden_states,
222+
self.topk_idx,
223+
self.topk_weights,
224+
self.event,
225+
) = self._dispatch_core(
226+
hidden_states, topk_idx, topk_weights, previous_event
227+
)
228+
return self.hidden_states
229+
230+
def wait_dispatch(self):
231+
self.event.current_stream_wait() if self.async_finish else ()
232+
233+
hidden_states = self.hidden_states
234+
topk_idx = self.topk_idx
235+
topk_weights = self.topk_weights
236+
237+
if self.hidden_states.shape[0] > 0:
238+
reorder_topk_ids, seg_indptr, hidden_states = self._deepep_permute(
239+
hidden_states, topk_idx, fp8_dtype=hidden_states.dtype
240+
)
241+
else:
242+
reorder_topk_ids = torch.empty(
243+
(0,), device=hidden_states.device, dtype=torch.int64
244+
)
245+
seg_indptr = torch.zeros(
246+
(self.num_experts + 1,), device=hidden_states.device, dtype=torch.int64
247+
)
248+
249+
masked_m = expected_m = None
250+
self.event = None
251+
self.hidden_states = None
252+
self.topk_idx = None
253+
self.topk_weights = None
254+
return (
255+
hidden_states,
256+
topk_idx,
257+
topk_weights,
258+
reorder_topk_ids,
259+
seg_indptr,
260+
masked_m,
261+
expected_m,
262+
)
263+
182264
def dispatch_a(
183265
self,
184266
hidden_states: torch.Tensor,
@@ -242,7 +324,7 @@ def _dispatch_core(
242324
async_finish=self.async_finish,
243325
allocate_on_comm_stream=previous_event is not None,
244326
)
245-
327+
246328
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
247329
# However, doing this would incur an unknown synchronization error, but keeping
248330
# `handle` as a member variable works.
@@ -266,7 +348,7 @@ def _dispatch_core(
266348
async_finish=self.async_finish,
267349
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
268350
)
269-
351+
270352
return (
271353
recv_x,
272354
recv_topk_idx,
@@ -313,6 +395,45 @@ def _deepep_permute(
313395
)
314396
return reorder_topk_ids, seg_indptr, gateup_input
315397

398+
def launch_combine(
399+
self,
400+
hidden_states: torch.Tensor,
401+
topk_idx: torch.Tensor,
402+
topk_weights: torch.Tensor,
403+
):
404+
if hidden_states.shape[0] > 0:
405+
num_tokens = self.src2dst.shape[0] // self.router_topk
406+
output = torch.empty(
407+
(num_tokens, hidden_states.shape[1]),
408+
device=hidden_states.device,
409+
dtype=hidden_states.dtype,
410+
)
411+
deepep_post_reorder_triton_kernel[(num_tokens,)](
412+
hidden_states,
413+
output,
414+
self.src2dst,
415+
topk_idx,
416+
topk_weights,
417+
self.router_topk,
418+
hidden_states.shape[1],
419+
BLOCK_SIZE=512,
420+
)
421+
else:
422+
output = torch.zeros(
423+
(0, hidden_states.shape[1]),
424+
device=hidden_states.device,
425+
dtype=hidden_states.dtype,
426+
)
427+
previous_event = Buffer.capture() if self.async_finish else None
428+
self.hidden_states, self.event = self._combine_core(output, previous_event)
429+
430+
def wait_combine(self):
431+
self.event.current_stream_wait() if self.async_finish else ()
432+
self.handle = None
433+
self.src2dst = None
434+
self.event = None
435+
return self.hidden_states
436+
316437
def combine_a(
317438
self,
318439
hidden_states: torch.Tensor,
@@ -385,6 +506,40 @@ def __init__(self, return_recv_hook: bool, **kwargs):
385506
"""
386507
self.return_recv_hook = return_recv_hook
387508

509+
def launch_dispatch(
510+
self,
511+
hidden_states: torch.Tensor,
512+
topk_idx: torch.Tensor,
513+
topk_weights: torch.Tensor,
514+
num_experts: int,
515+
num_max_dispatch_tokens_per_rank: int,
516+
) -> torch.Tensor:
517+
(
518+
self.hidden_states,
519+
self.topk_idx,
520+
self.topk_weights,
521+
self.masked_m,
522+
self.expected_m,
523+
self.event,
524+
self.hook,
525+
) = self.dispatch_a(hidden_states=hidden_states, topk_idx=topk_idx, topk_weights=topk_weights)
526+
return hidden_states
527+
528+
def wait_dispatch(self):
529+
self.hook() if self.return_recv_hook else self.event.current_stream_wait()
530+
531+
reorder_topk_ids = seg_indptr = None
532+
533+
return (
534+
self.hidden_states,
535+
self.topk_idx,
536+
self.topk_weights,
537+
reorder_topk_ids,
538+
seg_indptr,
539+
self.masked_m,
540+
self.expected_m,
541+
)
542+
388543
def dispatch_a(
389544
self,
390545
hidden_states: torch.Tensor,
@@ -491,6 +646,19 @@ def _dispatch_core(
491646
)
492647
return packed_recv_hidden, packed_recv_count, event, hook
493648

649+
def launch_combine(
650+
self,
651+
hidden_states: torch.Tensor,
652+
topk_idx: torch.Tensor,
653+
topk_weights: torch.Tensor,
654+
):
655+
self.hidden_states, self.event, self.hook = self.combine_a(hidden_states=hidden_states, topk_idx=topk_idx,
656+
topk_weights=topk_weights)
657+
658+
def wait_combine(self):
659+
self.hook() if self.return_recv_hook else self.event.current_stream_wait()
660+
return self.hidden_states
661+
494662
def combine_a(
495663
self,
496664
hidden_states: torch.Tensor,
@@ -576,6 +744,33 @@ def __init__(
576744
**common_kwargs,
577745
)
578746

747+
def launch_dispatch(
748+
self,
749+
hidden_states: torch.Tensor,
750+
topk_idx: torch.Tensor,
751+
topk_weights: torch.Tensor,
752+
num_experts: int,
753+
num_max_dispatch_tokens_per_rank: Optional[int] = 128,
754+
forward_mode: Optional[ForwardMode] = None
755+
) -> torch.Tensor:
756+
return self._get_impl(forward_mode).launch_dispatch(hidden_states, topk_idx, topk_weights, num_experts,
757+
num_max_dispatch_tokens_per_rank)
758+
759+
def wait_dispatch(self, forward_mode: ForwardMode = None) -> Tuple:
760+
return self._get_impl(forward_mode).wait_dispatch()
761+
762+
def launch_combine(
763+
self,
764+
hidden_states: torch.Tensor,
765+
topk_idx: torch.Tensor,
766+
topk_weights: torch.Tensor,
767+
forward_mode: ForwardMode = None
768+
):
769+
self._get_impl(forward_mode).launch_combine(hidden_states, topk_idx, topk_weights)
770+
771+
def wait_combine(self, forward_mode: ForwardMode = None) -> Tuple:
772+
return self._get_impl(forward_mode).wait_combine()
773+
579774
def dispatch(self, *args, **kwargs) -> Tuple:
580775
self.dispatch_a(*args, **kwargs)
581776
return self.dispatch_b()
@@ -630,3 +825,7 @@ def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
630825
return self._low_latency_dispatcher
631826
else:
632827
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
828+
829+
830+
DeepEPDispatcherImplLowLatency = _DeepEPDispatcherImplLowLatency
831+
DeepEPDispatcherImplNormal = _DeepEPDispatcherImplNormal

python/sglang/srt/managers/schedule_batch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
8787
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
8888
"disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
89+
"enable_micro_batch_overlap": ServerArgs.enable_micro_batch_overlap,
8990
}
9091

9192
logger = logging.getLogger(__name__)

python/sglang/srt/model_executor/model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ def __init__(
169169
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
170170
"disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
171171
"use_mla_backend": self.use_mla_backend,
172+
"enable_micro_batch_overlap": server_args.enable_micro_batch_overlap,
172173
}
173174
)
174175

0 commit comments

Comments
 (0)