Skip to content

Improve overlap scheduling #5788

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions python/sglang/srt/disaggregation/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import annotations

import logging
import threading
from collections import deque
from typing import TYPE_CHECKING, List, Optional

Expand Down Expand Up @@ -256,7 +257,10 @@ def event_loop_overlap_disagg_prefill(self: Scheduler):
self.running_batch.batch_is_full = False

def process_batch_result_disagg_prefill(
self: Scheduler, batch: ScheduleBatch, result: GenerationBatchResult
self: Scheduler,
batch: ScheduleBatch,
result: GenerationBatchResult,
launch_done: Optional[threading.Event] = None,
) -> None:
"""
Transfer kv for prefill completed requests and add it into disagg_prefill_inflight_queue
Expand All @@ -280,7 +284,7 @@ def process_batch_result_disagg_prefill(
# Transfer kv for prefill completed requests and add it into disagg_prefill_infight_queue
if self.enable_overlap:
# wait
_, next_token_ids = self.tp_worker.resolve_batch_result(bid)
_, next_token_ids = self.tp_worker.resolve_last_batch_result(launch_done)
else:
next_token_ids = result.next_token_ids.tolist()

Expand Down
8 changes: 8 additions & 0 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import copy
import dataclasses
import logging
import threading
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -724,6 +725,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# This is an optimization to reduce the overhead of the prefill check.
batch_is_full: bool = False

# Events
launch_done: Optional[threading.Event] = None

# Sampling info
sampling_info: SamplingBatchInfo = None
next_batch_sampling_info: SamplingBatchInfo = None
Expand Down Expand Up @@ -1565,6 +1569,7 @@ def get_model_worker_batch(self) -> ModelWorkerBatch:
)
),
extend_input_logprob_token_ids=self.extend_input_logprob_token_ids,
launch_done=self.launch_done,
)

def copy(self):
Expand Down Expand Up @@ -1647,6 +1652,9 @@ class ModelWorkerBatch:
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None

# Overlap event
launch_done: Optional[threading.Event] = None


@triton.jit
def write_req_to_token_pool_triton(
Expand Down
15 changes: 10 additions & 5 deletions python/sglang/srt/managers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,7 @@ def event_loop_overlap(self):
self.cur_batch = batch

if batch:
batch.launch_done = threading.Event()
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))

Expand All @@ -656,15 +657,18 @@ def event_loop_overlap(self):
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.process_batch_result(tmp_batch, None)
self.process_batch_result(tmp_batch, None, batch.launch_done)

if self.last_batch:
# Process the results of the last batch
tmp_batch, tmp_result = self.result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result(tmp_batch, tmp_result)
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
self.process_batch_result(
tmp_batch, tmp_result, batch.launch_done if batch else None
)
elif batch is None:
# When the server is idle, do self-check and re-init some states
self.check_memory()
Expand Down Expand Up @@ -1417,14 +1421,15 @@ def process_batch_result(
self,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
launch_done: Optional[threading.Event] = None,
):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
self.process_batch_result_decode(batch, result, launch_done)
elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result)
self.process_batch_result_prefill(batch, result, launch_done)
elif batch.forward_mode.is_idle():
if self.enable_overlap:
self.tp_worker.resolve_batch_result(result.bid)
self.tp_worker.resolve_last_batch_result(launch_done)
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
Expand Down
34 changes: 25 additions & 9 deletions python/sglang/srt/managers/scheduler_output_processor_mixin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import threading
from typing import TYPE_CHECKING, List, Optional, Tuple, Union

from sglang.srt.layers.logits_processor import LogitsProcessorOutput
Expand All @@ -11,6 +12,7 @@
EmbeddingBatchResult,
GenerationBatchResult,
ScheduleBatch,
Scheduler,
)


Expand All @@ -21,9 +23,10 @@ class SchedulerOutputProcessorMixin:
"""

def process_batch_result_prefill(
self,
self: Scheduler,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
launch_done: Optional[threading.Event] = None,
):
skip_stream_req = None

Expand All @@ -43,7 +46,11 @@ def process_batch_result_prefill(
)

