diff --git a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py index 8bb9224e37..6b67f6cea8 100644 --- a/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py +++ b/python/sglang/srt/layers/moe/ep_moe/token_dispatcher.py @@ -5,7 +5,6 @@ except ImportError: use_deepep = False -import os from typing import Optional, Tuple import torch @@ -101,6 +100,7 @@ def __init__( num_local_experts: int = None, hidden_size: int = None, params_dtype: torch.dtype = None, + async_finish: bool = False, ): self.group = group self.router_topk = router_topk @@ -117,6 +117,7 @@ def __init__( self.token_probs = None # Handle used for combine operation self.handle = None + self.async_finish = async_finish # `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256 # https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding @@ -182,7 +183,6 @@ def dispatch( topk_weights: torch.Tensor, num_experts: int, forward_mode: ForwardMode, - previous_event=None, num_max_dispatch_tokens_per_rank: int = 128, ) -> Tuple[torch.Tensor, torch.Tensor]: topk_idx = topk_idx.to(torch.int64) @@ -195,9 +195,7 @@ def dispatch( num_recv_tokens_per_expert_list, handle, event, - ) = self.dispatch_normal( - hidden_states, topk_idx, topk_weights, num_experts, previous_event - ) + ) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts) self.tokens_per_expert = torch.tensor( num_recv_tokens_per_expert_list, device=hidden_states.device, @@ -213,6 +211,10 @@ def dispatch( ) ) self.recv_expert_count = recv_expert_count + + if self.async_finish: + event.current_stream_wait() + self.handle = handle self.topk_idx = topk_idx self.topk_weights = topk_weights @@ -235,8 +237,9 @@ def dispatch_normal( topk_idx: torch.Tensor, topk_weights: torch.Tensor, num_experts: int, - previous_event=None, ): + previous_event = Buffer.capture() if self.async_finish else None + ( num_tokens_per_rank, num_tokens_per_rdma_rank, @@ -247,8 +250,8 @@ def dispatch_normal( topk_idx, num_experts, previous_event=previous_event, - async_finish=False, - allocate_on_comm_stream=False, + async_finish=self.async_finish, + allocate_on_comm_stream=previous_event is not None, ) ( @@ -267,8 +270,8 @@ def dispatch_normal( is_token_in_rank=is_token_in_rank, num_tokens_per_expert=num_tokens_per_expert, previous_event=previous_event, - async_finish=False, - allocate_on_comm_stream=False, + async_finish=self.async_finish, + allocate_on_comm_stream=(previous_event is not None) and self.async_finish, ) return ( @@ -333,7 +336,7 @@ def dispatch_low_latency( topk_idx, num_max_dispatch_tokens_per_rank, num_experts, - async_finish=False, + async_finish=self.async_finish, return_recv_hook=False, # True for double-batch overlapping, need call hook() ) ) @@ -373,16 +376,22 @@ def combine( hidden_states, event, hook = self.combine_low_latency( hidden_states, self.topk_idx, self.topk_weights, self.handle ) + + if self.async_finish: + event.current_stream_wait() + self.handle = None return hidden_states - def combine_normal(self, x: torch.Tensor, handle: Tuple, previous_event=None): + def combine_normal(self, x: torch.Tensor, handle: Tuple): + previous_event = Buffer.capture() if self.async_finish else None + combined_x, _, event = self.buffer_normal.combine( x, handle, - async_finish=False, + async_finish=self.async_finish, previous_event=previous_event, - allocate_on_comm_stream=False, + allocate_on_comm_stream=previous_event is not None, ) return combined_x, event @@ -399,7 +408,7 @@ def combine_low_latency( topk_idx, topk_weights, handle, - async_finish=False, + async_finish=self.async_finish, return_recv_hook=False, # True for double-batch overlapping, need call hook() ) ) diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 2214030c8f..f847b601c4 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -239,6 +239,7 @@ def __init__( num_local_experts=config.n_routed_experts // self.tp_size, hidden_size=config.hidden_size, params_dtype=config.torch_dtype, + async_finish=True, # TODO ) def forward(