-
Notifications
You must be signed in to change notification settings - Fork 1.7k
[rollout] fix: sglang async fail with Multi-stage Awake feature #2365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,7 +32,6 @@ | |
from verl.protocol import all_gather_data_proto | ||
from verl.utils.device import get_device_id, get_torch_device | ||
from verl.utils.fsdp_utils import fsdp_version, load_fsdp_model_to_gpu, offload_fsdp_model_to_cpu | ||
from verl.utils.model import convert_weight_keys | ||
from verl.utils.profiler import GPUMemoryLogger, log_gpu_memory_usage, simple_timer | ||
from verl.utils.torch_functional import check_device_is_available | ||
|
||
|
@@ -101,65 +100,13 @@ def __init__( | |
def __enter__(self): | ||
self.timing = {} | ||
with simple_timer("reshard", self.timing): | ||
get_torch_device().empty_cache() | ||
|
||
loop = asyncio.get_event_loop() | ||
|
||
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: | ||
if self.multi_stage_wake_up: | ||
loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["weights"])) | ||
log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger) | ||
else: | ||
loop.run_until_complete(self.inference_engine.resume_memory_occupation()) | ||
log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger) | ||
get_torch_device().empty_cache() | ||
|
||
log_gpu_memory_usage("Before state_dict() in sharding manager memory", logger=logger) | ||
if self.offload_param: | ||
load_fsdp_model_to_gpu(self.module) | ||
params = self.module.state_dict() | ||
log_gpu_memory_usage("After state_dict() in sharding manager memory", logger=logger) | ||
device = get_device_id() # used when fsdp2 set cpu_offload_policy | ||
params = { | ||
k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items() | ||
} | ||
params = convert_weight_keys(params, getattr(self.module, "_fsdp_wrapped_module", self.module)) | ||
# Copy, not share memory | ||
loop.run_until_complete(self.update_weights(params)) | ||
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) | ||
|
||
del params | ||
if self.offload_param: | ||
offload_fsdp_model_to_cpu(self.module) | ||
get_torch_device().empty_cache() | ||
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) | ||
|
||
if self.multi_stage_wake_up and self.rollout_config.free_cache_engine: | ||
loop.run_until_complete(self.inference_engine.resume_memory_occupation(tags=["kv_cache"])) | ||
log_gpu_memory_usage("After resume SGLang kv_cache in sharding manager", logger=logger) | ||
|
||
# important: need to manually set the random states of each tp to be identical. | ||
if self.device_mesh is not None: | ||
self.torch_random_states = get_torch_device().get_rng_state() | ||
get_torch_device().set_rng_state(self.gen_random_states) | ||
loop.run_until_complete(self.wake_up()) | ||
|
||
@GPUMemoryLogger(role="FSDPSGLangShardingManager exit", logger=logger) | ||
def __exit__(self, exc_type, exc_value, traceback): | ||
if self.rollout_config.free_cache_engine: | ||
log_gpu_memory_usage("Before SGLang offload in sharding manager", logger=logger) | ||
loop = asyncio.get_event_loop() | ||
loop.run_until_complete(self.release_memory()) | ||
log_gpu_memory_usage("After SGLang offload in sharding manager", logger=logger) | ||
|
||
self.module.train() | ||
|
||
# add empty cache after each compute | ||
get_torch_device().empty_cache() | ||
|
||
# restore random states | ||
if self.device_mesh is not None: | ||
self.gen_random_states = get_torch_device().get_rng_state() | ||
get_torch_device().set_rng_state(self.torch_random_states) | ||
loop = asyncio.get_event_loop() | ||
loop.run_until_complete(self.sleep()) | ||
|
||
async def update_weights(self, params): | ||
# Most naive implementation, can optimize a lot if it is bottleneck from sglang Engine weight update | ||
|
@@ -207,6 +154,15 @@ async def wake_up(self): | |
params = { | ||
k: v.to(device, non_blocking=True) if fsdp_version(self.module) == 2 else v for k, v in params.items() | ||
} | ||
|
||
if self.device_mesh["infer_tp"].get_local_rank() == 0 and self.rollout_config.free_cache_engine: | ||
if self.multi_stage_wake_up: | ||
await self.inference_engine.resume_memory_occupation(tags=["weights"]) | ||
log_gpu_memory_usage("Before resume SGLang weights in sharding manager", logger=logger) | ||
else: | ||
await self.inference_engine.resume_memory_occupation() | ||
log_gpu_memory_usage("Before resume SGLang weights + kv_cache in sharding manager", logger=logger) | ||
Comment on lines
+158
to
+164
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This block correctly re-introduces the memory occupation resumption for model weights. This is crucial for preventing the reported "illegal memory access" CUDA errors, as it ensures that the SGLang inference engine has its necessary weights loaded into GPU memory before further operations like
chenhaiq marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Copy, not share memory | ||
await self.update_weights(params) | ||
log_gpu_memory_usage("After sync model weights in sharding manager", logger=logger) | ||
|
@@ -217,6 +173,10 @@ async def wake_up(self): | |
get_torch_device().empty_cache() | ||
log_gpu_memory_usage("After del state_dict and empty_cache in sharding manager", logger=logger) | ||
|
||
if self.multi_stage_wake_up and self.rollout_config.free_cache_engine: | ||
await self.inference_engine.resume_memory_occupation(tags=["kv_cache"]) | ||
log_gpu_memory_usage("After resume SGLang kv_cache in sharding manager", logger=logger) | ||
Comment on lines
+176
to
+178
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This addition ensures that the KV cache is also properly resumed into GPU memory when |
||
|
||
# important: need to manually set the random states of each tp to be identical. | ||
if self.device_mesh is not None: | ||
self.torch_random_states = get_torch_device().get_rng_state() | ||
|
Uh oh!
There was an error while loading. Please reload this page.