Skip to content

[rollout] fix: sglang async fail with Multi-stage Awake feature #2365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/e2e_ppo_trainer.yml
Original file line number Diff line number Diff line change
Expand Up @@ -292,13 +292,13 @@ jobs:
- name: Running GSM8K E2E training tests on sglang async
run: |
ray stop --force
ENGINE=sglang ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh
TOTAL_TRAIN_STEPS=2 ENGINE=sglang ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh
- name: Running GSM8K E2E training tests on vllm async
run: |
ray stop --force
export VLLM_USE_V1=1
ray start --head
ENGINE=vllm ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh
TOTAL_TRAIN_STEPS=2 ENGINE=vllm ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh
e2e_ppo_trainer_sglang_multiturn_with_tool:
runs-on: [L20x8]
Expand Down
7 changes: 6 additions & 1 deletion verl/workers/megatron_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import logging
import os
import time
from typing import Optional, Union
from typing import Any, Dict, List, Optional, Union

import psutil
import torch
Expand Down Expand Up @@ -692,6 +692,11 @@ async def chat_completion(self, json_request):
ret = await self.rollout.chat_completion(json_request)
return ret

@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)
async def generate(self, prompt_ids: List[int], sampling_params: Dict[str, Any], request_id: str) -> List[int]:
ret = await self.rollout.generate(prompt_ids, sampling_params, request_id)
return ret

@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
async def wake_up(self):
if self.config.rollout.free_cache_engine:
Expand Down
72 changes: 16 additions & 56 deletions verl/workers/sharding_manager/fsdp_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
from verl.protocol import all_gather_data_proto
from verl.utils.device import get_device_id, get_torch_device
from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu
from verl.utils.model import convert_weight_keys
from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer
from verl.utils.torch_functional import check_device_is_available

Expand Down Expand Up @@ -101,65 +100,13 @@ def __init__(
def __enter__(self):
self.timing = {}
with simple_timer("reshard", self.timing):
get_torch_device().empty_cache()

loop = asyncio.get_event_loop()

if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
if self.multi_stage_wake_up:
loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["weights"]))
log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger)
else:
loop.run_until_complete(self.inference_engine.resume_memory_occupation())
log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger)
get_torch_device().empty_cache()

log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger)
if self.offload_param:
load_fsdp_model_to_gpu(self.module)
params = self.module.state_dict()
log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger)
device = get_device_id() # used when fsdp2 set cpu_offload_policy
params = {
k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()
}
params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module))
# Copy, not share memory
loop.run_until_complete(self.update_weights(params))
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)

del params
if self.offload_param:
offload_fsdp_model_to_cpu(self.module)
get_torch_device().empty_cache()
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger)

if self.multi_stage_wake_up and self.rollout_config.free_cache_engine:
loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["kv_cache"]))
log_gpu_memory_usage("After resume SGLang kv_cache in sharding manager", logger=logger)

# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
self.torch_random_states = get_torch_device().get_rng_state()
get_torch_device().set_rng_state(self.gen_random_states)
loop.run_until_complete(self.wake_up())

@GPUMemoryLogger(role="FSDPSGLangShardingManager exit", logger=logger)
def __exit__(self, exc_type, exc_value, traceback):
if self.rollout_config.free_cache_engine:
log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger)
loop = asyncio.get_event_loop()
loop.run_until_complete(self.release_memory())
log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger)

self.module.train()

# add empty cache after each compute
get_torch_device().empty_cache()

# restore random states
if self.device_mesh is not None:
self.gen_random_states = get_torch_device().get_rng_state()
get_torch_device().set_rng_state(self.torch_random_states)
loop = asyncio.get_event_loop()
loop.run_until_complete(self.sleep())

async def update_weights(self, params):
# Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update
Expand Down Expand Up @@ -207,6 +154,15 @@ async def wake_up(self):
params = {
k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items()
}

