Skip to content

Commit 53dcf38

Browse files
authored
Introduce moe_dense_tp_size to fix dense layer errors in DeepSeek V3 + 4x8xH100 (#4836)
1 parent 1effba4 commit 53dcf38

File tree

4 files changed

+31
-1
lines changed

4 files changed

+31
-1
lines changed

python/sglang/srt/managers/schedule_batch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
7979
"disable_radix_cache": ServerArgs.disable_radix_cache,
8080
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
81+
"moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
8182
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
8283
"n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
8384
"disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,

python/sglang/srt/model_executor/model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def __init__(
159159
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
160160
"disable_radix_cache": server_args.disable_radix_cache,
161161
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
162+
"moe_dense_tp_size": server_args.moe_dense_tp_size,
162163
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
163164
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
164165
"n_share_experts_fusion": server_args.n_share_experts_fusion,

python/sglang/srt/models/deepseek_v2.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1066,12 +1066,18 @@ def __init__(
10661066
prefix=add_prefix("mlp", prefix),
10671067
)
10681068
else:
1069+
if self._enable_moe_dense_fully_dp():
1070+
mlp_tp_rank, mlp_tp_size = 0, 1
1071+
else:
1072+
mlp_tp_rank, mlp_tp_size = None, None
10691073
self.mlp = DeepseekV2MLP(
10701074
hidden_size=config.hidden_size,
10711075
intermediate_size=config.intermediate_size,
10721076
hidden_act=config.hidden_act,
10731077
quant_config=quant_config,
10741078
prefix=add_prefix("mlp", prefix),
1079+
tp_rank=mlp_tp_rank,
1080+
tp_size=mlp_tp_size,
10751081
)
10761082

10771083
self.input_is_scattered = (
@@ -1084,6 +1090,10 @@ def __init__(
10841090
config.hidden_size, eps=config.rms_norm_eps
10851091
)
10861092

1093+
@staticmethod
1094+
def _enable_moe_dense_fully_dp():
1095+
return global_server_args_dict["moe_dense_tp_size"] == 1
1096+
10871097
@staticmethod
10881098
def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
10891099
is_sparse = is_nextn or (
@@ -1094,6 +1104,7 @@ def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
10941104
ffn_input_mode = (
10951105
_FFNInputMode.SCATTERED
10961106
if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
1107+
or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
10971108
else _FFNInputMode.FULL
10981109
)
10991110
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
@@ -1240,7 +1251,12 @@ def forward_ffn_with_scattered_input(
12401251
hidden_states, residual
12411252
)
12421253

1243-
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
1254+
if not (
1255+
self._enable_moe_dense_fully_dp()
1256+
and (not self.info.is_sparse)
1257+
and hidden_states.shape[0] == 0
1258+
):
1259+
hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
12441260

12451261
if self.is_last_layer and self.attn_tp_size != 1:
12461262
hidden_states += residual

python/sglang/srt/server_args.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ class ServerArgs:
181181
hicache_ratio: float = 2.0
182182
flashinfer_mla_disable_ragged: bool = False
183183
warmups: Optional[str] = None
184+
moe_dense_tp_size: Optional[int] = None
184185
n_share_experts_fusion: int = 0
185186
disable_shared_experts_fusion: bool = False
186187
disable_chunked_prefix_cache: bool = False
@@ -252,6 +253,11 @@ def __post_init__(self):
252253

253254
assert self.chunked_prefill_size % self.page_size == 0
254255

256+
assert self.moe_dense_tp_size in {
257+
1,
258+
None,
259+
}, f"moe_dense_tp_size only support 1 and None currently"
260+
255261
if self.attention_backend == "flashmla":
256262
logger.warning(
257263
"FlashMLA only supports a page_size of 64, change page_size to 64."
@@ -1101,6 +1107,12 @@ def add_cli_args(parser: argparse.ArgumentParser):
11011107
action="store_true",
11021108
help="Enabling DeepEP MoE implementation for EP MoE.",
11031109
)
1110+
parser.add_argument(
1111+
"--moe-dense-tp-size",
1112+
type=int,
1113+
default=ServerArgs.moe_dense_tp_size,
1114+
help="TP size for MoE dense MLP layers. This flag is useful when, with large TP size, there are errors caused by weights in MLP layers having dimension smaller than the min dimension GEMM supports.",
1115+
)
11041116
parser.add_argument(
11051117
"--deepep-mode",
11061118
type=str,

0 commit comments

Comments
 (0)