Skip to content

Commit fdf58ea

Browse files
committed
use local attn dp size
1 parent 494ec47 commit fdf58ea

File tree

6 files changed

+51
-28
lines changed

6 files changed

+51
-28
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ benchmark/llava_bench/images
175175
benchmark/llava_bench/mme_pack
176176
*.jsonl
177177
tmp*.txt
178+
core.*
178179

179180
# Plots
180181
*.png

python/sglang/srt/layers/dp_attention.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,33 @@
2626
_ATTN_TP_SIZE = None
2727
_ATTN_DP_RANK = None
2828
_ATTN_DP_SIZE = None
29+
_LOCAL_ATTN_DP_SIZE = None
30+
_LOCAL_ATTN_DP_RANK = None
31+
32+
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
33+
if not enable_dp_attention:
34+
return tp_rank, tp_size, 0
35+
36+
attn_tp_size = tp_size // dp_size
37+
attn_dp_rank = tp_rank //tp_size
38+
attn_tp_rank = tp_rank % attn_tp_size
39+
40+
return attn_tp_rank, attn_tp_size, attn_dp_rank
2941

3042

31-
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size):
43+
def compute_dp_attention_local_info(enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size):
3244
if not enable_dp_attention:
3345
return tp_rank, tp_size, 0
3446

3547
local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size
3648
local_tp_rank = tp_rank % local_tp_size
3749
local_dp_size = dp_size // (tp_size // local_tp_size)
3850

39-
attn_tp_size = local_tp_size // local_dp_size
40-
attn_dp_rank = local_tp_rank // attn_tp_size
41-
attn_tp_rank = local_tp_rank % attn_tp_size
51+
local_attn_tp_size = local_tp_size // local_dp_size
52+
local_attn_dp_rank = local_tp_rank // local_attn_tp_size
53+
local_attn_tp_rank = local_tp_rank % local_attn_tp_size
4254

43-
return attn_tp_rank, attn_tp_size, attn_dp_rank
55+
return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank
4456

4557

4658
def initialize_dp_attention(
@@ -51,20 +63,26 @@ def initialize_dp_attention(
5163
moe_dense_tp_size: Optional[int],
5264
):
5365
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
66+
global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
5467

5568
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
5669

5770
_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
71+
enable_dp_attention, tp_rank, tp_size, dp_size
72+
)
73+
_, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
5874
enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
5975
)
6076

6177
if enable_dp_attention:
78+
_ATTN_DP_SIZE = dp_size
6279
if moe_dense_tp_size is None:
63-
_ATTN_DP_SIZE = dp_size
80+
_LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
6481
else:
65-
_ATTN_DP_SIZE = dp_size // (tp_size // moe_dense_tp_size)
82+
_LOCAL_ATTN_DP_SIZE = dp_size // (tp_size // moe_dense_tp_size)
6683
else:
6784
_ATTN_DP_SIZE = 1
85+
_LOCAL_ATTN_DP_SIZE = 1
6886

6987
logger.info(f"{(_ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE)=}")
7088

@@ -110,6 +128,16 @@ def get_attention_dp_size():
110128
return _ATTN_DP_SIZE
111129

112130