if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
if self.multi_stage_wake_up:
await self.inference_engine.resume_memory_occupation(tags=["weights"])
log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger)
else:
await self.inference_engine.resume_memory_occupation()
log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger)
Comment on lines +158 to +164

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This block correctly re-introduces the memory occupation resumption for model weights. This is crucial for preventing the reported "illegal memory access" CUDA errors, as it ensures that the SGLang inference engine has its necessary weights loaded into GPU memory before further operations like update_weights are performed. This change aligns the wake_up method's behavior with the memory management logic already present in the __enter__ method.


# Copy, not share memory
await self.update_weights(params)
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger)
Expand All @@ -217,6 +173,10 @@ async def wake_up(self):
get_torch_device().empty_cache()
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger)

if self.multi_stage_wake_up and self.rollout_config.free_cache_engine:
await self.inference_engine.resume_memory_occupation(tags=["kv_cache"])
log_gpu_memory_usage("After resume SGLang kv_cache in sharding manager", logger=logger)
Comment on lines +176 to +178

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

This addition ensures that the KV cache is also properly resumed into GPU memory when multi_stage_wake_up is enabled. This granular control over memory resumption is vital for the stability and performance of the SGLang engine, especially in scenarios where KV cache might be offloaded. This change is a necessary part of the overall fix for the memory access regression.


# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
self.torch_random_states = get_torch_device().get_rng_state()
Expand Down
47 changes: 10 additions & 37 deletions verl/workers/sharding_manager/megatron_sglang.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,45 +114,13 @@ def __init__(
def __enter__(self):
self.timing = {}
with simple_timer("reshard", self.timing):
if self.offload_param:
load_megatron_model_to_gpu(self.actor_module)
if self.bridge is not None:
per_tensor_param = self.bridge.export_weights(self.actor_module)
else:
per_tensor_param = per_tensor_generator(
self.actor_module,
self.model_config,
self.weight_converter,
self.transformer_config,
self.layer_name_mapping,
)
loop = asyncio.get_event_loop()
loop.run_until_complete(self.update_weights(per_tensor_param))
if self.offload_param:
offload_megatron_model_to_cpu(self.actor_module)
get_torch_device().empty_cache()
# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
self.torch_random_states = get_torch_device().get_rng_state()
get_torch_device().set_rng_state(self.gen_random_states)
loop.run_until_complete(self.wake_up())

@GPUMemoryLogger(role="MegatronSGLangShardingManager exit", logger=logger)
def __exit__(self, exc_type, exc_value, traceback):
if self.rollout_config.free_cache_engine:
log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger)
loop = asyncio.get_event_loop()
loop.run_until_complete(self.release_memory())
log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger)

for model in self.actor_module:
model.train()
# add empty cache after each compute
get_torch_device().empty_cache()

# restore random states
if self.device_mesh is not None:
self.gen_random_states = get_torch_device().get_rng_state()
get_torch_device().set_rng_state(self.torch_random_states)
loop = asyncio.get_event_loop()
loop.run_until_complete(self.sleep())

async def update_weights(self, params):
if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
Expand Down Expand Up @@ -182,8 +150,10 @@ async def release_memory(self):
if self.device_mesh["tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine:
await self.inference_engine.release_memory_occupation()

@GPUMemoryLogger(role="FSDPSGLangShardingManager enter", logger=logger)
@GPUMemoryLogger(role="MegatronSGLangShardingManager enter", logger=logger)
async def wake_up(self):
if self.offload_param:
load_megatron_model_to_gpu(self.actor_module)
if self.bridge is not None:
per_tensor_param = self.bridge.export_weights(self.actor_module)
else:
Expand All @@ -195,12 +165,15 @@ async def wake_up(self):
self.layer_name_mapping,
)
await self.update_weights(per_tensor_param)
if self.offload_param:
offload_megatron_model_to_cpu(self.actor_module)
get_torch_device().empty_cache()
# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
self.torch_random_states = get_torch_device().get_rng_state()
get_torch_device().set_rng_state(self.gen_random_states)

@GPUMemoryLogger(role="FSDPSGLangShardingManager exit", logger=logger)
@GPUMemoryLogger(role="MegatronSGLangShardingManager exit", logger=logger)
async def sleep(self):
if self.rollout_config.free_cache_engine:
log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger)
Expand Down