Skip to content

Commit 1b48922

Browse files
ch-wanliusy58颉沆
authored andcommitted
Fix two issues related to --moe-dense-tp-size=1 (sgl-project#5657)
Co-authored-by: liusy58 <[email protected]> Co-authored-by: 颉沆 <[email protected]>
1 parent ee68e7e commit 1b48922

File tree

6 files changed

+119
-45
lines changed

6 files changed

+119
-45
lines changed

python/sglang/srt/layers/dp_attention.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,41 +24,71 @@
2424
_ATTN_TP_GROUP = None
2525
_ATTN_TP_RANK = None
2626
_ATTN_TP_SIZE = None
27-
_DP_RANK = None
28-
_DP_SIZE = None
27+
_ATTN_DP_RANK = None
28+
_ATTN_DP_SIZE = None
29+
_LOCAL_ATTN_DP_SIZE = None
30+
_LOCAL_ATTN_DP_RANK = None
2931

3032

3133
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
3234
if not enable_dp_attention:
3335
return tp_rank, tp_size, 0
3436

3537
attn_tp_size = tp_size // dp_size
36-
dp_rank = tp_rank // attn_tp_size
38+
attn_dp_rank = tp_rank // attn_tp_size
3739
attn_tp_rank = tp_rank % attn_tp_size
38-
return attn_tp_rank, attn_tp_size, dp_rank
40+
41+
return attn_tp_rank, attn_tp_size, attn_dp_rank
42+
43+
44+
def compute_dp_attention_local_info(
45+
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
46+
):
47+
if not enable_dp_attention:
48+
return tp_rank, tp_size, 0
49+
50+
local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size
51+
local_tp_rank = tp_rank % local_tp_size
52+
local_dp_size = max(1, dp_size // (tp_size // local_tp_size))
53+
54+
local_attn_tp_size = local_tp_size // local_dp_size
55+
local_attn_dp_rank = local_tp_rank // local_attn_tp_size
56+
local_attn_tp_rank = local_tp_rank % local_attn_tp_size
57+
58+
return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank
3959

4060

4161
def initialize_dp_attention(
4262
enable_dp_attention: bool,
4363
tp_rank: int,
4464
tp_size: int,
4565
dp_size: int,
66+
moe_dense_tp_size: int,
4667
pp_size: int,
4768
):
48-
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
69+
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
70+
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
4971

5072
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
5173

52-
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
74+
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
5375
enable_dp_attention, tp_rank, tp_size, dp_size
5476
)
77+
_, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
78+
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
79+
)
5580

