23
23
)
24
24
from sglang .srt .model_executor .forward_batch_info import ForwardMode
25
25
26
+ import logging
27
+
28
+ logger = logging .getLogger (__name__ )
29
+
26
30
27
31
class DeepEPDispatchMode (IntEnum ):
28
32
NORMAL = auto ()
29
33
LOW_LATENCY = auto ()
30
34
31
35
32
36
class DeepEPBuffer :
33
-
34
37
_buffer = None
35
38
_dispatch_mode : Optional [DeepEPDispatchMode ] = None
36
39
_hidden_size : Optional [int ] = None
@@ -145,6 +148,20 @@ def __init__(
145
148
self .num_max_dispatch_tokens_per_rank = 128
146
149
147
150
self .handle = None
151
+ self .event = None
152
+
153
+ def launch_dispatch (
154
+ self ,
155
+ hidden_states : torch .Tensor ,
156
+ topk_idx : torch .Tensor ,
157
+ topk_weights : torch .Tensor ,
158
+ num_experts : int ,
159
+ num_max_dispatch_tokens_per_rank : int ,
160
+ ) -> torch .Tensor :
161
+ raise NotImplementedError
162
+
163
+ def wait_dispatch (self ):
164
+ raise NotImplementedError
148
165
149
166
def dispatch_a (
150
167
self ,
@@ -157,6 +174,17 @@ def dispatch_a(
157
174
def dispatch_b (self , * args , ** kwargs ):
158
175
raise NotImplementedError
159
176
177
+ def launch_combine (
178
+ self ,
179
+ hidden_states : torch .Tensor ,
180
+ topk_idx : torch .Tensor ,
181
+ topk_weights : torch .Tensor ,
182
+ ):
183
+ raise NotImplementedError
184
+
185
+ def wait_combine (self ):
186
+ raise NotImplementedError
187
+
160
188
def combine_a (
161
189
self ,
162
190
hidden_states : torch .Tensor ,
@@ -179,6 +207,60 @@ def __init__(self, async_finish: bool, **kwargs):
179
207
self .async_finish = async_finish
180
208
self .src2dst = None
181
209
210
+ def launch_dispatch (
211
+ self ,
212
+ hidden_states : torch .Tensor ,
213
+ topk_idx : torch .Tensor ,
214
+ topk_weights : torch .Tensor ,
215
+ num_experts : int ,
216
+ num_max_dispatch_tokens_per_rank : int ,
217
+ ) -> torch .Tensor :
218
+ topk_idx = topk_idx .to (torch .int64 )
219
+ previous_event = Buffer .capture () if self .async_finish else None
220
+ (
221
+ self .hidden_states ,
222
+ self .topk_idx ,
223
+ self .topk_weights ,
224
+ self .event ,
225
+ ) = self ._dispatch_core (
226
+ hidden_states , topk_idx , topk_weights , previous_event
227
+ )
228
+ return self .hidden_states
229
+
230
+ def wait_dispatch (self ):
231
+ self .event .current_stream_wait () if self .async_finish else ()
232
+
233
+ hidden_states = self .hidden_states
234
+ topk_idx = self .topk_idx
235
+ topk_weights = self .topk_weights
236
+
237
+ if self .hidden_states .shape [0 ] > 0 :
238
+ reorder_topk_ids , seg_indptr , hidden_states = self ._deepep_permute (
239
+ hidden_states , topk_idx , fp8_dtype = hidden_states .dtype
240
+ )
241
+ else :
242
+ reorder_topk_ids = torch .empty (
243
+ (0 ,), device = hidden_states .device , dtype = torch .int64
244
+ )
245
+ seg_indptr = torch .zeros (
246
+ (self .num_experts + 1 ,), device = hidden_states .device , dtype = torch .int64
247
+ )
248
+
249
+ masked_m = expected_m = None
250
+ self .event = None
251
+ self .hidden_states = None
252
+ self .topk_idx = None
253
+ self .topk_weights = None
254
+ return (
255
+ hidden_states ,
256
+ topk_idx ,
257
+ topk_weights ,
258
+ reorder_topk_ids ,
259
+ seg_indptr ,
260
+ masked_m ,
261
+ expected_m ,
262
+ )
263
+
182
264
def dispatch_a (
183
265
self ,
184
266
hidden_states : torch .Tensor ,
@@ -242,7 +324,7 @@ def _dispatch_core(
242
324
async_finish = self .async_finish ,
243
325
allocate_on_comm_stream = previous_event is not None ,
244
326
)
245
-
327
+
246
328
# FIXME: `handle` should be transmitted with tokens from dispatch to combine.
247
329
# However, doing this would incur an unknown synchronization error, but keeping
248
330
# `handle` as a member variable works.
@@ -266,7 +348,7 @@ def _dispatch_core(
266
348
async_finish = self .async_finish ,
267
349
allocate_on_comm_stream = (previous_event is not None ) and self .async_finish ,
268
350
)
269
-
351
+
270
352
return (
271
353
recv_x ,
272
354
recv_topk_idx ,
@@ -313,6 +395,45 @@ def _deepep_permute(
313
395
)
314
396
return reorder_topk_ids , seg_indptr , gateup_input
315
397
398
+ def launch_combine (
399
+ self ,
400
+ hidden_states : torch .Tensor ,
401
+ topk_idx : torch .Tensor ,
402
+ topk_weights : torch .Tensor ,
403
+ ):
404
+ if hidden_states .shape [0 ] > 0 :
405
+ num_tokens = self .src2dst .shape [0 ] // self .router_topk
406
+ output = torch .empty (
407
+ (num_tokens , hidden_states .shape [1 ]),
408
+ device = hidden_states .device ,
409
+ dtype = hidden_states .dtype ,
410
+ )
411
+ deepep_post_reorder_triton_kernel [(num_tokens ,)](
412
+ hidden_states ,
413
+ output ,
414
+ self .src2dst ,
415
+ topk_idx ,
416
+ topk_weights ,
417
+ self .router_topk ,
418
+ hidden_states .shape [1 ],
419
+ BLOCK_SIZE = 512 ,
420
+ )
421
+ else :
422
+ output = torch .zeros (
423
+ (0 , hidden_states .shape [1 ]),
424
+ device = hidden_states .device ,
425
+ dtype = hidden_states .dtype ,
426
+ )
427
+ previous_event = Buffer .capture () if self .async_finish else None
428
+ self .hidden_states , self .event = self ._combine_core (output , previous_event )
429
+
430
+ def wait_combine (self ):
431
+ self .event .current_stream_wait () if self .async_finish else ()
432
+ self .handle = None
433
+ self .src2dst = None
434
+ self .event = None
435
+ return self .hidden_states
436
+
316
437
def combine_a (
317
438
self ,
318
439
hidden_states : torch .Tensor ,
@@ -385,6 +506,40 @@ def __init__(self, return_recv_hook: bool, **kwargs):
385
506
"""
386
507
self .return_recv_hook = return_recv_hook
387
508
509
+ def launch_dispatch (
510
+ self ,
511
+ hidden_states : torch .Tensor ,
512
+ topk_idx : torch .Tensor ,
513
+ topk_weights : torch .Tensor ,
514
+ num_experts : int ,
515
+ num_max_dispatch_tokens_per_rank : int ,
516
+ ) -> torch .Tensor :
517
+ (
518
+ self .hidden_states ,
519
+ self .topk_idx ,
520
+ self .topk_weights ,
521
+ self .masked_m ,
522
+ self .expected_m ,
523
+ self .event ,
524
+ self .hook ,
525
+ ) = self .dispatch_a (hidden_states = hidden_states , topk_idx = topk_idx , topk_weights = topk_weights )
526
+ return hidden_states
527
+
528
+ def wait_dispatch (self ):
529
+ self .hook () if self .return_recv_hook else self .event .current_stream_wait ()
530
+
531
+ reorder_topk_ids = seg_indptr = None
532
+
533
+ return (
534
+ self .hidden_states ,
535
+ self .topk_idx ,
536
+ self .topk_weights ,
537
+ reorder_topk_ids ,
538
+ seg_indptr ,
539
+ self .masked_m ,
540
+ self .expected_m ,
541
+ )
542
+
388
543
def dispatch_a (
389
544
self ,
390
545
hidden_states : torch .Tensor ,
@@ -491,6 +646,19 @@ def _dispatch_core(
491
646
)
492
647
return packed_recv_hidden , packed_recv_count , event , hook
493
648
649
+ def launch_combine (
650
+ self ,
651
+ hidden_states : torch .Tensor ,
652
+ topk_idx : torch .Tensor ,
653
+ topk_weights : torch .Tensor ,
654
+ ):
655
+ self .hidden_states , self .event , self .hook = self .combine_a (hidden_states = hidden_states , topk_idx = topk_idx ,
656
+ topk_weights = topk_weights )
657
+
658
+ def wait_combine (self ):
659
+ self .hook () if self .return_recv_hook else self .event .current_stream_wait ()
660
+ return self .hidden_states
661
+
494
662
def combine_a (
495
663
self ,
496
664
hidden_states : torch .Tensor ,
@@ -576,6 +744,33 @@ def __init__(
576
744
** common_kwargs ,
577
745
)
578
746
747
+ def launch_dispatch (
748
+ self ,
749
+ hidden_states : torch .Tensor ,
750
+ topk_idx : torch .Tensor ,
751
+ topk_weights : torch .Tensor ,
752
+ num_experts : int ,
753
+ num_max_dispatch_tokens_per_rank : Optional [int ] = 128 ,
754
+ forward_mode : Optional [ForwardMode ] = None
755
+ ) -> torch .Tensor :
756
+ return self ._get_impl (forward_mode ).launch_dispatch (hidden_states , topk_idx , topk_weights , num_experts ,
757
+ num_max_dispatch_tokens_per_rank )
758
+
759
+ def wait_dispatch (self , forward_mode : ForwardMode = None ) -> Tuple :
760
+ return self ._get_impl (forward_mode ).wait_dispatch ()
761
+
762
+ def launch_combine (
763
+ self ,
764
+ hidden_states : torch .Tensor ,
765
+ topk_idx : torch .Tensor ,
766
+ topk_weights : torch .Tensor ,
767
+ forward_mode : ForwardMode = None
768
+ ):
769
+ self ._get_impl (forward_mode ).launch_combine (hidden_states , topk_idx , topk_weights )
770
+
771
+ def wait_combine (self , forward_mode : ForwardMode = None ) -> Tuple :
772
+ return self ._get_impl (forward_mode ).wait_combine ()
773
+
579
774
def dispatch (self , * args , ** kwargs ) -> Tuple :
580
775
self .dispatch_a (* args , ** kwargs )
581
776
return self .dispatch_b ()
@@ -630,3 +825,7 @@ def _get_impl(self, forward_mode: ForwardMode) -> _DeepEPDispatcherImplBase:
630
825
return self ._low_latency_dispatcher
631
826
else :
632
827
raise ValueError (f"Invalid deepep_mode: { self .deepep_mode } " )
828
+
829
+
830
+ DeepEPDispatcherImplLowLatency = _DeepEPDispatcherImplLowLatency
831
+ DeepEPDispatcherImplNormal = _DeepEPDispatcherImplNormal
0 commit comments