Skip to content

Commit bfe59ad

Browse files
committed
runnable
1 parent 2fc22c2 commit bfe59ad

File tree

2 files changed

+25
-8
lines changed

2 files changed

+25
-8
lines changed

python/sglang/srt/managers/data_parallel_controller.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import setproctitle
2424
import zmq
2525

26+
from python.sglang.srt.disaggregation.utils import DisaggregationMode
27+
from python.sglang.srt.managers.schedule_batch import Req
2628
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
2729
from sglang.srt.managers.io_struct import (
2830
TokenizedEmbeddingReqInput,
@@ -220,9 +222,12 @@ def launch_tensor_parallel_group(
220222
self.max_total_num_tokens = scheduler_info[0]["max_total_num_tokens"]
221223
self.max_req_input_len = scheduler_info[0]["max_req_input_len"]
222224

223-
def round_robin_scheduler(self, req):
224-
self.workers[self.round_robin_counter].send_pyobj(req)
225-
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
225+
def round_robin_scheduler(self, req: Req):
226+
if self.server_args.disaggregation_mode == DisaggregationMode.NULL:
227+
self.workers[self.round_robin_counter].send_pyobj(req)
228+
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
229+
else:
230+
self.workers[req.bootstrap_room % len(self.workers)].send_pyobj(req)
226231

227232
def shortest_queue_scheduler(self, input_requests):
228233
raise NotImplementedError()

python/sglang/srt/managers/scheduler.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def init_memory_pool_and_cache(self):
484484
self.tree_cache = HiRadixCache(
485485
req_to_token_pool=self.req_to_token_pool,
486486
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
487-
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
487+
tp_cache_group=self.tp_cpu_group,
488488
page_size=self.page_size,
489489
hicache_ratio=server_args.hicache_ratio,
490490
)
@@ -553,7 +553,7 @@ def init_disaggregation(self):
553553

554554
# The decode requests polling kv cache
555555
self.disagg_decode_transfer_queue = DecodeTransferQueue(
556-
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
556+
gloo_group=self.attn_tp_cpu_group,
557557
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
558558
metadata_buffers=metadata_buffers,
559559
)
@@ -568,7 +568,7 @@ def init_disaggregation(self):
568568
scheduler=self,
569569
transfer_queue=self.disagg_decode_transfer_queue,
570570
tree_cache=self.tree_cache,
571-
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
571+
gloo_group=self.attn_tp_cpu_group,
572572
tp_rank=self.tp_rank,
573573
tp_size=self.tp_size,
574574
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
@@ -597,7 +597,7 @@ def init_disaggregation(self):
597597
tp_rank=self.tp_rank,
598598
tp_size=self.tp_size,
599599
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
600-
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
600+
gloo_group=self.attn_tp_cpu_group,
601601
transfer_backend=self.transfer_backend,
602602
scheduler=self,
603603
)
@@ -676,6 +676,11 @@ def event_loop_normal_disagg_prefill(self):
676676
)
677677
self.process_prefill_chunk()
678678
batch = self.get_new_batch_prefill()
679+
680+
# Handle DP attention
681+
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
682+
batch, _ = self.prepare_dp_attn_batch(batch)
683+
679684
self.cur_batch = batch
680685

681686
if batch:
@@ -704,11 +709,18 @@ def event_loop_normal_disagg_decode(self):
704709
# polling and allocating kv cache
705710
self.process_decode_queue()
706711
batch = self.get_next_disagg_decode_batch_to_run()
712+
713+
# Handle DP attention
714+
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
715+
batch, do_extend = self.prepare_dp_attn_batch(batch)
716+
elif batch:
717+
do_extend = batch.forward_mode.is_extend()
718+
707719
self.cur_batch = batch
708720

709721
if batch:
710722
# Generate fake extend output.
711-
if batch.forward_mode.is_extend():
723+
if do_extend:
712724
# Note: Logprobs should be handled on the prefill engine.
713725
self.stream_output(
714726
batch.reqs, [False for _ in range(len(batch.reqs))]

0 commit comments

Comments
 (0)