Skip to content

Commit abf1a6d

Browse files
ch-wanliusy58颉沆
authored andcommitted
Fix two issues related to --moe-dense-tp-size=1 (sgl-project#5657)
Co-authored-by: liusy58 <[email protected]> Co-authored-by: 颉沆 <[email protected]>
1 parent c7981de commit abf1a6d

File tree

6 files changed

+119
-45
lines changed

6 files changed

+119
-45
lines changed

python/sglang/srt/layers/dp_attention.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,41 +24,71 @@
2424
_ATTN_TP_GROUP = None
2525
_ATTN_TP_RANK = None
2626
_ATTN_TP_SIZE = None
27-
_DP_RANK = None
28-
_DP_SIZE = None
27+
_ATTN_DP_RANK = None
28+
_ATTN_DP_SIZE = None
29+
_LOCAL_ATTN_DP_SIZE = None
30+
_LOCAL_ATTN_DP_RANK = None
2931

3032

3133
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
3234
if not enable_dp_attention:
3335
return tp_rank, tp_size, 0
3436

3537
attn_tp_size = tp_size // dp_size
36-
dp_rank = tp_rank // attn_tp_size
38+
attn_dp_rank = tp_rank // attn_tp_size
3739
attn_tp_rank = tp_rank % attn_tp_size
38-
return attn_tp_rank, attn_tp_size, dp_rank
40+
41+
return attn_tp_rank, attn_tp_size, attn_dp_rank
42+
43+
44+
def compute_dp_attention_local_info(
45+
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
46+
):
47+
if not enable_dp_attention:
48+
return tp_rank, tp_size, 0
49+
50+
local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size
51+
local_tp_rank = tp_rank % local_tp_size
52+
local_dp_size = max(1, dp_size // (tp_size // local_tp_size))
53+
54+
local_attn_tp_size = local_tp_size // local_dp_size
55+
local_attn_dp_rank = local_tp_rank // local_attn_tp_size
56+
local_attn_tp_rank = local_tp_rank % local_attn_tp_size
57+
58+
return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank
3959

4060

4161
def initialize_dp_attention(
4262
enable_dp_attention: bool,
4363
tp_rank: int,
4464
tp_size: int,
4565
dp_size: int,
66+
moe_dense_tp_size: int,
4667
pp_size: int,
4768
):
48-
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
69+
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
70+
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
4971

5072
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
5173

52-
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
74+
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
5375
enable_dp_attention, tp_rank, tp_size, dp_size
5476
)
77+
_, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
78+
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
79+
)
5580

