Skip to content

Commit c4f4e55

Browse files
ch-wanjimoosciuc
authored andcommitted
Fix several minor issues in PD disaggregation (sgl-project#5444)
1 parent e01c024 commit c4f4e55

File tree

3 files changed

+67
-69
lines changed

3 files changed

+67
-69
lines changed

python/sglang/srt/disaggregation/decode.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,38 @@ def process_prebuilt_extend(
419419

420420
class SchedulerDisaggregationDecodeMixin:
421421

422+
@torch.no_grad()
423+
def event_loop_normal_disagg_decode(self):
424+
"""A normal scheduler loop for decode worker in disaggregation mode."""
425+
426+
while True:
427+
recv_reqs = self.recv_requests()
428+
self.process_input_requests(recv_reqs)
429+
# polling and allocating kv cache
430+
self.process_decode_queue()
431+
batch = self.get_next_disagg_decode_batch_to_run()
432+
self.cur_batch = batch
433+
434+
if batch:
435+
# Generate fake extend output.
436+
if batch.forward_mode.is_extend():
437+
# Note: Logprobs should be handled on the prefill engine.
438+
self.stream_output(batch.reqs, False)
439+
else:
440+
result = self.run_batch(batch)
441+
self.process_batch_result(batch, result)
442+
443+
if batch is None and (
444+
len(self.disagg_decode_transfer_queue.queue)
445+
+ len(self.disagg_decode_prealloc_queue.queue)
446+
== 0
447+
):
448+
# When the server is idle, do self-check and re-init some states
449+
self.check_memory()
450+
self.new_token_ratio = self.init_new_token_ratio
451+
452+
self.last_batch = batch
453+
422454
def get_next_disagg_decode_batch_to_run(
423455
self: Scheduler,
424456
) -> Optional[Tuple[ScheduleBatch, bool]]:

python/sglang/srt/disaggregation/prefill.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,36 @@ class SchedulerDisaggregationPrefillMixin:
171171
Mixin for Scheduler to handle disaggregation prefill
172172
"""
173173

174+
@torch.no_grad()
175+
def event_loop_normal_disagg_prefill(self):
176+
"""A normal scheduler loop for prefill worker in disaggregation mode."""
177+
178+
while True:
179+
recv_reqs = self.recv_requests()
180+
self.process_input_requests(recv_reqs)
181+
self.waiting_queue.extend(
182+
self.disagg_prefill_pending_queue.pop_bootstrapped()
183+
)
184+
self.process_prefill_chunk()
185+
batch = self.get_new_batch_prefill()
186+
self.cur_batch = batch
187+
188+
if batch:
189+
result = self.run_batch(batch)
190+
self.process_batch_result_disagg_prefill(batch, result)
191+
192+
if len(self.disagg_prefill_inflight_queue) > 0:
193+
self.process_disagg_prefill_inflight_queue()
194+
195+
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
196+
self.check_memory()
197+
self.new_token_ratio = self.init_new_token_ratio
198+
199+
self.last_batch = batch
200+
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
201+
# Otherwise, it hangs under high concurrency
202+
self.running_batch.batch_is_full = False
203+
174204
def process_batch_result_disagg_prefill(
175205
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
176206
) -> None:
@@ -210,7 +240,7 @@ def process_disagg_prefill_inflight_queue(self: Scheduler) -> None:
210240

211241
polls = poll_and_all_reduce(
212242
[req.disagg_kv_sender for req in self.disagg_prefill_inflight_queue],
213-
self.tp_worker.get_tp_cpu_group(),
243+
self.attn_tp_cpu_group,
214244
)
215245

216246
undone_reqs: List[Req] = []

python/sglang/srt/managers/scheduler.py

Lines changed: 4 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ def init_memory_pool_and_cache(self):
503503
self.tree_cache = HiRadixCache(
504504
req_to_token_pool=self.req_to_token_pool,
505505
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
506-
tp_cache_group=self.tp_worker.get_tp_cpu_group(),
506+
tp_cache_group=self.tp_cpu_group,
507507
page_size=self.page_size,
508508
hicache_ratio=server_args.hicache_ratio,
509509
)
@@ -572,7 +572,7 @@ def init_disaggregation(self):
572572

573573
# The decode requests polling kv cache
574574
self.disagg_decode_transfer_queue = DecodeTransferQueue(
575-
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
575+
gloo_group=self.attn_tp_cpu_group,
576576
req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
577577
metadata_buffers=metadata_buffers,
578578
)
@@ -587,7 +587,7 @@ def init_disaggregation(self):
587587
scheduler=self,
588588
transfer_queue=self.disagg_decode_transfer_queue,
589589
tree_cache=self.tree_cache,
590-
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
590+
gloo_group=self.attn_tp_cpu_group,
591591
tp_rank=self.tp_rank,
592592
tp_size=self.tp_size,
593593
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
@@ -616,7 +616,7 @@ def init_disaggregation(self):
616616
tp_rank=self.tp_rank,
617617
tp_size=self.tp_size,
618618
bootstrap_port=self.server_args.disaggregation_bootstrap_port,
619-
gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
619+
gloo_group=self.attn_tp_cpu_group,
620620
transfer_backend=self.transfer_backend,
621621
scheduler=self,
622622
)
@@ -683,70 +683,6 @@ def event_loop_overlap(self):
683683

684684
self.last_batch = batch
685685

686-
@torch.no_grad()
687-
def event_loop_normal_disagg_prefill(self):
688-
"""A normal scheduler loop for prefill worker in disaggregation mode."""
689-
690-
while True:
691-
recv_reqs = self.recv_requests()
692-
self.process_input_requests(recv_reqs)
693-
self.waiting_queue.extend(
694-
self.disagg_prefill_pending_queue.pop_bootstrapped()
695-
)
696-
self.process_prefill_chunk()
697-
batch = self.get_new_batch_prefill()
698-
self.cur_batch = batch
699-
700-
if batch:
701-
result = self.run_batch(batch)
702-
self.process_batch_result_disagg_prefill(batch, result)
703-
704-
if len(self.disagg_prefill_inflight_queue) > 0:
705-
self.process_disagg_prefill_inflight_queue()
706-
707-
if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
708-
self.check_memory()
709-
self.new_token_ratio = self.init_new_token_ratio
710-
711-
self.last_batch = batch
712-
# HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
713-
# Otherwise, it hangs under high concurrency
714-
self.running_batch.batch_is_full = False
715-
716-
@torch.no_grad()
717-
def event_loop_normal_disagg_decode(self):
718-
"""A normal scheduler loop for decode worker in disaggregation mode."""
719-
720-
while True:
721-
recv_reqs = self.recv_requests()
722-
self.process_input_requests(recv_reqs)
723-
# polling and allocating kv cache
724-
self.process_decode_queue()
725-
batch = self.get_next_disagg_decode_batch_to_run()
726-
self.cur_batch = batch
727-
728-
if batch:
729-
# Generate fake extend output.
730-
if batch.forward_mode.is_extend():
731-
# Note: Logprobs should be handled on the prefill engine.
732-
self.stream_output(
733-
batch.reqs, [False for _ in range(len(batch.reqs))]
734-
)
735-
else:
736-
result = self.run_batch(batch)
737-
self.process_batch_result(batch, result)
738-
739-
if batch is None and (
740-
len(self.disagg_decode_transfer_queue.queue)
741-
+ len(self.disagg_decode_prealloc_queue.queue)
742-
== 0
743-
):
744-
# When the server is idle, do self-check and re-init some states
745-
self.check_memory()
746-
self.new_token_ratio = self.init_new_token_ratio
747-
748-
self.last_batch = batch
749-
750686
def recv_requests(self) -> List[Req]:
751687
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
752688
if self.attn_tp_rank == 0:

0 commit comments

Comments
 (0)