Skip to content

[RL] add pause and continue generation for async rl training #7419

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 5, 2025
20 changes: 20 additions & 0 deletions python/sglang/srt/entrypoints/http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -712,6 +712,26 @@ async def separate_reasoning_request(obj: SeparateReasoningReqInput, request: Re
return ORJSONResponse(content=response_data, status_code=200)


@app.post("/pause_generation")
async def pause_generation(request: Request):
"""Pause generation."""
await _global_state.tokenizer_manager.pause_generation()
return ORJSONResponse(
content={"message": "Generation paused successfully.", "status": "ok"},
status_code=200,
)


@app.post("/continue_generation")
async def continue_generation(request: Request):
"""Continue generation."""
await _global_state.tokenizer_manager.continue_generation()
return ORJSONResponse(
content={"message": "Generation continued successfully.", "status": "ok"},
status_code=200,
)


##### OpenAI-compatible API endpoints #####


Expand Down
15 changes: 15 additions & 0 deletions python/sglang/srt/managers/tokenizer_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,8 @@ def __init__(
self.is_image_gen = self.model_config.is_image_gen
self.context_len = self.model_config.context_len
self.image_token_id = self.model_config.image_token_id
self._updating = False
self._cond = asyncio.Condition()

if self.model_config.is_multimodal:
import_processors()
Expand Down Expand Up @@ -421,6 +423,9 @@ async def generate_request(
request: Optional[fastapi.Request] = None,
):
created_time = time.time()
async with self._cond:
await self._cond.wait_for(lambda: not self._updating)

self.auto_create_handle_loop()
obj.normalize_batch_and_arguments()

Expand Down Expand Up @@ -902,6 +907,16 @@ async def dump_expert_distribution_record(self):
self.auto_create_handle_loop()
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)

async def pause_generation(self):
async with self._cond:
self._updating = True
self.abort_request(abort_all=True)

async def continue_generation(self):
async with self._cond:
self._updating = False
self._cond.notify_all()

async def update_weights_from_disk(
self,
obj: UpdateWeightFromDiskReqInput,
Expand Down
Loading