32
32
get_pp_group ,
33
33
get_tensor_model_parallel_rank ,
34
34
get_tensor_model_parallel_world_size ,
35
+ parallel_state ,
35
36
split_tensor_along_last_dim ,
36
37
tensor_model_parallel_all_gather ,
37
38
tensor_model_parallel_all_reduce ,
54
55
RowParallelLinear ,
55
56
)
56
57
from sglang .srt .layers .logits_processor import LogitsProcessor , LogitsProcessorOutput
57
- from sglang .srt .layers .moe .ep_moe .layer import EPMoE
58
+ from sglang .srt .layers .moe .ep_moe .layer import DeepEPMoE , EPMoE
59
+ from sglang .srt .layers .moe .ep_moe .token_dispatcher import DeepEPDispatcher
58
60
from sglang .srt .layers .moe .fused_moe_triton import FusedMoE
61
+ from sglang .srt .layers .moe .topk import select_experts
59
62
from sglang .srt .layers .quantization .base_config import QuantizationConfig
60
63
from sglang .srt .layers .radix_attention import RadixAttention
61
64
from sglang .srt .layers .rotary_embedding import get_rope
65
68
VocabParallelEmbedding ,
66
69
)
67
70
from sglang .srt .managers .schedule_batch import global_server_args_dict
68
- from sglang .srt .model_executor .forward_batch_info import ForwardBatch , PPProxyTensors
71
+ from sglang .srt .model_executor .forward_batch_info import (
72
+ ForwardBatch ,
73
+ ForwardMode ,
74
+ PPProxyTensors ,
75
+ )
69
76
from sglang .srt .model_loader .weight_utils import default_weight_loader
70
77
from sglang .srt .models .qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
71
78
from sglang .srt .models .qwen2_moe import Qwen2MoeModel
72
- from sglang .srt .utils import add_prefix
79
+ from sglang .srt .utils import DeepEPMode , add_prefix
73
80
74
81
Qwen3MoeConfig = None
75
82
@@ -92,7 +99,11 @@ def __init__(
92
99
f"the number of experts { config .num_experts } ."
93
100
)
94
101
95
- MoEImpl = EPMoE if global_server_args_dict ["enable_ep_moe" ] else FusedMoE
102
+ MoEImpl = (
103
+ DeepEPMoE
104
+ if global_server_args_dict ["enable_deepep_moe" ]
105
+ else (EPMoE if global_server_args_dict ["enable_ep_moe" ] else FusedMoE )
106
+ )
96
107
97
108
self .experts = MoEImpl (
98
109
num_experts = config .num_experts ,
@@ -102,6 +113,11 @@ def __init__(
102
113
renormalize = config .norm_topk_prob ,
103
114
quant_config = quant_config ,
104
115
prefix = add_prefix ("experts" , prefix ),
116
+ ** (
117
+ dict (deepep_mode = DeepEPMode [global_server_args_dict ["deepep_mode" ]])
118
+ if global_server_args_dict ["enable_deepep_moe" ]
119
+ else {}
120
+ ),
105
121
)
106
122
107
123
self .gate = ReplicatedLinear (
@@ -112,7 +128,37 @@ def __init__(
112
128
prefix = add_prefix ("gate" , prefix ),
113
129
)
114
130
115
- def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
131
+ if global_server_args_dict ["enable_deepep_moe" ]:
132
+ # TODO: we will support tp < ep in the future
133
+ self .ep_size = get_tensor_model_parallel_world_size ()
134
+ self .num_experts = config .num_experts
135
+ self .top_k = config .num_experts_per_tok
136
+ self .renormalize = config .norm_topk_prob
137
+
138
+ self .deepep_dispatcher = DeepEPDispatcher (
139
+ group = parallel_state .get_tp_group ().device_group ,
140
+ router_topk = self .top_k ,
141
+ permute_fusion = True ,
142
+ num_experts = config .num_experts ,
143
+ num_local_experts = config .num_experts // self .tp_size ,
144
+ hidden_size = config .hidden_size ,
145
+ params_dtype = config .torch_dtype ,
146
+ deepep_mode = DeepEPMode [global_server_args_dict ["deepep_mode" ]],
147
+ async_finish = True , # TODO
148
+ return_recv_hook = True ,
149
+ )
150
+
151
+ def forward (
152
+ self , hidden_states : torch .Tensor , forward_mode : Optional [ForwardMode ] = None
153
+ ) -> torch .Tensor :
154
+
155
+ if not global_server_args_dict ["enable_deepep_moe" ]:
156
+ return self .forward_normal (hidden_states )
157
+ else :
158
+ return self .forward_deepep (hidden_states , forward_mode )
159
+
160
+ def forward_normal (self , hidden_states : torch .Tensor ) -> torch .Tensor :
161
+
116
162
num_tokens , hidden_dim = hidden_states .shape
117
163
hidden_states = hidden_states .view (- 1 , hidden_dim )
118
164
@@ -126,6 +172,68 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
126
172
127
173
return final_hidden_states .view (num_tokens , hidden_dim )
128
174
175
+ def forward_deepep (
176
+ self , hidden_states : torch .Tensor , forward_mode : ForwardMode
177
+ ) -> torch .Tensor :
178
+ if (
179
+ forward_mode is not None
180
+ and not forward_mode .is_idle ()
181
+ and hidden_states .shape [0 ] > 0
182
+ ):
183
+ # router_logits: (num_tokens, n_experts)
184
+ router_logits , _ = self .gate (hidden_states )
185
+
186
+ topk_weights , topk_idx = select_experts (
187
+ hidden_states = hidden_states ,
188
+ router_logits = router_logits ,
189
+ top_k = self .top_k ,
190
+ use_grouped_topk = False ,
191
+ renormalize = self .renormalize ,
192
+ )
193
+ else :
194
+ topk_idx = torch .full (
195
+ (0 , self .top_k ), - 1 , dtype = torch .int , device = hidden_states .device
196
+ )
197
+ topk_weights = torch .empty (
198
+ (0 , self .top_k ), dtype = torch .float32 , device = hidden_states .device
199
+ )
200
+ if self .ep_size > 1 :
201
+ # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
202
+ (
203
+ hidden_states ,
204
+ topk_idx ,
205
+ topk_weights ,
206
+ reorder_topk_ids ,
207
+ num_recv_tokens_per_expert ,
208
+ seg_indptr ,
209
+ masked_m ,
210
+ expected_m ,
211
+ ) = self .deepep_dispatcher .dispatch (
212
+ hidden_states ,
213
+ topk_idx ,
214
+ topk_weights ,
215
+ forward_mode = forward_mode ,
216
+ )
217
+ final_hidden_states = self .experts (
218
+ hidden_states = hidden_states ,
219
+ topk_idx = topk_idx ,
220
+ topk_weights = topk_weights ,
221
+ reorder_topk_ids = reorder_topk_ids ,
222
+ seg_indptr = seg_indptr ,
223
+ masked_m = masked_m ,
224
+ expected_m = expected_m ,
225
+ num_recv_tokens_per_expert = num_recv_tokens_per_expert ,
226
+ forward_mode = forward_mode ,
227
+ )
228
+ if self .ep_size > 1 :
229
+ final_hidden_states = self .deepep_dispatcher .combine (
230
+ final_hidden_states ,
231
+ topk_idx ,
232
+ topk_weights ,
233
+ forward_mode ,
234
+ )
235
+ return final_hidden_states
236
+
129
237
130
238
class Qwen3MoeAttention (nn .Module ):
131
239
def __init__ (
@@ -403,7 +511,7 @@ def forward_ffn_with_full_input(
403
511
)
404
512
405
513
# Fully Connected
406
- hidden_states = self .mlp (hidden_states )
514
+ hidden_states = self .mlp (hidden_states , forward_batch . forward_mode )
407
515
408
516
# TODO: use reduce-scatter in MLP to avoid this scatter
409
517
# Scatter
@@ -577,7 +685,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
577
685
("gate_up_proj" , "up_proj" , 1 ),
578
686
]
579
687
580
- MoEImpl = EPMoE if global_server_args_dict ["enable_ep_moe" ] else FusedMoE
688
+ # Params for weights, fp8 weight scales, fp8 activation scales
689
+ # (param_name, weight_name, expert_id, shard_id)
690
+ MoEImpl = (
691
+ DeepEPMoE
692
+ if global_server_args_dict ["enable_deepep_moe" ]
693
+ else (EPMoE if global_server_args_dict ["enable_ep_moe" ] else FusedMoE )
694
+ )
581
695
582
696
expert_params_mapping = MoEImpl .make_expert_params_mapping (
583
697
ckpt_gate_proj_name = "gate_proj" ,
0 commit comments