if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
logits_output, next_token_ids = (
self.tp_worker.resolve_last_batch_result(
launch_done,
)
)
else:
# Move next_token_ids and logprobs to cpu
next_token_ids = next_token_ids.tolist()
Expand Down Expand Up @@ -175,9 +182,10 @@ def process_batch_result_prefill(
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)

def process_batch_result_decode(
self,
self: Scheduler,
batch: ScheduleBatch,
result: GenerationBatchResult,
launch_done: Optional[threading.Event] = None,
):
logits_output, next_token_ids, bid = (
result.logits_output,
Expand All @@ -187,7 +195,9 @@ def process_batch_result_decode(
self.num_generated_tokens += len(batch.reqs)

if self.enable_overlap:
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
logits_output, next_token_ids = self.tp_worker.resolve_last_batch_result(
launch_done
)
next_token_logprobs = logits_output.next_token_logprobs
elif batch.spec_algorithm.is_none():
# spec decoding handles output logprobs inside verify process.
Expand Down Expand Up @@ -271,7 +281,7 @@ def process_batch_result_decode(
self.log_decode_stats()

def add_input_logprob_return_values(
self,
self: Scheduler,
i: int,
req: Req,
output: LogitsProcessorOutput,
Expand Down Expand Up @@ -405,7 +415,7 @@ def add_input_logprob_return_values(
assert len(req.input_token_ids_logprobs_idx) == relevant_tokens_len

def add_logprob_return_values(
self,
self: Scheduler,
i: int,
req: Req,
pt: int,
Expand Down Expand Up @@ -436,7 +446,10 @@ def add_logprob_return_values(
return num_input_logprobs

def stream_output(
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
self: Scheduler,
reqs: List[Req],
return_logprob: bool,
skip_req: Optional[Req] = None,
):
"""Stream the output to detokenizer."""
if self.is_generation:
Expand All @@ -445,7 +458,10 @@ def stream_output(
self.stream_output_embedding(reqs)

def stream_output_generation(
self, reqs: List[Req], return_logprob: bool, skip_req: Optional[Req] = None
self: Scheduler,
reqs: List[Req],
return_logprob: bool,
skip_req: Optional[Req] = None,
):
rids = []
finished_reasons: List[BaseFinishReason] = []
Expand Down Expand Up @@ -593,7 +609,7 @@ def stream_output_generation(
)
)

def stream_output_embedding(self, reqs: List[Req]):
def stream_output_embedding(self: Scheduler, reqs: List[Req]):
rids = []
finished_reasons: List[BaseFinishReason] = []

Expand Down
6 changes: 3 additions & 3 deletions python/sglang/srt/managers/tp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,13 @@ def get_memory_pool(self):
def forward_batch_generation(
self,
model_worker_batch: ModelWorkerBatch,
launch_done: Optional[threading.Event] = None,
skip_sample: bool = False,
) -> Tuple[LogitsProcessorOutput, Optional[torch.Tensor]]:
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
logits_output = self.model_runner.forward(forward_batch)
if launch_done:
launch_done.set()

if model_worker_batch.launch_done is not None:
model_worker_batch.launch_done.set()

if skip_sample:
next_token_ids = None
Expand Down
13 changes: 9 additions & 4 deletions python/sglang/srt/managers/tp_worker_overlap_thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,6 @@ def forward_thread_func_(self):
batch_pt += 1

# Create event
self.launch_done = threading.Event()
copy_done = torch.get_device_module(self.device).Event()

# Resolve future tokens in the input
Expand All @@ -141,7 +140,7 @@ def forward_thread_func_(self):

# Run forward
logits_output, next_token_ids = self.worker.forward_batch_generation(
model_worker_batch, self.launch_done
model_worker_batch
)

# Update the future token ids map
Expand All @@ -168,10 +167,16 @@ def forward_thread_func_(self):

self.output_queue.put((copy_done, logits_output, next_token_ids))

def resolve_batch_result(self, bid: int):
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
"""
This function is called to resolve the last batch result and
wait for the current batch to be launched. Used in overlap mode.
"""
copy_done, logits_output, next_token_ids = self.output_queue.get()

if launch_done is not None:
launch_done.wait()
copy_done.synchronize()
self.launch_done.wait()

if logits_output.next_token_logprobs is not None:
logits_output.next_token_logprobs = (
Expand Down
Loading