Skip to content

Commit 5a7368a

Browse files
ByronHsutarinkk
authored andcommitted
[PD] Support decode overlap schedule (sgl-project#5608)
1 parent 7ad6cf7 commit 5a7368a

File tree

3 files changed

+49
-5
lines changed

3 files changed

+49
-5
lines changed

python/sglang/srt/disaggregation/decode.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from __future__ import annotations
2222

2323
import logging
24+
from collections import deque
2425
from dataclasses import dataclass
2526
from typing import TYPE_CHECKING, List, Optional, Tuple
2627

@@ -475,6 +476,48 @@ def event_loop_normal_disagg_decode(self):
475476

476477
self.last_batch = batch
477478

479+
@torch.no_grad()
480+
def event_loop_overlap_disagg_decode(self):
481+
result_queue = deque()
482+
self.last_batch: Optional[ScheduleBatch] = None
483+
self.last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend
484+
485+
while True:
486+
recv_reqs = self.recv_requests()
487+
self.process_input_requests(recv_reqs)
488+
# polling and allocating kv cache
489+
self.process_decode_queue()
490+
batch = self.get_next_disagg_decode_batch_to_run()
491+
self.cur_batch = batch
492+
last_batch_is_extend = False
493+
494+
if batch:
495+
# Generate fake extend output.
496+
if batch.forward_mode.is_extend():
497+
# Note: Logprobs should be handled on the prefill engine.
498+
self.stream_output(batch.reqs, False)
499+
last_batch_is_extend = True
500+
else:
501+
result = self.run_batch(batch)
502+
result_queue.append((batch.copy(), result))
503+
504+
# Process the results of the previous batch but skip if the last batch is extend
505+
if self.last_batch and not self.last_batch_is_extend:
506+
tmp_batch, tmp_result = result_queue.popleft()
507+
self.process_batch_result(tmp_batch, tmp_result)
508+
509+
if batch is None and (
510+
len(self.disagg_decode_transfer_queue.queue)
511+
+ len(self.disagg_decode_prealloc_queue.queue)
512+
== 0
513+
):
514+
# When the server is idle, do self-check and re-init some states
515+
self.check_memory()
516+
self.new_token_ratio = self.init_new_token_ratio
517+
518+
self.last_batch = batch
519+
self.last_batch_is_extend = last_batch_is_extend
520+
478521
def get_next_disagg_decode_batch_to_run(
479522
self: Scheduler,
480523
) -> Optional[Tuple[ScheduleBatch, bool]]:

python/sglang/srt/managers/scheduler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2016,7 +2016,10 @@ def run_scheduler_process(
20162016
elif disaggregation_mode == DisaggregationMode.PREFILL:
20172017
scheduler.event_loop_normal_disagg_prefill()
20182018
elif disaggregation_mode == DisaggregationMode.DECODE:
2019-
scheduler.event_loop_normal_disagg_decode()
2019+
if scheduler.enable_overlap:
2020+
scheduler.event_loop_overlap_disagg_decode()
2021+
else:
2022+
scheduler.event_loop_normal_disagg_decode()
20202023

20212024
except Exception:
20222025
traceback = get_exception_traceback()

python/sglang/srt/server_args.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -387,14 +387,12 @@ def __post_init__(self):
387387
# PD disaggregation
388388
if self.disaggregation_mode == "prefill":
389389
self.disable_cuda_graph = True
390-
logger.warning("KV cache is forced as chunk cache for decode server")
390+
logger.warning("Cuda graph is disabled for prefill server")
391391
self.disable_overlap_schedule = True
392392
logger.warning("Overlap scheduler is disabled for prefill server")
393393
elif self.disaggregation_mode == "decode":
394394
self.disable_radix_cache = True
395-
logger.warning("Cuda graph is disabled for prefill server")
396-
self.disable_overlap_schedule = True
397-
logger.warning("Overlap scheduler is disabled for decode server")
395+
logger.warning("KV cache is forced as chunk cache for decode server")
398396

399397
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
400398
"1" if self.enable_torch_compile else "0"

0 commit comments

Comments
 (0)