5681
if enable_dp_attention:
5782
local_rank = tp_rank % (tp_size // dp_size)
58-
_DP_SIZE = dp_size
83+
_ATTN_DP_SIZE = dp_size
84+
if moe_dense_tp_size is None:
85+
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
86+
else:
87+
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
5988
else:
6089
local_rank = tp_rank
61-
_DP_SIZE = 1
90+
_ATTN_DP_SIZE = 1
91+
_LOCAL_ATTN_DP_SIZE = 1
6292

6393
tp_group = get_tp_group()
6494
_ATTN_TP_GROUP = GroupCoordinator(
@@ -93,13 +123,33 @@ def get_attention_tp_size():
93123

94124

95125
def get_attention_dp_rank():
96-
assert _DP_RANK is not None, "dp attention not initialized!"
97-
return _DP_RANK
126+
assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
127+
return _ATTN_DP_RANK
98128

99129

100130
def get_attention_dp_size():
101-
assert _DP_SIZE is not None, "dp attention not initialized!"
102-
return _DP_SIZE
131+
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
132+
return _ATTN_DP_SIZE
133+
134+
135+
def get_local_attention_dp_rank():
136+
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
137+
return _LOCAL_ATTN_DP_RANK
138+
139+
140+
def get_local_attention_dp_size():
141+
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
142+
return _LOCAL_ATTN_DP_SIZE
143+
144+
145+
def get_local_attention_dp_rank():
146+
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
147+
return _LOCAL_ATTN_DP_RANK
148+
149+
150+
def get_local_attention_dp_size():
151+
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
152+
return _LOCAL_ATTN_DP_SIZE
103153

104154

105155
@contextmanager
@@ -112,19 +162,19 @@ def disable_dp_size():
112162
Args:
113163
tp_group (GroupCoordinator): the tp group coordinator
114164
"""
115-
global _DP_SIZE
116-
assert _DP_SIZE is not None, "dp attention not initialized!"
165+
global _ATTN_DP_SIZE
166+
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
117167

118-
old_dp_size = _DP_SIZE
119-
_DP_SIZE = 1
168+
old_dp_size = _ATTN_DP_SIZE
169+
_ATTN_DP_SIZE = 1
120170
try:
121171
yield
122172
finally:
123-
_DP_SIZE = old_dp_size
173+
_ATTN_DP_SIZE = old_dp_size
124174

125175

126176
def get_dp_local_info(forward_batch: ForwardBatch):
127-
dp_rank = get_attention_dp_rank()
177+
dp_rank = get_local_attention_dp_rank()
128178

129179
if forward_batch.dp_local_start_pos is None:
130180
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)

python/sglang/srt/layers/logits_processor.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@
3030
attn_tp_all_gather,
3131
dp_gather_replicate,
3232
dp_scatter,
33-
get_attention_dp_rank,
3433
get_attention_dp_size,
3534
get_attention_tp_size,
35+
get_local_attention_dp_rank,
36+
get_local_attention_dp_size,
3637
)
3738
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
3839
from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -46,6 +47,18 @@
4647
logger = logging.getLogger(__name__)
4748

4849

50+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
51+
from sglang.srt.managers.schedule_batch import global_server_args_dict
52+
from sglang.srt.model_executor.forward_batch_info import (
53+
CaptureHiddenMode,
54+
ForwardBatch,
55+
ForwardMode,
56+
)
57+
from sglang.srt.utils import dump_to_file
58+
59+
logger = logging.getLogger(__name__)
60+
61+
4962
@dataclasses.dataclass
5063
class LogitsProcessorOutput:
5164
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
@@ -170,7 +183,7 @@ def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
170183
return
171184

172185
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
173-
dp_rank = get_attention_dp_rank()
186+
dp_rank = get_local_attention_dp_rank()
174187
if dp_rank == 0:
175188
dp_local_start_pos = torch.zeros_like(
176189
self.global_num_tokens_for_logprob_gpu[0]
@@ -324,7 +337,8 @@ def forward(
324337

325338
if self.debug_tensor_dump_output_folder:
326339
assert (
327-
not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
340+
not self.do_tensor_parallel_all_gather
341+
or get_local_attention_dp_size() == 1
328342
), "dp attention + sharded lm_head doesn't support full logits"
329343
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
330344
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)

python/sglang/srt/managers/scheduler.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,8 @@ def __init__(
211211
)
212212

213213
# Distributed rank info
214-
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
214+
self.dp_size = server_args.dp_size
215+
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
215216
compute_dp_attention_world_info(
216217
server_args.enable_dp_attention,
217218
self.tp_rank,
@@ -789,7 +790,7 @@ def event_loop_pp(self):
789790
)
790791

791792
# send out reqs to the next stage
792-
dp_offset = self.dp_rank * self.attn_tp_size
793+
dp_offset = self.attn_dp_rank * self.attn_tp_size
793794
if self.attn_tp_rank == 0:
794795
point_to_point_pyobj(
795796
recv_reqs,
@@ -836,7 +837,7 @@ def recv_requests(self) -> List[Req]:
836837
recv_reqs = None
837838
else:
838839
if self.attn_tp_rank == 0:
839-
dp_offset = self.dp_rank * self.attn_tp_size
840+
dp_offset = self.attn_dp_rank * self.attn_tp_size
840841
recv_reqs = point_to_point_pyobj(
841842
[],
842843
self.pp_rank * self.tp_size + dp_offset,
@@ -1654,6 +1655,7 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
16541655
local_batch,
16551656
dp_size=self.server_args.dp_size,
16561657
attn_tp_size=self.attn_tp_size,
1658+
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
16571659
tp_cpu_group=self.tp_cpu_group,
16581660
get_idle_batch=self.get_idle_batch,
16591661
disable_cuda_graph=self.server_args.disable_cuda_graph,
@@ -1666,6 +1668,7 @@ def prepare_dp_attn_batch_raw(
16661668
local_batch: ScheduleBatch,
16671669
dp_size,
16681670
attn_tp_size: int,
1671+
moe_dense_tp_size: Optional[int],
16691672
tp_cpu_group,
16701673
get_idle_batch,
16711674
disable_cuda_graph: bool,
@@ -1675,15 +1678,15 @@ def prepare_dp_attn_batch_raw(
16751678
# Check if other DP workers have running batches
16761679
if local_batch is None:
16771680
num_tokens = 0
1678-
global_num_tokens_for_logprob = 0
1681+
num_tokens_for_logprob = 0
16791682
elif local_batch.forward_mode.is_decode():
16801683
num_tokens = local_batch.batch_size()
16811684
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
16821685
num_tokens = num_tokens * speculative_num_draft_tokens
1683-
global_num_tokens_for_logprob = num_tokens
1686+
num_tokens_for_logprob = num_tokens
16841687
else:
16851688
num_tokens = local_batch.extend_num_tokens
1686-
global_num_tokens_for_logprob = sum(
1689+
num_tokens_for_logprob = sum(
16871690
[
16881691
# We should have at least 1 token for sample in every case.
16891692
max(extend_len - logprob_start_len, 1)
@@ -1710,7 +1713,7 @@ def prepare_dp_attn_batch_raw(
17101713
[
17111714
num_tokens,
17121715
can_cuda_graph,
1713-
global_num_tokens_for_logprob,
1716+
num_tokens_for_logprob,
17141717
is_extend_in_batch,
17151718
],
17161719
dtype=torch.int64,
@@ -1733,8 +1736,15 @@ def prepare_dp_attn_batch_raw(
17331736
local_batch = get_idle_batch()
17341737

17351738
if local_batch is not None:
1736-
local_batch.global_num_tokens = global_num_tokens
1737-
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
1739+
# TODO: handle the case when moe_dense_tp_size != 1
1740+
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
1741+
local_batch.global_num_tokens = [num_tokens]
1742+
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
1743+
else:
1744+
local_batch.global_num_tokens = global_num_tokens
1745+
local_batch.global_num_tokens_for_logprob = (
1746+
global_num_tokens_for_logprob
1747+
)
17381748

17391749
# Check forward mode for cuda graph
17401750
if not disable_cuda_graph:
@@ -2226,8 +2236,8 @@ def close_session(self, recv_req: CloseSessionReqInput):
22262236

22272237
def get_print_prefix(self):
22282238
prefix = ""
2229-
if self.dp_rank is not None:
2230-
prefix += f" DP{self.dp_rank}"
2239+
if self.attn_dp_rank is not None:
2240+
prefix += f" DP{self.attn_dp_rank}"
22312241
if self.server_args.tp_size > 1:
22322242
prefix += f" TP{self.tp_rank}"
22332243
if self.pp_size > 1:

python/sglang/srt/model_executor/model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -402,6 +402,7 @@ def init_torch_distributed(self):
402402
tp_rank=self.tp_rank,
403403
tp_size=self.tp_size,
404404
dp_size=self.server_args.dp_size,
405+
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
405406
pp_size=self.server_args.pp_size,
406407
)
407408

python/sglang/srt/models/deepseek_v2.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@
4040
attn_tp_reduce_scatter,
4141
dp_gather_partial,
4242
dp_scatter,
43-
get_attention_dp_size,
4443
get_attention_tp_rank,
4544
get_attention_tp_size,
45+
get_local_attention_dp_size,
4646
)
4747
from sglang.srt.layers.layernorm import RMSNorm
4848
from sglang.srt.layers.linear import (
@@ -438,7 +438,6 @@ def __init__(
438438
self.v_head_dim = v_head_dim
439439
self.q_lora_rank = q_lora_rank
440440
self.kv_lora_rank = kv_lora_rank
441-
self.dp_size = get_attention_dp_size()
442441
attn_tp_rank = get_attention_tp_rank()
443442
attn_tp_size = get_attention_tp_size()
444443

@@ -1133,7 +1132,7 @@ def __init__(
11331132
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
11341133
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
11351134
self.layer_id = layer_id
1136-
self.dp_size = get_attention_dp_size()
1135+
self.local_dp_size = get_local_attention_dp_size()
11371136
self.attn_tp_size = get_attention_tp_size()
11381137
self.attn_tp_rank = get_attention_tp_rank()
11391138
self.self_attn = DeepseekV2AttentionMLA(
@@ -1184,7 +1183,8 @@ def __init__(
11841183
)
11851184

11861185
self.input_is_scattered = (
1187-
previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
1186+
layer_id > 0
1187+
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
11881188
)
11891189
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
11901190

@@ -1264,7 +1264,7 @@ def forward_ffn_with_full_input(
12641264
# Gather
12651265
if get_tensor_model_parallel_world_size() > 1:
12661266
# all gather and all reduce
1267-
if self.dp_size != 1:
1267+
if self.local_dp_size != 1:
12681268
if self.attn_tp_rank == 0:
12691269
hidden_states += residual
12701270
hidden_states, local_hidden_states = (
@@ -1289,7 +1289,7 @@ def forward_ffn_with_full_input(
12891289

12901290
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
12911291
# Scatter
1292-
if self.dp_size != 1:
1292+
if self.local_dp_size != 1:
12931293
# important: forward batch.gathered_buffer is used both after scatter and after gather.
12941294
# be careful about this!
12951295
hidden_states, global_hidden_states = (
@@ -1413,7 +1413,7 @@ def __init__(
14131413
)
14141414
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
14151415

1416-
self.dp_size = get_attention_dp_size()
1416+
self.dp_size = get_local_attention_dp_size()
14171417

14181418
def get_input_embeddings(self) -> torch.Tensor:
14191419
return self.embed_tokens
@@ -1478,7 +1478,7 @@ def __init__(
14781478
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
14791479
)
14801480
self.logits_processor = LogitsProcessor(config)
1481-
self.dp_size = get_attention_dp_size()
1481+
self.dp_size = get_local_attention_dp_size()
14821482

14831483
def determine_n_share_experts_fusion(
14841484
self, architecture: str = "DeepseekV3ForCausalLM"

0 commit comments

Comments
 (0)