Skip to content

Commit 30ea560

Browse files
zhuzilinlifuhuang
authored andcommitted
[RL] add pause and continue generation for async rl training (#7419)
1 parent 98d5ce9 commit 30ea560

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

python/sglang/srt/entrypoints/http_server.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,26 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
712712
return ORJSONResponse(content=response_data, status_code=200)
713713

714714

715+
@app.post("/pause_generation")
716+
async def pause_generation(request: Request):
717+
"""Pause generation."""
718+
await _global_state.tokenizer_manager.pause_generation()
719+
return ORJSONResponse(
720+
content={"message": "Generation paused successfully.", "status": "ok"},
721+
status_code=200,
722+
)
723+
724+
725+
@app.post("/continue_generation")
726+
async def continue_generation(request: Request):
727+
"""Continue generation."""
728+
await _global_state.tokenizer_manager.continue_generation()
729+
return ORJSONResponse(
730+
content={"message": "Generation continued successfully.", "status": "ok"},
731+
status_code=200,
732+
)
733+
734+
715735
##### OpenAI-compatible API endpoints #####
716736

717737

python/sglang/srt/managers/tokenizer_manager.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ def __init__(
203203
self.is_image_gen = self.model_config.is_image_gen
204204
self.context_len = self.model_config.context_len
205205
self.image_token_id = self.model_config.image_token_id
206+
self._updating = False
207+
self._cond = asyncio.Condition()
206208

207209
if self.model_config.is_multimodal:
208210
import_processors()
@@ -421,6 +423,9 @@ async def generate_request(
421423
request: Optional[fastapi.Request] = None,
422424
):
423425
created_time = time.time()
426+
async with self._cond:
427+
await self._cond.wait_for(lambda: not self._updating)
428+
424429
self.auto_create_handle_loop()
425430
obj.normalize_batch_and_arguments()
426431

@@ -908,6 +913,16 @@ async def dump_expert_distribution_record(self):
908913
self.auto_create_handle_loop()
909914
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
910915

916+
async def pause_generation(self):
917+
async with self._cond:
918+
self._updating = True
919+
self.abort_request(abort_all=True)
920+
921+
async def continue_generation(self):
922+
async with self._cond:
923+
self._updating = False
924+
self._cond.notify_all()
925+
911926
async def update_weights_from_disk(
912927
self,
913928
obj: UpdateWeightFromDiskReqInput,

0 commit comments

Comments
 (0)