5681
if enable_dp_attention:
5782
local_rank = tp_rank % (tp_size // dp_size)
58-
_DP_SIZE = dp_size
83+
_ATTN_DP_SIZE = dp_size
84+
if moe_dense_tp_size is None:
85+
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
86+
else:
87+
_LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
5988
else:
6089
local_rank = tp_rank
61-
_DP_SIZE = 1
90+
_ATTN_DP_SIZE = 1
91+
_LOCAL_ATTN_DP_SIZE = 1
6292

6393
tp_group = get_tp_group()
6494
_ATTN_TP_GROUP = GroupCoordinator(
@@ -93,13 +123,33 @@ def get_attention_tp_size():
93123

94124

95125
def get_attention_dp_rank():
96-
assert _DP_RANK is not None, "dp attention not initialized!"
97-
return _DP_RANK
126+
assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
127+
return _ATTN_DP_RANK
98128

99129

100130
def get_attention_dp_size():
101-
assert _DP_SIZE is not None, "dp attention not initialized!"
102-
return _DP_SIZE
131+
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
132+
return _ATTN_DP_SIZE
133+
134+
135+
def get_local_attention_dp_rank():
136+
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
137+
return _LOCAL_ATTN_DP_RANK
138+
139+
140+
def get_local_attention_dp_size():
141+
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
142+
return _LOCAL_ATTN_DP_SIZE
143+
144+
145+
def get_local_attention_dp_rank():
146+
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
147+
return _LOCAL_ATTN_DP_RANK
148+
149+
150+
def get_local_attention_dp_size():
151+
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
152+
return _LOCAL_ATTN_DP_SIZE
103153

104154

105155
@contextmanager
@@ -112,19 +162,19 @@ def disable_dp_size():
112162
Args:
113163
tp_group (GroupCoordinator): the tp group coordinator
114164
"""
115-
global _DP_SIZE
116-
assert _DP_SIZE is not None, "dp attention not initialized!"
165+
global _ATTN_DP_SIZE
166+
assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
117167

118-
old_dp_size = _DP_SIZE
119-
_DP_SIZE = 1
168+
old_dp_size = _ATTN_DP_SIZE
169+
_ATTN_DP_SIZE = 1
120170
try:
121171
yield
122172
finally:
123-
_DP_SIZE = old_dp_size
173+
_ATTN_DP_SIZE = old_dp_size
124174

125175

126176
def get_dp_local_info(forward_batch: ForwardBatch):
127-
dp_rank = get_attention_dp_rank()
177+
dp_rank = get_local_attention_dp_rank()
128178

129179
if forward_batch.dp_local_start_pos is None:
130180
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)

python/sglang/srt/layers/logits_processor.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,10 @@
3030
attn_tp_all_gather,
3131
dp_gather_replicate,
3232
dp_scatter,
33-
get_attention_dp_rank,
3433
get_attention_dp_size,
3534
get_attention_tp_size,
35+
get_local_attention_dp_rank,
36+
get_local_attention_dp_size,
3637
)
3738
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
3839
from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -46,6 +47,18 @@
4647
logger = logging.getLogger(__name__)
4748

4849

50+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
51+
from sglang.srt.managers.schedule_batch import global_server_args_dict
52+
from sglang.srt.model_executor.forward_batch_info import (
53+
CaptureHiddenMode,
54+
ForwardBatch,
55+
ForwardMode,
56+
)
57+
from sglang.srt.utils import dump_to_file
58+
59+
logger = logging.getLogger(__name__)
60+
61+
4962
@dataclasses.dataclass
5063
class LogitsProcessorOutput:
5164
## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
@@ -170,7 +183,7 @@ def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
170183
return
171184

172185
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
173-
dp_rank = get_attention_dp_rank()
186+
dp_rank = get_local_attention_dp_rank()
174187
if dp_rank == 0:
175188
dp_local_start_pos = torch.zeros_like(
176189
self.global_num_tokens_for_logprob_gpu[0]
@@ -324,7 +337,8 @@ def forward(
324337

325338
if self.debug_tensor_dump_output_folder:
326339
assert (
327-
not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
340+
not self.do_tensor_parallel_all_gather
341+
or get_local_attention_dp_size() == 1
328342
), "dp attention + sharded lm_head doesn't support full logits"
329343
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
330344
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)

python/sglang/srt/managers/scheduler.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,8 @@ def __init__(
207207
self.page_size = server_args.page_size
208208

209209
# Distributed rank info
210-
self.attn_tp_rank, self.attn_tp_size, self.dp_rank = (
210+
self.dp_size = server_args.dp_size
211+
self.attn_tp_rank, self.attn_tp_size, self.attn_dp_rank = (
211212
compute_dp_attention_world_info(
212213
server_args.enable_dp_attention,
213214
self.tp_rank,
@@ -768,7 +769,7 @@ def event_loop_pp(self):
768769
)
769770

770771
# send out reqs to the next stage
771-
dp_offset = self.dp_rank * self.attn_tp_size
772+
dp_offset = self.attn_dp_rank * self.attn_tp_size
772773
if self.attn_tp_rank == 0:
773774
point_to_point_pyobj(
774775
recv_reqs,
@@ -815,7 +816,7 @@ def recv_requests(self) -> List[Req]:
815816
recv_reqs = None
816817
else:
817818
if self.attn_tp_rank == 0:
818-
dp_offset = self.dp_rank * self.attn_tp_size
819+
dp_offset = self.attn_dp_rank * self.attn_tp_size
819820
recv_reqs = point_to_point_pyobj(
820821
[],
821822
self.pp_rank * self.tp_size + dp_offset,
@@ -1610,6 +1611,7 @@ def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
16101611
local_batch,
16111612
dp_size=self.server_args.dp_size,
16121613
attn_tp_size=self.attn_tp_size,
1614+
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
16131615
tp_cpu_group=self.tp_cpu_group,
16141616
get_idle_batch=self.get_idle_batch,
16151617
disable_cuda_graph=self.server_args.disable_cuda_graph,
@@ -1622,6 +1624,7 @@ def prepare_dp_attn_batch_raw(
16221624
local_batch: ScheduleBatch,
16231625
dp_size,
16241626
attn_tp_size: int,
1627+
moe_dense_tp_size: Optional[int],
16251628
tp_cpu_group,
16261629
get_idle_batch,
16271630
disable_cuda_graph: bool,
@@ -1631,15 +1634,15 @@ def prepare_dp_attn_batch_raw(
16311634
# Check if other DP workers have running batches
16321635
if local_batch is None:
16331636
num_tokens = 0
1634-
global_num_tokens_for_logprob = 0
1637+
num_tokens_for_logprob = 0
16351638
elif local_batch.forward_mode.is_decode():
16361639
num_tokens = local_batch.batch_size()
16371640
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
16381641
num_tokens = num_tokens * speculative_num_draft_tokens
1639-
global_num_tokens_for_logprob = num_tokens
1642+
num_tokens_for_logprob = num_tokens
16401643
else:
16411644
num_tokens = local_batch.extend_num_tokens
1642-
global_num_tokens_for_logprob = sum(
1645+
num_tokens_for_logprob = sum(
16431646
[
16441647
# We should have at least 1 token for sample in every case.
16451648
max(extend_len - logprob_start_len, 1)
@@ -1666,7 +1669,7 @@ def prepare_dp_attn_batch_raw(
16661669
[
16671670
num_tokens,
16681671
can_cuda_graph,
1669-
global_num_tokens_for_logprob,
1672+
num_tokens_for_logprob,
16701673
is_extend_in_batch,
16711674
],
16721675
dtype=torch.int64,
@@ -1689,8 +1692,15 @@ def prepare_dp_attn_batch_raw(
16891692
local_batch = get_idle_batch()
16901693

16911694
if local_batch is not None:
1692-
local_batch.global_num_tokens = global_num_tokens
1693-
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
1695+
# TODO: handle the case when moe_dense_tp_size != 1
1696+
if moe_dense_tp_size == 1 and global_server_args_dict["enable_dp_lm_head"]:
1697+
local_batch.global_num_tokens = [num_tokens]
1698+
local_batch.global_num_tokens_for_logprob = [num_tokens_for_logprob]
1699+
else:
1700+
local_batch.global_num_tokens = global_num_tokens
1701+
local_batch.global_num_tokens_for_logprob = (
1702+
global_num_tokens_for_logprob
1703+
)
16941704

16951705
# Check forward mode for cuda graph
16961706
if not disable_cuda_graph:
@@ -2177,8 +2187,8 @@ def close_session(self, recv_req: CloseSessionReqInput):
21772187

21782188
def get_print_prefix(self):
21792189
prefix = ""
2180-
if self.dp_rank is not None:
2181-
prefix += f" DP{self.dp_rank}"
2190+
if self.attn_dp_rank is not None:
2191+
prefix += f" DP{self.attn_dp_rank}"
21822192
if self.server_args.tp_size > 1:
21832193
prefix += f" TP{self.tp_rank}"
21842194
if self.pp_size > 1:

python/sglang/srt/model_executor/model_runner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,7 @@ def init_torch_distributed(self):
401401
tp_rank=self.tp_rank,
402402
tp_size=self.tp_size,
403403
dp_size=self.server_args.dp_size,
404+
moe_dense_tp_size=self.server_args.moe_dense_tp_size,
404405
pp_size=self.server_args.pp_size,
405406
)
406407

python/sglang/srt/models/deepseek_v2.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@
4040
attn_tp_reduce_scatter,
4141
dp_gather_partial,
4242
dp_scatter,
43-
get_attention_dp_size,
4443
get_attention_tp_rank,
4544
get_attention_tp_size,
45+
get_local_attention_dp_size,
4646
)
4747
from sglang.srt.layers.layernorm import RMSNorm
4848
from sglang.srt.layers.linear import (
@@ -438,7 +438,6 @@ def __init__(
438438
self.v_head_dim = v_head_dim
439439
self.q_lora_rank = q_lora_rank
440440
self.kv_lora_rank = kv_lora_rank
441-
self.dp_size = get_attention_dp_size()
442441
attn_tp_rank = get_attention_tp_rank()
443442
attn_tp_size = get_attention_tp_size()
444443

@@ -1133,7 +1132,7 @@ def __init__(
11331132
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
11341133
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
11351134
self.layer_id = layer_id
1136-
self.dp_size = get_attention_dp_size()
1135+
self.local_dp_size = get_local_attention_dp_size()
11371136
self.attn_tp_size = get_attention_tp_size()
11381137
self.attn_tp_rank = get_attention_tp_rank()
11391138
self.self_attn = DeepseekV2AttentionMLA(
@@ -1184,7 +1183,8 @@ def __init__(
11841183
)
11851184

11861185
self.input_is_scattered = (
1187-
previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
1186+
layer_id > 0
1187+
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
11881188
)
11891189
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
11901190

@@ -1264,7 +1264,7 @@ def forward_ffn_with_full_input(
12641264
# Gather
12651265
if get_tensor_model_parallel_world_size() > 1:
12661266
# all gather and all reduce
1267-
if self.dp_size != 1:
1267+
if self.local_dp_size != 1:
12681268
if self.attn_tp_rank == 0:
12691269
hidden_states += residual
12701270
hidden_states, local_hidden_states = (
@@ -1289,7 +1289,7 @@ def forward_ffn_with_full_input(
12891289

12901290
# TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
12911291
# Scatter
1292-
if self.dp_size != 1:
1292+
if self.local_dp_size != 1:
12931293
# important: forward batch.gathered_buffer is used both after scatter and after gather.
12941294
# be careful about this!
12951295
hidden_states, global_hidden_states = (
@@ -1413,7 +1413,7 @@ def __init__(
14131413
)
14141414
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
14151415

1416-
self.dp_size = get_attention_dp_size()
1416+
self.dp_size = get_local_attention_dp_size()
14171417

14181418
def get_input_embeddings(self) -> torch.Tensor:
14191419
return self.embed_tokens
@@ -1478,7 +1478,7 @@ def __init__(
14781478
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
14791479
)
14801480
self.logits_processor = LogitsProcessor(config)
1481-
self.dp_size = get_attention_dp_size()
1481+
self.dp_size = get_local_attention_dp_size()
14821482

14831483
def determine_n_share_experts_fusion(
14841484
self, architecture: str = "DeepseekV3ForCausalLM"

0 commit comments

Comments
 (0)