Skip to content

Commit 04f2428

Browse files
committed
[rollout] fix: sglang megatron backend missing generate function
1 parent 5c39b51 commit 04f2428

File tree

5 files changed

+47
-12
lines changed

5 files changed

+47
-12
lines changed

.github/workflows/sgl.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,8 @@ jobs:
131131
run: |
132132
cd tests/workers/rollout
133133
pytest -s test_sglang_async_rollout_mcp_tools.py
134+
- name: Test the latest SGLang Rollout async with agent loop
135+
run: |
136+
BACKEND=fsdp ROLLOUT_NAME=sglang pytest -svvv tests/experimental/agent_loop/test_basic_agent_loop.py
137+
BACKEND=megatron ROLLOUT_NAME=sglang pytest -svvv tests/experimental/agent_loop/test_basic_agent_loop.py
134138
# Note(haibin.lin): for any new test, please update gpu_unit_tests.yaml to avoid repeated tests

.github/workflows/vllm.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,6 @@ jobs:
121121
- name: Running multi-turn rollout tests on 8 L20 GPUs
122122
run: |
123123
pip3 install --upgrade vllm==0.8.3 tensordict==0.7.2
124-
pytest -svvv tests/workers/rollout/rollout_vllm/test_vllm_chat_scheduler.py
125-
ROLLOUT_NAME=vllm pytest -svvv tests/experimental/agent_loop/test_basic_agent_loop.py
124+
BACKEND=fsdp ROLLOUT_NAME=vllm pytest -svvv tests/experimental/agent_loop/test_basic_agent_loop.py
125+
BACKEND=megatron ROLLOUT_NAME=vllm pytest -svvv tests/experimental/agent_loop/test_basic_agent_loop.py
126126
# Note(haibin.lin): for any new test, please update gpu_unit_tests.yaml to avoid repeated tests

tests/experimental/agent_loop/agent_utils.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,23 @@
2020
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
2121
from verl.single_controller.ray.base import create_colocated_worker_cls
2222
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
23-
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
2423

2524

2625
def init_agent_loop_manager(config: DictConfig) -> Union[AgentLoopManager, RayWorkerGroup]:
2726
# =========================== 1. Create hybrid ActorRollout workers ===========================
28-
actor_rollout_cls = (
29-
AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
30-
)
27+
if config.actor_rollout_ref.actor.strategy in ["fsdp", "fsdp2"]:
28+
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
29+
30+
actor_rollout_cls = (
31+
AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
32+
)
33+
elif config.actor_rollout_ref.actor.strategy == "megatron":
34+
from verl.workers.megatron_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker
35+
36+
actor_rollout_cls = (
37+
AsyncActorRolloutRefWorker if config.actor_rollout_ref.rollout.mode == "async" else ActorRolloutRefWorker
38+
)
39+
3140
role_worker_mapping = {
3241
Role.ActorRollout: ray.remote(actor_rollout_cls),
3342
}

tests/experimental/agent_loop/test_basic_agent_loop.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,29 @@
2727
from verl.utils import hf_tokenizer
2828

2929

30+
def init_fsdp_config() -> DictConfig:
31+
config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml")
32+
# test sleep/wake_up with fsdp offload
33+
config.actor_rollout_ref.actor.fsdp_config.param_offload = True
34+
config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
35+
return config
36+
37+
38+
def init_megatron_config() -> DictConfig:
39+
config = OmegaConf.load("verl/trainer/config/ppo_megatron_trainer.yaml")
40+
config.actor_rollout_ref.actor.megatron.tensor_model_parallel_size = 2
41+
config.actor_rollout_ref.actor.megatron.pipeline_model_parallel_size = 2
42+
43+
# FIXME: sglang with megatron param_offload got error:
44+
# "CUDA error: an illegal memory access was encountered"
45+
config.actor_rollout_ref.actor.megatron.param_offload = False
46+
config.actor_rollout_ref.actor.megatron.optimizer_offload = True
47+
return config
48+
49+
3050
@pytest.fixture
3151
def init_config() -> DictConfig:
32-
config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml")
52+
config = init_fsdp_config() if os.getenv("BACKEND", "fsdp") == "fsdp" else init_megatron_config()
3353
model_path = "Qwen/Qwen2.5-1.5B-Instruct"
3454
config.actor_rollout_ref.model.path = model_path
3555
config.actor_rollout_ref.rollout.name = os.getenv("ROLLOUT_NAME", "vllm")
@@ -38,10 +58,7 @@ def init_config() -> DictConfig:
3858
config.actor_rollout_ref.rollout.response_length = 4096
3959
config.actor_rollout_ref.rollout.n = 4
4060
config.actor_rollout_ref.rollout.agent.num_workers = 2
41-
42-
# test sleep/wake_up with fsdp offload
43-
config.actor_rollout_ref.actor.fsdp_config.param_offload = True
44-
config.actor_rollout_ref.actor.fsdp_config.optimizer_offload = True
61+
config.actor_rollout_ref.actor.optim.total_training_steps = 100
4562

4663
return config
4764

verl/workers/megatron_workers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import logging
2020
import os
2121
import time
22-
from typing import Union
22+
from typing import Any, Dict, List, Union
2323

2424
import psutil
2525
import torch
@@ -700,6 +700,11 @@ async def chat_completion(self, json_request):
700700
ret = await self.rollout.chat_completion(json_request)
701701
return ret
702702

703+
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)
704+
async def generate(self, prompt_ids: List[int], sampling_params: Dict[str, Any], request_id: str) -> List[int]:
705+
ret = await self.rollout.generate(prompt_ids, sampling_params, request_id)
706+
return ret
707+
703708
@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD)
704709
async def wake_up(self):
705710
if self.config.rollout.free_cache_engine:

0 commit comments

Comments
 (0)