Skip to content

Support qwen3 deepep #6120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 22, 2025
5 changes: 4 additions & 1 deletion python/sglang/srt/models/qwen2_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,10 @@ def forward(
)
else:
if hidden_states.shape[0] != 0:
hidden_states, _ = self.norm(hidden_states, residual)
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states


Expand Down
128 changes: 121 additions & 7 deletions python/sglang/srt/models/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
parallel_state,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce,
Expand All @@ -54,8 +55,10 @@
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
Expand All @@ -65,11 +68,15 @@
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.qwen2_moe import Qwen2MoeMLP as Qwen3MoeMLP
from sglang.srt.models.qwen2_moe import Qwen2MoeModel
from sglang.srt.utils import add_prefix
from sglang.srt.utils import DeepEPMode, add_prefix

Qwen3MoeConfig = None

Expand All @@ -92,7 +99,11 @@ def __init__(
f"the number of experts {config.num_experts}."
)

MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
MoEImpl = (
DeepEPMoE
if global_server_args_dict["enable_deepep_moe"]
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
)

self.experts = MoEImpl(
num_experts=config.num_experts,
Expand All @@ -102,6 +113,11 @@ def __init__(
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
if global_server_args_dict["enable_deepep_moe"]
else {}
),
)

self.gate = ReplicatedLinear(
Expand All @@ -112,7 +128,37 @@ def __init__(
prefix=add_prefix("gate", prefix),
)

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if global_server_args_dict["enable_deepep_moe"]:
# TODO: we will support tp < ep in the future
self.ep_size = get_tensor_model_parallel_world_size()
self.num_experts = config.num_experts
self.top_k = config.num_experts_per_tok
self.renormalize = config.norm_topk_prob

self.deepep_dispatcher = DeepEPDispatcher(
group=parallel_state.get_tp_group().device_group,
router_topk=self.top_k,
permute_fusion=True,
num_experts=config.num_experts,
num_local_experts=config.num_experts // self.tp_size,
hidden_size=config.hidden_size,
params_dtype=config.torch_dtype,
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
async_finish=True, # TODO
return_recv_hook=True,
)

def forward(
self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
) -> torch.Tensor:

if not global_server_args_dict["enable_deepep_moe"]:
return self.forward_normal(hidden_states)
else:
return self.forward_deepep(hidden_states, forward_mode)

def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:

num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)

Expand All @@ -126,6 +172,68 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:

return final_hidden_states.view(num_tokens, hidden_dim)

def forward_deepep(
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
) -> torch.Tensor:
if (
forward_mode is not None
and not forward_mode.is_idle()
and hidden_states.shape[0] > 0
):
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)

topk_weights, topk_idx = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=False,
renormalize=self.renormalize,
)
else:
topk_idx = torch.full(
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
)
topk_weights = torch.empty(
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
)
if self.ep_size > 1:
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
(
hidden_states,
topk_idx,
topk_weights,
reorder_topk_ids,
num_recv_tokens_per_expert,
seg_indptr,
masked_m,
expected_m,
) = self.deepep_dispatcher.dispatch(
hidden_states,
topk_idx,
topk_weights,
forward_mode=forward_mode,
)
final_hidden_states = self.experts(
hidden_states=hidden_states,
topk_idx=topk_idx,
topk_weights=topk_weights,
reorder_topk_ids=reorder_topk_ids,
seg_indptr=seg_indptr,
masked_m=masked_m,
expected_m=expected_m,
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
forward_mode=forward_mode,
)
if self.ep_size > 1:
final_hidden_states = self.deepep_dispatcher.combine(
final_hidden_states,
topk_idx,
topk_weights,
forward_mode,
)
return final_hidden_states


class Qwen3MoeAttention(nn.Module):
def __init__(
Expand Down Expand Up @@ -403,7 +511,7 @@ def forward_ffn_with_full_input(
)

# Fully Connected
hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)

# TODO: use reduce-scatter in MLP to avoid this scatter
# Scatter
Expand Down Expand Up @@ -577,7 +685,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
("gate_up_proj", "up_proj", 1),
]

MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
MoEImpl = (
DeepEPMoE
if global_server_args_dict["enable_deepep_moe"]
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
)

expert_params_mapping = MoEImpl.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
Expand Down
Loading