Skip to content

Commit 2c0b84a

Browse files
ByronHsuLayssy
authored andcommitted
[PD] Support structured output (sgl-project#6560)
1 parent 1d5bfa2 commit 2c0b84a

File tree

6 files changed

+106
-13
lines changed

6 files changed

+106
-13
lines changed

python/sglang/srt/disaggregation/decode.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,19 +45,16 @@
4545
poll_and_all_reduce,
4646
prepare_abort,
4747
)
48-
from sglang.srt.managers.schedule_batch import FINISH_ABORT
48+
from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
4949
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
5050
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
51-
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
52-
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
51+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
5352

5453
logger = logging.getLogger(__name__)
5554

5655
if TYPE_CHECKING:
57-
from sglang.srt.configs.model_config import ModelConfig
58-
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
56+
from sglang.srt.managers.schedule_batch import Req
5957
from sglang.srt.managers.scheduler import Scheduler
60-
from sglang.srt.server_args import ServerArgs
6158

6259

6360
@dataclass
@@ -531,7 +528,18 @@ def event_loop_overlap_disagg_decode(self: Scheduler):
531528
self.prepare_dp_attn_batch(batch)
532529
result = self.run_batch(batch)
533530
result_queue.append((batch.copy(), result))
531+
532+
if (self.last_batch is None) or (not self.last_batch_in_queue):
533+
# Create a dummy first batch to start the pipeline for overlap schedule.
534+
# It is now used for triggering the sampling_info_done event.
535+
tmp_batch = ScheduleBatch(
536+
reqs=None,
537+
forward_mode=ForwardMode.DUMMY_FIRST,
538+
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
539+
)
540+
self.set_next_batch_sampling_info_done(tmp_batch)
534541
last_batch_in_queue = True
542+
535543
elif prepare_dp_attn_flag:
536544
batch, result = self._prepare_idle_batch_and_run(
537545
None, delay_process=True
@@ -543,6 +551,9 @@ def event_loop_overlap_disagg_decode(self: Scheduler):
543551
# Process the results of the previous batch but skip if the last batch is extend
544552
if self.last_batch and self.last_batch_in_queue:
545553
tmp_batch, tmp_result = result_queue.popleft()
554+
tmp_batch.next_batch_sampling_info = (
555+
self.tp_worker.cur_sampling_info if batch else None
556+
)
546557
self.process_batch_result(tmp_batch, tmp_result)
547558

548559
if batch is None and (
@@ -591,6 +602,9 @@ def get_next_disagg_decode_batch_to_run(
591602

592603
def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
593604
"""Create a schedulebatch for fake completed prefill"""
605+
if self.grammar_queue:
606+
self.move_ready_grammar_requests()
607+
594608
if len(self.waiting_queue) == 0:
595609
return None
596610

@@ -616,8 +630,6 @@ def get_new_prebuilt_batch(self: Scheduler) -> Optional[ScheduleBatch]:
616630
self.waiting_queue = waiting_queue
617631
if len(can_run_list) == 0:
618632
return None
619-
# local import to avoid circular import
620-
from sglang.srt.managers.schedule_batch import ScheduleBatch
621633

622634
# construct a schedule batch with those requests and mark as decode
623635
new_batch = ScheduleBatch.init_new(

python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ def process_prebuilt_extend(
101101
for req in self.reqs:
102102
self.output_ids.append(req.output_ids[-1])
103103
self.tree_cache.cache_unfinished_req(req)
104+
if req.grammar is not None:
105+
req.grammar.accept_token(req.output_ids[-1])
106+
req.grammar.finished = req.finished()
104107
self.output_ids = torch.tensor(self.output_ids, device=self.device)
105108

106109
# Simulate the eagle run. We add mock data to hidden states for the

python/sglang/srt/disaggregation/prefill.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
prepare_abort,
4444
)
4545
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
46+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
4647

4748
if TYPE_CHECKING:
4849
from torch.distributed import ProcessGroup
@@ -143,6 +144,10 @@ def add(self, req: Req) -> None:
143144
self._process_req(req)
144145
self.queue.append(req)
145146

147+
def extend(self, reqs: List[Req]) -> None:
148+
for req in reqs:
149+
self.add(req)
150+
146151
def _process_req(self, req: Req) -> None:
147152
"""
148153
Set max_new_tokens = 1, so PrefillAdder memory estimation is accurate
@@ -269,6 +274,16 @@ def event_loop_overlap_disagg_prefill(self: Scheduler):
269274
result = self.run_batch(batch)
270275
self.result_queue.append((batch.copy(), result))
271276

277+
if self.last_batch is None:
278+
# Create a dummy first batch to start the pipeline for overlap schedule.
279+
# It is now used for triggering the sampling_info_done event.
280+
tmp_batch = ScheduleBatch(
281+
reqs=None,
282+
forward_mode=ForwardMode.DUMMY_FIRST,
283+
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
284+
)
285+
self.set_next_batch_sampling_info_done(tmp_batch)
286+
272287
if self.last_batch:
273288
tmp_batch, tmp_result = self.result_queue.popleft()
274289
self.process_batch_result_disagg_prefill(tmp_batch, tmp_result)

python/sglang/srt/managers/scheduler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1087,8 +1087,11 @@ def _add_request_to_queue(self, req: Req):
10871087
else:
10881088
self.waiting_queue.append(req)
10891089

1090-
def _extend_requests_to_queue(self, reqs: List[Req], is_retracted: bool = False):
1091-
if self.disaggregation_mode == DisaggregationMode.DECODE:
1090+
def _extend_requests_to_queue(self, reqs: List[Req]):
1091+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
1092+
self.disagg_prefill_bootstrap_queue.extend(reqs)
1093+
elif self.disaggregation_mode == DisaggregationMode.DECODE:
1094+
# If this is a decode server, we put the request to the decode pending prealloc queue
10921095
self.disagg_decode_prealloc_queue.extend(reqs)
10931096
else:
10941097
self.waiting_queue.extend(reqs)
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import json
2+
3+
import requests
4+
5+
port = 8000
6+
7+
json_schema = json.dumps(
8+
{
9+
"type": "object",
10+
"properties": {
11+
"name": {"type": "string", "pattern": "^[\\w]+$"},
12+
"population": {"type": "integer"},
13+
},
14+
"required": ["name", "population"],
15+
}
16+
)
17+
18+
# JSON
19+
response = requests.post(
20+
f"http://localhost:{port}/generate",
21+
json={
22+
"text": "Here is the information of the capital of France in the JSON format.\n",
23+
"sampling_params": {
24+
"temperature": 0,
25+
"max_new_tokens": 64,
26+
"json_schema": json_schema,
27+
},
28+
},
29+
)
30+
31+
print(response.json())
32+
33+
34+
# python3 -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --trust-remote-code --disaggregation-mode prefill --tp 2 --disaggregation-ib-device mlx5_roce0,mlx5_roce1 --speculative-algorithm EAGLE --speculative-draft-model-path lmsys/sglang-EAGLE-llama2-chat-7B --speculative-num-steps 3 --speculative-eagle-topk 4 --speculative-num-draft-tokens 16 --cuda-graph-max-bs 8 --host 127.0.0.1 --port 8100

test/srt/test_disaggregation.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
import os
23
import subprocess
34
import time
@@ -17,12 +18,9 @@
1718
DEFAULT_URL_FOR_TEST,
1819
CustomTestCase,
1920
popen_launch_pd_server,
20-
run_with_timeout,
2121
)
2222

2323

24-
# skip the test because we have different_tp test
25-
@unittest.skip("skip the test because we have different_tp test")
2624
class TestDisaggregationAccuracy(CustomTestCase):
2725
@classmethod
2826
def setUpClass(cls):
@@ -172,6 +170,34 @@ def test_logprob(self):
172170
len(input_logprobs) > 0
173171
), f"input_logprobs should have at least one token, but got {len(input_logprobs)}"
174172

173+
def test_structured_output(self):
174+
json_schema = json.dumps(
175+
{
176+
"type": "object",
177+
"properties": {
178+
"name": {"type": "string", "pattern": "^[\\w]+$"},
179+
"population": {"type": "integer"},
180+
},
181+
"required": ["name", "population"],
182+
}
183+
)
184+
185+
# JSON
186+
response = requests.post(
187+
f"{self.lb_url}/generate",
188+
json={
189+
"text": "Here is the information of the capital of France in the JSON format.\n",
190+
"sampling_params": {
191+
"temperature": 0,
192+
"max_new_tokens": 64,
193+
"json_schema": json_schema,
194+
},
195+
},
196+
)
197+
output = response.json()["text"]
198+
# ensure the output is a valid JSON
199+
json.loads(output)
200+
175201

176202
class TestDisaggregationMooncakeFailure(CustomTestCase):
177203
@classmethod

0 commit comments

Comments
 (0)