Skip to content

Commit 95de7df

Browse files
committed
replace extend batch with idle batch
1 parent ec3675c commit 95de7df

File tree

2 files changed

+20
-16
lines changed

2 files changed

+20
-16
lines changed

python/sglang/srt/managers/data_parallel_controller.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
import setproctitle
2424
import zmq
2525

26-
from python.sglang.srt.disaggregation.utils import DisaggregationMode
27-
from python.sglang.srt.managers.schedule_batch import Req
26+
from sglang.srt.disaggregation.utils import DisaggregationMode
27+
from sglang.srt.managers.schedule_batch import Req
2828
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
2929
from sglang.srt.managers.io_struct import (
3030
TokenizedEmbeddingReqInput,

python/sglang/srt/managers/scheduler.py

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -710,24 +710,28 @@ def event_loop_normal_disagg_decode(self):
710710
self.process_decode_queue()
711711
batch = self.get_next_disagg_decode_batch_to_run()
712712

713+
extend_batch = None
714+
if batch and batch.forward_mode.is_extend():
715+
extend_batch = batch
716+
batch = None
717+
713718
# Handle DP attention
714719
if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
715720
batch, _ = self.prepare_dp_attn_batch(batch)
716-
717-
self.cur_batch = batch
721+
722+
self.cur_batch = extend_batch if extend_batch else batch
723+
724+
# Generate fake extend output.
725+
if extend_batch:
726+
# Note: Logprobs should be handled on the prefill engine.
727+
# FIXME: stream_output
728+
self.stream_output(
729+
extend_batch.reqs, False
730+
)
718731

719732
if batch:
720-
# Generate fake extend output.
721-
if batch.forward_mode.is_extend():
722-
# Note: Logprobs should be handled on the prefill engine.
723-
self.stream_output(
724-
batch.reqs, [False for _ in range(len(batch.reqs))]
725-
)
726-
result = self.run_batch(batch)
727-
self.process_batch_result(batch, result)
728-
else:
729-
result = self.run_batch(batch)
730-
self.process_batch_result(batch, result)
733+
result = self.run_batch(batch)
734+
self.process_batch_result(batch, result)
731735

732736
if batch is None and (
733737
len(self.disagg_decode_transfer_queue.queue)
@@ -738,7 +742,7 @@ def event_loop_normal_disagg_decode(self):
738742
self.check_memory()
739743
self.new_token_ratio = self.init_new_token_ratio
740744

741-
self.last_batch = batch
745+
self.last_batch = extend_batch if extend_batch else batch
742746

743747
def recv_requests(self) -> List[Req]:
744748
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""

0 commit comments

Comments
 (0)