Skip to content

Commit e4abef0

Browse files
sleepcooLayssy
authored andcommitted
Support qwen3 deepep (sgl-project#6120)
1 parent 0f81015 commit e4abef0

File tree

2 files changed

+125
-8
lines changed

2 files changed

+125
-8
lines changed

python/sglang/srt/models/qwen2_moe.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,10 @@ def forward(
607607
)
608608
else:
609609
if hidden_states.shape[0] != 0:
610-
hidden_states, _ = self.norm(hidden_states, residual)
610+
if residual is None:
611+
hidden_states = self.norm(hidden_states)
612+
else:
613+
hidden_states, _ = self.norm(hidden_states, residual)
611614
return hidden_states
612615

613616

python/sglang/srt/models/qwen3_moe.py

Lines changed: 121 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
get_pp_group,
3333
get_tensor_model_parallel_rank,
3434
get_tensor_model_parallel_world_size,
35+
parallel_state,
3536
split_tensor_along_last_dim,
3637
tensor_model_parallel_all_gather,
3738
tensor_model_parallel_all_reduce,
@@ -54,8 +55,10 @@
5455
RowParallelLinear,
5556
)
5657
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
57-
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
58+
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
59+
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
5860
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
61+
from sglang.srt.layers.moe.topk import select_experts
5962
from sglang.srt.layers.quantization.base_config import QuantizationConfig
6063
from sglang.srt.layers.radix_attention import RadixAttention
6164
from sglang.srt.layers.rotary_embedding import get_rope
@@ -65,11 +68,15 @@
6568
VocabParallelEmbedding,
6669
)
6770
from sglang.srt.managers.schedule_batch import global_server_args_dict
68-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
71+
from sglang.srt.model_executor.forward_batch_info import (
72+
ForwardBatch,
73+
ForwardMode,
74+
PPProxyTensors,
75+
)
6976
from sglang.srt.model_loader.weight_utils import default_weight_loader
7077
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
7178
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
72-
from sglang.srt.utils import add_prefix
79+
from sglang.srt.utils import DeepEPMode, add_prefix
7380

7481
Qwen3MoeConfig = None
7582

@@ -92,7 +99,11 @@ def __init__(
9299
f"the number of experts {config.num_experts}."
93100
)
94101

95-
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
102+
MoEImpl = (
103+
DeepEPMoE
104+
if global_server_args_dict["enable_deepep_moe"]
105+
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
106+
)
96107

97108
self.experts = MoEImpl(
98109
num_experts=config.num_experts,
@@ -102,6 +113,11 @@ def __init__(
102113
renormalize=config.norm_topk_prob,
103114
quant_config=quant_config,
104115
prefix=add_prefix("experts", prefix),
116+
**(
117+
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
118+
if global_server_args_dict["enable_deepep_moe"]
119+
else {}
120+
),
105121
)
106122

107123
self.gate = ReplicatedLinear(
@@ -112,7 +128,37 @@ def __init__(
112128
prefix=add_prefix("gate", prefix),
113129
)
114130

115-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
131+
if global_server_args_dict["enable_deepep_moe"]:
132+
# TODO: we will support tp < ep in the future
133+
self.ep_size = get_tensor_model_parallel_world_size()
134+
self.num_experts = config.num_experts
135+
self.top_k = config.num_experts_per_tok
136+
self.renormalize = config.norm_topk_prob
137+
138+
self.deepep_dispatcher = DeepEPDispatcher(
139+
group=parallel_state.get_tp_group().device_group,
140+
router_topk=self.top_k,
141+
permute_fusion=True,
142+
num_experts=config.num_experts,
143+
num_local_experts=config.num_experts // self.tp_size,
144+
hidden_size=config.hidden_size,
145+
params_dtype=config.torch_dtype,
146+
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
147+
async_finish=True, # TODO
148+
return_recv_hook=True,
149+
)
150+
151+
def forward(
152+
self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
153+
) -> torch.Tensor:
154+
155+
if not global_server_args_dict["enable_deepep_moe"]:
156+
return self.forward_normal(hidden_states)
157+
else:
158+
return self.forward_deepep(hidden_states, forward_mode)
159+
160+
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
161+
116162
num_tokens, hidden_dim = hidden_states.shape
117163
hidden_states = hidden_states.view(-1, hidden_dim)
118164

@@ -126,6 +172,68 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
126172

127173
return final_hidden_states.view(num_tokens, hidden_dim)
128174

175+
def forward_deepep(
176+
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
177+
) -> torch.Tensor:
178+
if (
179+
forward_mode is not None
180+
and not forward_mode.is_idle()
181+
and hidden_states.shape[0] > 0
182+
):
183+
# router_logits: (num_tokens, n_experts)
184+
router_logits, _ = self.gate(hidden_states)
185+
186+
topk_weights, topk_idx = select_experts(
187+
hidden_states=hidden_states,
188+
router_logits=router_logits,
189+
top_k=self.top_k,
190+
use_grouped_topk=False,
191+
renormalize=self.renormalize,
192+
)
193+
else:
194+
topk_idx = torch.full(
195+
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
196+
)
197+
topk_weights = torch.empty(
198+
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
199+
)
200+
if self.ep_size > 1:
201+
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
202+
(
203+
hidden_states,
204+
topk_idx,
205+
topk_weights,
206+
reorder_topk_ids,
207+
num_recv_tokens_per_expert,
208+
seg_indptr,
209+
masked_m,
210+
expected_m,
211+
) = self.deepep_dispatcher.dispatch(
212+
hidden_states,
213+
topk_idx,
214+
topk_weights,
215+
forward_mode=forward_mode,
216+
)
217+
final_hidden_states = self.experts(
218+
hidden_states=hidden_states,
219+
topk_idx=topk_idx,
220+
topk_weights=topk_weights,
221+
reorder_topk_ids=reorder_topk_ids,
222+
seg_indptr=seg_indptr,
223+
masked_m=masked_m,
224+
expected_m=expected_m,
225+
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
226+
forward_mode=forward_mode,
227+
)
228+
if self.ep_size > 1:
229+
final_hidden_states = self.deepep_dispatcher.combine(
230+
final_hidden_states,
231+
topk_idx,
232+
topk_weights,
233+
forward_mode,
234+
)
235+
return final_hidden_states
236+
129237

130238
class Qwen3MoeAttention(nn.Module):
131239
def __init__(
@@ -403,7 +511,7 @@ def forward_ffn_with_full_input(
403511
)
404512

405513
# Fully Connected
406-
hidden_states = self.mlp(hidden_states)
514+
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
407515

408516
# TODO: use reduce-scatter in MLP to avoid this scatter
409517
# Scatter
@@ -577,7 +685,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
577685
("gate_up_proj", "up_proj", 1),
578686
]
579687

580-
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
688+
# Params for weights, fp8 weight scales, fp8 activation scales
689+
# (param_name, weight_name, expert_id, shard_id)
690+
MoEImpl = (
691+
DeepEPMoE
692+
if global_server_args_dict["enable_deepep_moe"]
693+
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
694+
)
581695

582696
expert_params_mapping = MoEImpl.make_expert_params_mapping(
583697
ckpt_gate_proj_name="gate_proj",

0 commit comments

Comments
 (0)