131+
def get_local_attention_dp_rank():
132+
assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
133+
return _LOCAL_ATTN_DP_RANK
134+
135+
136+
def get_local_attention_dp_size():
137+
assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
138+
return _LOCAL_ATTN_DP_SIZE
139+
140+
113141
@contextmanager
114142
def disable_dp_size():
115143
"""Patch the tp group temporarily until this function ends.
@@ -132,7 +160,7 @@ def disable_dp_size():
132160

133161

134162
def get_dp_local_info(forward_batch: ForwardBatch):
135-
dp_rank = get_attention_dp_rank()
163+
dp_rank = get_local_attention_dp_rank()
136164

137165
if forward_batch.dp_local_start_pos is None:
138166
cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)

python/sglang/srt/layers/logits_processor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@
3030
from sglang.srt.layers.dp_attention import (
3131
dp_gather_replicate,
3232
dp_scatter,
33-
get_attention_dp_rank,
34-
get_attention_dp_size,
33+
get_local_attention_dp_rank,
34+
get_local_attention_dp_size,
3535
)
3636
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
3737
from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -169,7 +169,7 @@ def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
169169
return
170170

171171
cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
172-
dp_rank = get_attention_dp_rank()
172+
dp_rank = get_local_attention_dp_rank()
173173
if dp_rank == 0:
174174
dp_local_start_pos = torch.zeros_like(
175175
self.global_num_tokens_for_logprob_gpu[0]
@@ -202,7 +202,7 @@ def __init__(
202202
not skip_all_gather and get_tensor_model_parallel_world_size() > 1
203203
)
204204
self.do_tensor_parallel_all_gather_dp_attn = (
205-
self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
205+
self.do_tensor_parallel_all_gather and get_local_attention_dp_size() != 1
206206
)
207207
self.final_logit_softcapping = getattr(
208208
self.config, "final_logit_softcapping", None
@@ -315,7 +315,7 @@ def forward(
315315

316316
if self.debug_tensor_dump_output_folder:
317317
assert (
318-
not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
318+
not self.do_tensor_parallel_all_gather or get_local_attention_dp_size() == 1
319319
), "dp attention + sharded lm_head doesn't support full logits"
320320
full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
321321
dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)

python/sglang/srt/managers/data_parallel_controller.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,6 @@ def launch_tensor_parallel_group(
192192
tp_rank,
193193
server_args.tp_size,
194194
server_args.dp_size,
195-
server_args.moe_dense_tp_size,
196195
)
197196
# compute zmq ports for this dp rank
198197
rank_port_args = PortArgs.init_new(server_args, dp_rank)

python/sglang/srt/models/deepseek_v2.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from sglang.srt.layers.dp_attention import (
3939
dp_gather_partial,
4040
dp_scatter,
41-
get_attention_dp_size,
41+
get_local_attention_dp_size,
4242
get_attention_tp_rank,
4343
get_attention_tp_size,
4444
tp_all_gather,
@@ -420,7 +420,6 @@ def __init__(
420420
self.v_head_dim = v_head_dim
421421
self.q_lora_rank = q_lora_rank
422422
self.kv_lora_rank = kv_lora_rank
423-
self.dp_size = get_attention_dp_size()
424423
attn_tp_rank = get_attention_tp_rank()
425424
attn_tp_size = get_attention_tp_size()
426425

@@ -1034,7 +1033,7 @@ def __init__(
10341033
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
10351034
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
10361035
self.layer_id = layer_id
1037-
self.dp_size = get_attention_dp_size()
1036+
self.local_dp_size = get_local_attention_dp_size()
10381037
self.attn_tp_size = get_attention_tp_size()
10391038
self.attn_tp_rank = get_attention_tp_rank()
10401039
self.self_attn = DeepseekV2AttentionMLA(
@@ -1166,7 +1165,7 @@ def forward_ffn_with_full_input(
11661165
# Gather
11671166
if get_tensor_model_parallel_world_size() > 1:
11681167
# all gather and all reduce
1169-
if self.dp_size != 1:
1168+
if self.local_dp_size != 1:
11701169
if self.attn_tp_rank == 0:
11711170
hidden_states += residual
11721171
hidden_states, local_hidden_states = (
@@ -1197,7 +1196,7 @@ def forward_ffn_with_full_input(
11971196

11981197
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
11991198
# Scatter
1200-
if self.dp_size != 1:
1199+
if self.local_dp_size != 1:
12011200
# important: forward batch.gathered_buffer is used both after scatter and after gather.
12021201
# be careful about this!
12031202
hidden_states, global_hidden_states = (
@@ -1341,8 +1340,6 @@ def __init__(
13411340
)
13421341
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
13431342

1344-
self.dp_size = get_attention_dp_size()
1345-
13461343
def forward(
13471344
self,
13481345
input_ids: torch.Tensor,
@@ -1411,10 +1408,9 @@ def __init__(
14111408
config.hidden_size,
14121409
quant_config=quant_config,
14131410
prefix=add_prefix("lm_head", prefix),
1414-
enable_tp=not _enable_moe_dense_fully_dp(),
1411+
enable_tp=not _enable_moe_dense_fully_dp(), # TODO: replace it with DP attention
14151412
)
14161413
self.logits_processor = LogitsProcessor(config)
1417-
self.dp_size = get_attention_dp_size()
14181414

14191415
def get_input_embeddings(self) -> nn.Embedding:
14201416
return self.model.embed_tokens

python/sglang/srt/models/llama4.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from sglang.srt.layers.dp_attention import (
3131
dp_gather_partial,
3232
dp_scatter,
33-
get_attention_dp_size,
33+
get_local_attention_dp_size,
3434
get_attention_tp_rank,
3535
get_attention_tp_size,
3636
)
@@ -152,7 +152,6 @@ def __init__(
152152
self.use_rope = int((layer_id + 1) % 4 != 0)
153153
self.use_qk_norm = config.use_qk_norm and self.use_rope
154154

155-
self.dp_size = get_attention_dp_size()
156155
attn_tp_rank = get_attention_tp_rank()
157156
attn_tp_size = get_attention_tp_size()
158157

@@ -297,7 +296,7 @@ def __init__(
297296
rope_theta = config.rope_theta
298297
rope_scaling = config.rope_scaling
299298
max_position_embeddings = config.max_position_embeddings
300-
self.dp_size = get_attention_dp_size()
299+
self.local_dp_size = get_local_attention_dp_size()
301300
self.attn_tp_size = get_attention_tp_size()
302301
self.attn_tp_rank = get_attention_tp_rank()
303302

@@ -360,7 +359,7 @@ def forward(
360359
# Gather
361360
if get_tensor_model_parallel_world_size() > 1:
362361
# all gather and all reduce
363-
if self.dp_size != 1:
362+
if self.local_dp_size != 1:
364363
if self.attn_tp_rank == 0:
365364
hidden_states += residual
366365
hidden_states, local_hidden_states = (
@@ -385,7 +384,7 @@ def forward(
385384

386385
# TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
387386
# Scatter
388-
if self.dp_size != 1:
387+
if self.local_dp_size != 1:
389388
# important: forward batch.gathered_buffer is used both after scatter and after gather.
390389
# be careful about this!
391390
hidden_states, global_hidden_states = (

0 commit comments

Comments
 (0)