Skip to content

Conversation

@chenhaiq
Copy link
Collaborator

@chenhaiq chenhaiq commented May 30, 2025

Changed sglang rollout pipeline to async method to have better performance.

resolved issue #1721

Checklist Before Starting

  • [ done ] Search for similar PR(s).

What does this PR do?

In previous version, the sglang async_generate is called with a sync ray actor with lots of sync functions, and resulted poor performance ( GPU SM is 20% in TP2)

This PR changed the while pipeline to async method.

Performance comparsion to previous "sglang_async" mode:

  sglang_async (old) async (new) % faster
timing_s/gen 95 25 73.68%
timing_s/step 170 90 47.06%
perf/throughput 2700 4000 48.15%

High-Level Design

see #1698

This is a follow up task from above PR.

Usage Example

examples/grpo_trainer/run_qwen2-7b_seq_balance.sh

Test

.github/workflows/e2e_ppo_trainer.yml

Additional Info.

Checklist Before Submitting

  • [ done ] Read the Contribute Guide.
  • [ done ] Apply pre-commit checks.
  • [ done ] Add [BREAKING] to the PR title if it breaks any API.
  • [ done ] Update the documentation about your changes in the docs.
  • [ done ] Add CI test(s) if necessary.

Changed sglang rollout pipeline to async method to have better
performance.

resolved issue volcengine#1721
@chenhaiq chenhaiq requested review from SwordFaith and wuxibin89 May 30, 2025 05:32
@chenhaiq
Copy link
Collaborator Author

@ocss884 I have to add lots of patch function to sglang in order to get it works.

You can see those code in verl/workers/rollout/sglang_rollout/async_sglang_rollout.py

# patch to avoid issue https://github.com/sgl-project/sglang/issues/6723
def _set_envs_and_config(server_args: ServerArgs):

Do you have a better solution?

@chenhaiq
Copy link
Collaborator Author

Found 2 bugs,

  1. eos token has no effect in async_generate
  2. ray actor is incorrect in multi node env

I am working on those 2 errors.

@chenhaiq chenhaiq changed the title [rollout] feat: ChatScheduler requests sglang fully async [DONOTMERGE][rollout] feat: ChatScheduler requests sglang fully async May 30, 2025
@chenhaiq
Copy link
Collaborator Author

chenhaiq commented Jun 3, 2025

Found 2 bugs,

  1. eos token has no effect in async_generate
  2. ray actor is incorrect in multi node env

I am working on those 2 errors.

The above 2 bugs have been fixed.

@chenhaiq chenhaiq changed the title [DONOTMERGE][rollout] feat: ChatScheduler requests sglang fully async [rollout] feat: ChatScheduler requests sglang fully async Jun 3, 2025
@tongyx361
Copy link
Collaborator

The motivation and the experiment look good to me! I can help review if needed.

@chenhaiq
Copy link
Collaborator Author

chenhaiq commented Jun 3, 2025

The motivation and the experiment look good to me! I can help review if needed.

Yes,please review.

@chenhaiq chenhaiq requested a review from tongyx361 June 3, 2025 12:03
def mock_get_actor_side_effect(name, namespace=None): # Changed 'actor_name_arg' to 'name'
# Create a new mock actor for each call
actor_mock = MagicMock()
actor_mock.execute_method.remote = AsyncMock(return_value={"content": "mocked response"})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears that execute_method has been superseded by chat_completion.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears that execute_method has been superseded by chat_completion.

updated the testcase to remove the not used mock

print(f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] execute_method: {method if isinstance(method, str) else 'Callable'}")
return self.rollout.execute_method(method, *args, **kwargs)

@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to use the rollout API directly instead of adding a glue layer at the FSDP worker level?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to use the rollout API directly instead of adding a glue layer at the FSDP worker level?

I have tried this approach, but have not success yet, because sglang need tp start a process for each TP worker, otherwise it stacks at " params = self.module.state_dict()".

Do you know how to get it work?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After investigation, it seems the base Worker class includes a routing mechanism through a registration system, which may hinder rollout from supporting the dispatch method via registration. Therefore, the current logic is essential.

current_rank = int(matched_actor["name"].split(":")[-1])
fields = matched_actor["name"].split(":")
assert len(fields) == 2, f"invalid actor name: {matched_actor['name']}"
pg_index, local_rank = int(fields[0].split("_")[-1]), int(fields[1])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the current implementation support cross-node inference?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I have tested this case.

@chenhaiq chenhaiq force-pushed the async_llm_sglang2 branch from 938ba9d to 1fa47c3 Compare June 4, 2025 02:03
Copy link
Collaborator

@SwordFaith SwordFaith left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Excited to dive into exploring interactions with the tool registry and the locality-aware tool, especially with the faster rollout in place!

print(f"[DP={self.vllm_dp_rank},TP={self.vllm_tp_rank}] execute_method: {method if isinstance(method, str) else 'Callable'}")
return self.rollout.execute_method(method, *args, **kwargs)

@register(dispatch_mode=Dispatch.DIRECT_ROLLOUT_METHOD, blocking=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After investigation, it seems the base Worker class includes a routing mechanism through a registration system, which may hinder rollout from supporting the dispatch method via registration. Therefore, the current logic is essential.

@tongyx361 tongyx361 merged commit fe23634 into volcengine:main Jun 6, 2025
36 checks passed
yellowbee686 pushed a commit to yellowbee686/verl that referenced this pull request Jun 6, 2025
…#1769)

Changed sglang rollout pipeline to async method to have better
performance.

resolved issue volcengine#1721

### Checklist Before Starting

- [ done ] Search for similar PR(s).

### What does this PR do?

In previous version, the sglang async_generate is called with a sync ray
actor with lots of sync functions, and resulted poor performance ( GPU
SM is 20% in TP2)

This PR changed the while pipeline to async method. 

Performance comparsion to previous "sglang_async" mode:
  | sglang_async (old) | async (new) | % faster
-- | -- | -- | --
timing_s/gen | 95 | 25 | 73.68%
timing_s/step | 170 | 90 | 47.06%
perf/throughput | 2700 | 4000 | 48.15%

### High-Level Design

see volcengine#1698

This is a follow up task from above PR.


### Usage Example

examples/grpo_trainer/run_qwen2-7b_seq_balance.sh

### Test

.github/workflows/e2e_ppo_trainer.yml

### Additional Info.

- **Issue Number**: Fixes issue volcengine#1721

### Checklist Before Submitting

- [ done ] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [ done ] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [ done ] Add `[BREAKING]` to the PR title if it breaks any API.
- [ done ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [ done ] Add CI test(s) if necessary.
@yuleiqin
Copy link

yuleiqin commented Jun 9, 2025

WIth this async implementation, I find that an error would immediately occur at the beginning of training (step 1) right after the rollout process.
see verl/verl/trainer/ppo/ray_trainer.py
@chenhaiq @SwordFaith

batch = batch.union(gen_batch_output)

The union operation meets errors with inconsistent batchsize

AssertionError: Two tensor dict must have identical batch size. Got torch.Size([8192]) and torch.Size([1024])
[36m(TaskRunner pid=59667)[0m >>>>>>>>>>>>>>>>>>>>>>>>>>gen_batch_output: fit<<<<<<<<<<<<<<<<<<<<<<<<<<
Error executing job with overrides: ['algorithm.adv_estimator=grpo', 'data.train_batch_size=128', 'data.max_prompt_length=16384', 'data.max_response_length=4096', 'data.filter_overlong_prompts=False', 'data.truncation=error', 'data.return_raw_chat=True', 'actor_rollout_ref.model.path=/cfs/yuleiqin/models/Qwen2.5-32B-Instruct_Qwen', 'actor_rollout_ref.actor.optim.lr=1e-6', 'actor_rollout_ref.model.use_remove_padding=True', 'actor_rollout_ref.actor.ppo_mini_batch_size=4', 'actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1', 'actor_rollout_ref.actor.use_kl_loss=True', 'actor_rollout_ref.actor.kl_loss_coef=0.001', 'actor_rollout_ref.actor.kl_loss_type=low_var_kl', 'actor_rollout_ref.actor.entropy_coeff=0', 'actor_rollout_ref.model.enable_gradient_checkpointing=True', 'actor_rollout_ref.actor.fsdp_config.param_offload=False', 'actor_rollout_ref.actor.fsdp_config.optimizer_offload=False', 'actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32', 'actor_rollout_ref.rollout.tensor_model_parallel_size=4', 'actor_rollout_ref.rollout.name=sglang', 'actor_rollout_ref.rollout.gpu_memory_utilization=0.6', 'actor_rollout_ref.rollout.n=8', 'actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32', 'actor_rollout_ref.ref.fsdp_config.param_offload=True', 'algorithm.use_kl_in_reward=False', 'trainer.val_before_train=False', 'trainer.critic_warmup=0', 'trainer.logger=[console,wandb]', 'trainer.project_name=retool_async_rl', 'trainer.experiment_name=Qwen2.5-32b_function-retool-dapomath-async-sgl-multi-w-tool-code-n8-32gpu_tp4_20250606_231400', 'trainer.n_gpus_per_node=8', 'trainer.nnodes=4', 'trainer.save_freq=5', 'trainer.test_freq=20', 'trainer.total_epochs=20', 'trainer.val_only=False', 'trainer.val_before_train=False', 'custom_reward_function.path=/cfs_turbo/yuleiqin/VERL/verl/examples/grpo_trainer_tool/rewards/reward_score.py', 'custom_reward_function.name=default_compute_score', 'reward_model.reward_manager=agent', 'data.train_files=/cfs_turbo/yuleiqin/VERL/DATA/retool_dapo_math/train.parquet', 'data.val_files=/cfs_turbo/yuleiqin/VERL/DATA/retool_dapo_math/test.parquet', 'actor_rollout_ref.rollout.multi_turn.tool_config_path=/cfs_turbo/yuleiqin/VERL/verl/examples/sglang_multiturn/config/tool_config/python_tool_config.yaml']
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
File "/cfs_turbo/yuleiqin/VERL/verl/verl/trainer/main_ppo.py", line 225, in <module>
    main()
File "/usr/local/lib/python3.10/dist-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 457, in _run_app
    run_and_report(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 223, in run_and_report
    raise ex
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 220, in run_and_report
    return func()
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
File "/cfs_turbo/yuleiqin/VERL/verl/verl/trainer/main_ppo.py", line 27, in main
    run_ppo(config)
File "/cfs_turbo/yuleiqin/VERL/verl/verl/trainer/main_ppo.py", line 42, in run_ppo
    ray.get(runner.run.remote(config))
File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2822, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 930, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AssertionError): [36mray::TaskRunner.run()[39m (pid=59667, ip=train-1337910252737936640-a60e1gtpdlvk-master-0, actor_id=aac7e2d4fc0f32314a7ae70002000000, repr=<main_ppo.TaskRunner object at 0x7f233abc4070>)
File "/cfs_turbo/yuleiqin/VERL/verl/verl/trainer/main_ppo.py", line 162, in run
    trainer.fit()
File "/cfs_turbo/yuleiqin/VERL/verl/verl/trainer/ppo/ray_trainer.py", line 1125, in fit
    **batch = batch.union(gen_batch_output)**
File "/cfs_turbo/yuleiqin/VERL/verl/verl/protocol.py", line 588, in union
    self.batch = union_tensor_dict(self.batch, other.batch)
File "/cfs_turbo/yuleiqin/VERL/verl/verl/protocol.py", line 106, in union_tensor_dict
assert tensor_dict1.batch_size == tensor_dict2.batch_size, f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}"
AssertionError: Two tensor dict must have identical batch size. Got torch.Size([8192]) and torch.Size([1024])
[36m(WorkerDict pid=1360, ip=train-1337910252737936640-a60e1gtpdlvk-worker-1)[0m /cfs_turbo/yuleiqin/VERL/verl/verl/workers/rollout/sglang_rollout/utils.py:49: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.)[32m [repeated 7x across cluster][0m
[36m(WorkerDict pid=1360, ip=train-1337910252737936640-a60e1gtpdlvk-worker-1)[0m   tensor_data = torch.ByteTensor([32m [repeated 7x across cluster][0m
2025-06-09 01:02:53,464	ERR cli.py:73 -- [31m---------------------------------------[39m
2025-06-09 01:02:53,464	ERR cli.py:74 -- [31mJob 'raysubmit_TDRFGkdGcXw68Tx2' failed[39m
2025-06-09 01:02:53,464	ERR cli.py:75 -- [31m---------------------------------------[39m
2025-06-09 01:02:53,464	INFO cli.py:88 -- Status message: Job entrypoint command failed with exit code 1, last available logs (truncated to 20,000 chars):
    trainer.fit()
File "/cfs_turbo/yuleiqin/VERL/verl/verl/trainer/ppo/ray_trainer.py", line 1125, in fit
    batch = batch.union(gen_batch_output)
File "/cfs_turbo/yuleiqin/VERL/verl/verl/protocol.py", line 588, in union
    self.batch = union_tensor_dict(self.batch, other.batch)
File "/cfs_turbo/yuleiqin/VERL/verl/verl/protocol.py", line 106, in union_tensor_dict
assert tensor_dict1.batch_size == tensor_dict2.batch_size, f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}"
AssertionError: Two tensor dict must have identical batch size. Got torch.Size([8192]) and torch.Size([1024])

@chenhaiq
Copy link
Collaborator Author

chenhaiq commented Jun 9, 2025

WIth this async implementation, I find that an error would immediately occur at the beginning of training (step 1) right after the rollout process. see verl/verl/trainer/ppo/ray_trainer.py @chenhaiq @SwordFaith

batch = batch.union(gen_batch_output)

The union operation meets errors with inconsistent batchsize

AssertionError: Two tensor dict must have identical batch size. Got torch.Size([8192]) and torch.Size([1024])
[36m(TaskRunner pid=59667)[0m >>>>>>>>>>>>>>>>>>>>>>>>>>gen_batch_output: fit<<<<<<<<<<<<<<<<<<<<<<<<<<
Error executing job with overrides: ['algorithm.adv_estimator=grpo', 'data.train_batch_size=128', 'data.max_prompt_length=16384', 'data.max_response_length=4096', 'data.filter_overlong_prompts=False', 'data.truncation=error', 'data.return_raw_chat=True', 'actor_rollout_ref.model.path=/cfs/yuleiqin/models/Qwen2.5-32B-Instruct_Qwen', 'actor_rollout_ref.actor.optim.lr=1e-6', 'actor_rollout_ref.model.use_remove_padding=True', 'actor_rollout_ref.actor.ppo_mini_batch_size=4', 'actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1', 'actor_rollout_ref.actor.use_kl_loss=True', 'actor_rollout_ref.actor.kl_loss_coef=0.001', 'actor_rollout_ref.actor.kl_loss_type=low_var_kl', 'actor_rollout_ref.actor.entropy_coeff=0', 'actor_rollout_ref.model.enable_gradient_checkpointing=True', 'actor_rollout_ref.actor.fsdp_config.param_offload=False', 'actor_rollout_ref.actor.fsdp_config.optimizer_offload=False', 'actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32', 'actor_rollout_ref.rollout.tensor_model_parallel_size=4', 'actor_rollout_ref.rollout.name=sglang', 'actor_rollout_ref.rollout.gpu_memory_utilization=0.6', 'actor_rollout_ref.rollout.n=8', 'actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32', 'actor_rollout_ref.ref.fsdp_config.param_offload=True', 'algorithm.use_kl_in_reward=False', 'trainer.val_before_train=False', 'trainer.critic_warmup=0', 'trainer.logger=[console,wandb]', 'trainer.project_name=retool_async_rl', 'trainer.experiment_name=Qwen2.5-32b_function-retool-dapomath-async-sgl-multi-w-tool-code-n8-32gpu_tp4_20250606_231400', 'trainer.n_gpus_per_node=8', 'trainer.nnodes=4', 'trainer.save_freq=5', 'trainer.test_freq=20', 'trainer.total_epochs=20', 'trainer.val_only=False', 'trainer.val_before_train=False', 'custom_reward_function.path=/cfs_turbo/yuleiqin/VERL/verl/examples/grpo_trainer_tool/rewards/reward_score.py', 'custom_reward_function.name=default_compute_score', 'reward_model.reward_manager=agent', 'data.train_files=/cfs_turbo/yuleiqin/VERL/DATA/retool_dapo_math/train.parquet', 'data.val_files=/cfs_turbo/yuleiqin/VERL/DATA/retool_dapo_math/test.parquet', 'actor_rollout_ref.rollout.multi_turn.tool_config_path=/cfs_turbo/yuleiqin/VERL/verl/examples/sglang_multiturn/config/tool_config/python_tool_config.yaml']
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
File "/cfs_turbo/yuleiqin/VERL/verl/verl/trainer/main_ppo.py", line 225, in <module>
    main()
File "/usr/local/lib/python3.10/dist-packages/hydra/main.py", line 94, in decorated_main
    _run_hydra(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 394, in _run_hydra
    _run_app(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 457, in _run_app
    run_and_report(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 223, in run_and_report
    raise ex
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 220, in run_and_report
    return func()
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/utils.py", line 458, in <lambda>
    lambda: hydra.run(
File "/usr/local/lib/python3.10/dist-packages/hydra/_internal/hydra.py", line 132, in run
    _ = ret.return_value
File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 260, in return_value
    raise self._return_value
File "/usr/local/lib/python3.10/dist-packages/hydra/core/utils.py", line 186, in run_job
    ret.return_value = task_function(task_cfg)
File "/cfs_turbo/yuleiqin/VERL/verl/verl/trainer/main_ppo.py", line 27, in main
    run_ppo(config)
File "/cfs_turbo/yuleiqin/VERL/verl/verl/trainer/main_ppo.py", line 42, in run_ppo
    ray.get(runner.run.remote(config))
File "/usr/local/lib/python3.10/dist-packages/ray/_private/auto_init_hook.py", line 21, in auto_init_wrapper
    return fn(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/client_mode_hook.py", line 103, in wrapper
    return func(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 2822, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
File "/usr/local/lib/python3.10/dist-packages/ray/_private/worker.py", line 930, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(AssertionError): [36mray::TaskRunner.run()[39m (pid=59667, ip=train-1337910252737936640-a60e1gtpdlvk-master-0, actor_id=aac7e2d4fc0f32314a7ae70002000000, repr=<main_ppo.TaskRunner object at 0x7f233abc4070>)
File "/cfs_turbo/yuleiqin/VERL/verl/verl/trainer/main_ppo.py", line 162, in run
    trainer.fit()
File "/cfs_turbo/yuleiqin/VERL/verl/verl/trainer/ppo/ray_trainer.py", line 1125, in fit
    **batch = batch.union(gen_batch_output)**
File "/cfs_turbo/yuleiqin/VERL/verl/verl/protocol.py", line 588, in union
    self.batch = union_tensor_dict(self.batch, other.batch)
File "/cfs_turbo/yuleiqin/VERL/verl/verl/protocol.py", line 106, in union_tensor_dict
assert tensor_dict1.batch_size == tensor_dict2.batch_size, f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}"
AssertionError: Two tensor dict must have identical batch size. Got torch.Size([8192]) and torch.Size([1024])
[36m(WorkerDict pid=1360, ip=train-1337910252737936640-a60e1gtpdlvk-worker-1)[0m /cfs_turbo/yuleiqin/VERL/verl/verl/workers/rollout/sglang_rollout/utils.py:49: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.)[32m [repeated 7x across cluster][0m
[36m(WorkerDict pid=1360, ip=train-1337910252737936640-a60e1gtpdlvk-worker-1)[0m   tensor_data = torch.ByteTensor([32m [repeated 7x across cluster][0m
2025-06-09 01:02:53,464	ERR cli.py:73 -- [31m---------------------------------------[39m
2025-06-09 01:02:53,464	ERR cli.py:74 -- [31mJob 'raysubmit_TDRFGkdGcXw68Tx2' failed[39m
2025-06-09 01:02:53,464	ERR cli.py:75 -- [31m---------------------------------------[39m
2025-06-09 01:02:53,464	INFO cli.py:88 -- Status message: Job entrypoint command failed with exit code 1, last available logs (truncated to 20,000 chars):
    trainer.fit()
File "/cfs_turbo/yuleiqin/VERL/verl/verl/trainer/ppo/ray_trainer.py", line 1125, in fit
    batch = batch.union(gen_batch_output)
File "/cfs_turbo/yuleiqin/VERL/verl/verl/protocol.py", line 588, in union
    self.batch = union_tensor_dict(self.batch, other.batch)
File "/cfs_turbo/yuleiqin/VERL/verl/verl/protocol.py", line 106, in union_tensor_dict
assert tensor_dict1.batch_size == tensor_dict2.batch_size, f"Two tensor dict must have identical batch size. Got {tensor_dict1.batch_size} and {tensor_dict2.batch_size}"
AssertionError: Two tensor dict must have identical batch size. Got torch.Size([8192]) and torch.Size([1024])

@yuleiqin
I can not reproduce it using similar settings as your log. Can you try to narrow down the steps to reproduce it?

  1. using 1 node instead of 4 with a smaller model, like Qwen2.5-1.5B-Instruct
  2. remove those custom_xx configs.
  3. remove actor_rollout_ref.rollout.multi_turn.tool_config_path
  4. using gsm8k dataset.
  5. pull the latest code, and let me know the latest git commit id in your env.

Please provide the training script as simple as possible so that I can reproduce it.

@ShichaoSun
Copy link

@chenhaiq I have tried to run the example via bash examples/grpo_trainer/run_qwen2-7b_seq_balance.sh , but I got an error as below:
Exception: sgl-kernel is installed with version 0.1.0, which is less than the minimum required version 0.1.1. Please reinstall the latest version with 'pip install sgl-kernel --force-reinstall'. Note that I created my conda env based on the specification verl install-dependencies.

I fixed it via pip install sgl-kernel==0.1.0. But I got another error

  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/concurrent/futures/_base.py", line 445, in result
    return self.__get_result()
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/concurrent/futures/_base.py", line 390, in __get_result
    raise self._exception
  File "/home/ubuntu/verl/verl/single_controller/ray/base.py", line 645, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
  File "/home/ubuntu/verl/verl/single_controller/base/decorator.py", line 540, in inner
    return func(*args, **kwargs)
  File "/home/ubuntu/verl/verl/workers/fsdp_workers.py", line 537, in init_model
    self.rollout, self.rollout_sharding_manager = self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False))
  File "/home/ubuntu/verl/verl/workers/fsdp_workers.py", line 1408, in _build_rollout
    rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code)
  File "/home/ubuntu/verl/verl/workers/fsdp_workers.py", line 447, in _build_rollout
    rollout = SGLangRollout(
  File "/home/ubuntu/verl/verl/workers/rollout/sglang_rollout/sglang_rollout.py", line 276, in __init__
    self._init_inference_engine(trust_remote_code, actor_module, port)
  File "/home/ubuntu/verl/verl/workers/rollout/sglang_rollout/sglang_rollout.py", line 359, in _init_inference_engine
    self._engine = AsyncEngine(
  File "/home/ubuntu/verl/verl/workers/rollout/sglang_rollout/sglang_rollout.py", line 138, in __init__
    super().__init__(**kwargs)
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/site-packages/sglang/srt/entrypoints/engine.py", line 122, in __init__
    tokenizer_manager, scheduler_info = _launch_subprocesses(
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/site-packages/sglang/srt/entrypoints/engine.py", line 582, in _launch_subprocesses
    tokenizer_manager = TokenizerManager(server_args, port_args)
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 228, in __init__
    self.model_update_lock = RWLock()
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/site-packages/sglang/srt/aio_rwlock.py", line 10, in __init__
    self._cond = asyncio.Condition(self._lock)
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/asyncio/locks.py", line 234, in __init__
    raise ValueError("loop argument must agree with lock")
ValueError: loop argument must agree with lock

Could you please provide some hints to help me solve it?

@chenhaiq
Copy link
Collaborator Author

@chenhaiq I have tried to run the example via bash examples/grpo_trainer/run_qwen2-7b_seq_balance.sh , but I got an error as below: Exception: sgl-kernel is installed with version 0.1.0, which is less than the minimum required version 0.1.1. Please reinstall the latest version with 'pip install sgl-kernel --force-reinstall'. Note that I created my conda env based on the specification verl install-dependencies.

I fixed it via pip install sgl-kernel==0.1.0. But I got another error

  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/concurrent/futures/_base.py", line 445, in result
    return self.__get_result()
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/concurrent/futures/_base.py", line 390, in __get_result
    raise self._exception
  File "/home/ubuntu/verl/verl/single_controller/ray/base.py", line 645, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
  File "/home/ubuntu/verl/verl/single_controller/base/decorator.py", line 540, in inner
    return func(*args, **kwargs)
  File "/home/ubuntu/verl/verl/workers/fsdp_workers.py", line 537, in init_model
    self.rollout, self.rollout_sharding_manager = self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False))
  File "/home/ubuntu/verl/verl/workers/fsdp_workers.py", line 1408, in _build_rollout
    rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code)
  File "/home/ubuntu/verl/verl/workers/fsdp_workers.py", line 447, in _build_rollout
    rollout = SGLangRollout(
  File "/home/ubuntu/verl/verl/workers/rollout/sglang_rollout/sglang_rollout.py", line 276, in __init__
    self._init_inference_engine(trust_remote_code, actor_module, port)
  File "/home/ubuntu/verl/verl/workers/rollout/sglang_rollout/sglang_rollout.py", line 359, in _init_inference_engine
    self._engine = AsyncEngine(
  File "/home/ubuntu/verl/verl/workers/rollout/sglang_rollout/sglang_rollout.py", line 138, in __init__
    super().__init__(**kwargs)
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/site-packages/sglang/srt/entrypoints/engine.py", line 122, in __init__
    tokenizer_manager, scheduler_info = _launch_subprocesses(
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/site-packages/sglang/srt/entrypoints/engine.py", line 582, in _launch_subprocesses
    tokenizer_manager = TokenizerManager(server_args, port_args)
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 228, in __init__
    self.model_update_lock = RWLock()
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/site-packages/sglang/srt/aio_rwlock.py", line 10, in __init__
    self._cond = asyncio.Condition(self._lock)
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/asyncio/locks.py", line 234, in __init__
    raise ValueError("loop argument must agree with lock")
ValueError: loop argument must agree with lock

Could you please provide some hints to help me solve it?

did you install sglang==0.4.6.post5?
It is required since #1717

@TianL123
Copy link

@chenhaiq I have tried to run the example via bash examples/grpo_trainer/run_qwen2-7b_seq_balance.sh , but I got an error as below: Exception: sgl-kernel is installed with version 0.1.0, which is less than the minimum required version 0.1.1. Please reinstall the latest version with 'pip install sgl-kernel --force-reinstall'. Note that I created my conda env based on the specification verl install-dependencies.

I fixed it via pip install sgl-kernel==0.1.0. But I got another error

  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/concurrent/futures/_base.py", line 445, in result
    return self.__get_result()
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/concurrent/futures/_base.py", line 390, in __get_result
    raise self._exception
  File "/home/ubuntu/verl/verl/single_controller/ray/base.py", line 645, in func
    return getattr(self.worker_dict[key], name)(*args, **kwargs)
  File "/home/ubuntu/verl/verl/single_controller/base/decorator.py", line 540, in inner
    return func(*args, **kwargs)
  File "/home/ubuntu/verl/verl/workers/fsdp_workers.py", line 537, in init_model
    self.rollout, self.rollout_sharding_manager = self._build_rollout(trust_remote_code=self.config.model.get("trust_remote_code", False))
  File "/home/ubuntu/verl/verl/workers/fsdp_workers.py", line 1408, in _build_rollout
    rollout, rollout_sharding_manager = super()._build_rollout(trust_remote_code)
  File "/home/ubuntu/verl/verl/workers/fsdp_workers.py", line 447, in _build_rollout
    rollout = SGLangRollout(
  File "/home/ubuntu/verl/verl/workers/rollout/sglang_rollout/sglang_rollout.py", line 276, in __init__
    self._init_inference_engine(trust_remote_code, actor_module, port)
  File "/home/ubuntu/verl/verl/workers/rollout/sglang_rollout/sglang_rollout.py", line 359, in _init_inference_engine
    self._engine = AsyncEngine(
  File "/home/ubuntu/verl/verl/workers/rollout/sglang_rollout/sglang_rollout.py", line 138, in __init__
    super().__init__(**kwargs)
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/site-packages/sglang/srt/entrypoints/engine.py", line 122, in __init__
    tokenizer_manager, scheduler_info = _launch_subprocesses(
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/site-packages/sglang/srt/entrypoints/engine.py", line 582, in _launch_subprocesses
    tokenizer_manager = TokenizerManager(server_args, port_args)
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/site-packages/sglang/srt/managers/tokenizer_manager.py", line 228, in __init__
    self.model_update_lock = RWLock()
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/site-packages/sglang/srt/aio_rwlock.py", line 10, in __init__
    self._cond = asyncio.Condition(self._lock)
  File "/home/ubuntu/anaconda3/envs/verl/lib/python3.10/asyncio/locks.py", line 234, in __init__
    raise ValueError("loop argument must agree with lock")
ValueError: loop argument must agree with lock

Could you please provide some hints to help me solve it?

@ShichaoSun Did you solve it? I'm having the same problem.

@ShichaoSun
Copy link

@TianL123 No. I have tried to install sglang==0.4.6.post5, but I got the same issue ValueError: loop argument must agree with lock .

@chenhaiq
Copy link
Collaborator Author

Exception: sgl-kernel is installed with version 0.1.0, which is less than the minimum required version 0.1.1. Please reinstall the latest version with 'pip install sgl-kernel --force-reinstall'. Note that I created my conda env based on the specification verl install-dependencies.

can you try this image?
whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6.post5-mcore0.12.0-te2

It is the image used in CI.

@TianL123
Copy link

@TianL123 No. I have tried to install sglang==0.4.6.post5, but I got the same issue ValueError: loop argument must agree with lock .

@ShichaoSun
I have added "getattr(self._lock, "_get_loop", lambda: None)()" that can be executed.
sglang/srt/aio_rwlock.py
class RWLock:
def init(self):
self._lock = asyncio.Lock()
getattr(self._lock, "_get_loop", lambda: None)()
self._cond = asyncio.Condition(self._lock)

@TianL123
Copy link

@chenhaiq
Title: state_dict() hangs when running run_qwen2-7b_seq_balance.sh with tp=2

Description:
When I run the following script:

verl/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh
with tp=2, the process appears to hang indefinitely. After some debugging, it seems to be stuck in the following line inside:

verl/verl/workers/sharding_manager/fsdp_sglang.py
Specifically, in this async function:

async def wake_up(self):
...
params = self.module.state_dict() # <-- hangs here
This issue does not occur when I use tp=1. It seems related to tensor parallelism settings, but I’m not sure what exactly is causing the deadlock or hang.
Can you help me solve this?

@chenhaiq
Copy link
Collaborator Author

Title: state_dict() hangs when running run_qwen2-7b_seq_balance.sh with tp=2

fixed in #2098

@TianL123
Copy link

@chenhaiq
When I run the following command:
NGINE=sglang ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh
the program hangs during execution. I am using 2 GPUs, and all other settings and configurations are exactly the same as before (when it worked).However, when I switch to the vllm backend with async mode, the same script runs normally. This suggests that the issue is not due to hardware, GPU setup, or async rollout mode itself.
Training Progress: 0%| | 0/1 [00:00<?, ?it/s]
=========================================+======================+======================|
| 0 NVIDIA H800 PCIe On | 00000000:27:00.0 Off | 0 |
| N/A 48C P0 86W / 350W | 67802MiB / 81559MiB | 100% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
| 1 NVIDIA H800 PCIe On | 00000000:B8:00.0 Off | 0 |
| N/A 53C P0 89W / 350W | 68642MiB / 81559MiB | 0% Default |
| | | Disabled |
+-----------------------------------------+----------------------+----------------------+
Could you help me identify the cause? How can I debug or fix this hanging issue when using multiple GPUs?
Thank you!

@chenhaiq
Copy link
Collaborator Author

@chenhaiq When I run the following command: NGINE=sglang ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh the program hangs during execution. I am using 2 GPUs, and all other settings and configurations are exactly the same as before (when it worked).However, when I switch to the vllm backend with async mode, the same script runs normally. This suggests that the issue is not due to hardware, GPU setup, or async rollout mode itself. Training Progress: 0%| | 0/1 [00:00<?, ?it/s] =========================================+======================+======================| | 0 NVIDIA H800 PCIe On | 00000000:27:00.0 Off | 0 | | N/A 48C P0 86W / 350W | 67802MiB / 81559MiB | 100% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 1 NVIDIA H800 PCIe On | 00000000:B8:00.0 Off | 0 | | N/A 53C P0 89W / 350W | 68642MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ Could you help me identify the cause? How can I debug or fix this hanging issue when using multiple GPUs? Thank you!

did you run it with fix from #2098 ?

If yes, please provides a console log

@TianL123
Copy link

TianL123 commented Jun 25, 2025

@chenhaiq When I run the following command: NGINE=sglang ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh the program hangs during execution. I am using 2 GPUs, and all other settings and configurations are exactly the same as before (when it worked).However, when I switch to the vllm backend with async mode, the same script runs normally. This suggests that the issue is not due to hardware, GPU setup, or async rollout mode itself. Training Progress: 0%| | 0/1 [00:00<?, ?it/s] =========================================+======================+======================| | 0 NVIDIA H800 PCIe On | 00000000:27:00.0 Off | 0 | | N/A 48C P0 86W / 350W | 67802MiB / 81559MiB | 100% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ | 1 NVIDIA H800 PCIe On | 00000000:B8:00.0 Off | 0 | | N/A 53C P0 89W / 350W | 68642MiB / 81559MiB | 0% Default | | | | Disabled | +-----------------------------------------+----------------------+----------------------+ Could you help me identify the cause? How can I debug or fix this hanging issue when using multiple GPUs? Thank you!

did you run it with fix from #2098 ?

If yes, please provides a console log

ENGINE=sglang ROLLOUT_MODE=async bash tests/special_e2e/ppo_trainer/run_function_reward.sh
+ NUM_GPUS=2
+ HOME=/share/huitianlin
+ MODEL_ID=Qwen2.5-0.5B-Instruct
+ MODEL_PATH=/share/huitianlin/models/Qwen2.5-0.5B-Instruct
+ TRAIN_FILES=/share/huitianlin/data/gsm8k/train.parquet
+ VAL_FILES=/share/huitianlin/data/gsm8k/test.parquet
+ MAX_PROMPT_LEN=512
+ MAX_RESPONSE_LEN=512
+ ENGINE=sglang
+ ROLLOUT_MODE=async
+ RETURN_RAW_CHAT=False
+ '[' async = async ']'
+ RETURN_RAW_CHAT=True
+ GPU_MEMORY_UTILIZATION=0.8
+ ACTOR_FSDP_PARAM_OFFLOAD=False
+ ACTOR_FSDP_OPTIMIZER_OFFLOAD=False
+ REF_FSDP_PARAM_OFFLOAD=True
+ RM_PAD=True
+ FUSED_KERNELS=False
+ FUSED_KERNEL_BACKEND=torch
+ ADV_ESTIMATOR=gae
+ USE_KL=False
+ CUSTOM_REWARD_FN=False
+ ENABLE_CHUNKED_PREFILL=True
+ STRATEGY=fsdp
+ LORA_RANK=0
+ LORA_ALPHA=0
+ USE_SHM=False
+ LOAD_FORMAT=dummy_dtensor
+ LAYERED_SUMMON=False
+ VAL_BEFORE_TRAIN=False
+ TEST_FREQ=-1
+ RESUME_MODE=disable
+ SAVE_FREQ=-1
+ TOTAL_TRAIN_STEPS=1
+ SAVE_HF_MODEL=False
+ FSDP_SIZE=-1
+ SP_SIZE=1
+ '[' False = True ']'
+ CHECKPOINT_CONTENTS='['\''model'\'','\''optimizer'\'','\''extra'\'']'
+ train_traj_micro_bsz_per_gpu=2
+ n_resp_per_prompt=4
+ train_traj_micro_bsz=4
+ train_traj_mini_bsz=8
+ train_prompt_mini_bsz=32
+ train_prompt_bsz=64
+ reward_fn_name=null
+ reward_fn_file_path=null
++ pwd
+ output_file=/share/huitianlin/verl/output.txt
+ '[' False = True ']'
++ basename qwen2.5-0.5b-instruct
+ exp_name=qwen2.5-0.5b-instruct-function-reward-minimal
+ python3 -m verl.trainer.main_ppo algorithm.adv_estimator=gae data.train_files=/share/huitianlin/data/gsm8k/train.parquet data.val_files=/share/huitianlin/data/gsm8k/test.parquet data.train_batch_size=64 data.max_prompt_length=512 data.max_response_length=512 data.return_raw_chat=True actor_rollout_ref.model.path=/share/huitianlin/models/Qwen2.5-0.5B-Instruct actor_rollout_ref.model.use_shm=False actor_rollout_ref.model.lora_rank=0 actor_rollout_ref.model.lora_alpha=0 actor_rollout_ref.actor.optim.lr=1e-6 actor_rollout_ref.model.use_remove_padding=True actor_rollout_ref.model.use_fused_kernels=False actor_rollout_ref.model.fused_kernel_options.impl_backend=torch actor_rollout_ref.actor.ppo_mini_batch_size=32 actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 actor_rollout_ref.actor.strategy=fsdp actor_rollout_ref.actor.fsdp_config.param_offload=False actor_rollout_ref.actor.fsdp_config.optimizer_offload=False actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 'actor_rollout_ref.actor.checkpoint.save_contents=['\''model'\'','\''optimizer'\'','\''extra'\'']' actor_rollout_ref.actor.use_kl_loss=False actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 actor_rollout_ref.rollout.tensor_model_parallel_size=2 actor_rollout_ref.rollout.name=sglang actor_rollout_ref.rollout.mode=async actor_rollout_ref.rollout.load_format=dummy_dtensor actor_rollout_ref.rollout.layered_summon=False actor_rollout_ref.rollout.gpu_memory_utilization=0.8 actor_rollout_ref.rollout.enable_chunked_prefill=True actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2 actor_rollout_ref.ref.fsdp_config.param_offload=True critic.optim.lr=1e-5 critic.model.use_remove_padding=True critic.model.path=/share/huitianlin/models/Qwen2.5-0.5B-Instruct critic.model.enable_gradient_checkpointing=False critic.ppo_micro_batch_size_per_gpu=2 critic.model.fsdp_config.param_offload=False critic.model.fsdp_config.optimizer_offload=False custom_reward_function.path=null custom_reward_function.name=null algorithm.use_kl_in_reward=False algorithm.kl_penalty=kl algorithm.kl_ctrl.kl_coef=0.001 trainer.critic_warmup=0 'trainer.logger=[console]' trainer.project_name=verl-test trainer.experiment_name=qwen2.5-0.5b-instruct-function-reward-minimal trainer.nnodes=1 trainer.n_gpus_per_node=2 trainer.val_before_train=False trainer.test_freq=-1 trainer.save_freq=-1 trainer.resume_mode=disable trainer.total_epochs=5 trainer.device=cuda
2025-06-25 18:46:39,184 INFO worker.py:1908 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 
(TaskRunner pid=192441) TaskRunner hostname: kml-dtmachine-19988-prod, PID: 192441
(TaskRunner pid=192441) {'actor_rollout_ref': {'actor': {'checkpoint': {'load_contents': ['model',
(TaskRunner pid=192441)                                                                   'optimizer',
(TaskRunner pid=192441)                                                                   'extra'],
(TaskRunner pid=192441)                                                 'save_contents': ['model',
(TaskRunner pid=192441)                                                                   'optimizer',
(TaskRunner pid=192441)                                                                   'extra']},
(TaskRunner pid=192441)                                  'clip_ratio': 0.2,
(TaskRunner pid=192441)                                  'clip_ratio_c': 3.0,
(TaskRunner pid=192441)                                  'clip_ratio_high': 0.2,
(TaskRunner pid=192441)                                  'clip_ratio_low': 0.2,
(TaskRunner pid=192441)                                  'entropy_checkpointing': False,
(TaskRunner pid=192441)                                  'entropy_coeff': 0,
(TaskRunner pid=192441)                                  'entropy_from_logits_with_chunking': False,
(TaskRunner pid=192441)                                  'fsdp_config': {'forward_prefetch': False,
(TaskRunner pid=192441)                                                  'fsdp_size': -1,
(TaskRunner pid=192441)                                                  'offload_policy': False,
(TaskRunner pid=192441)                                                  'optimizer_offload': False,
(TaskRunner pid=192441)                                                  'param_offload': False,
(TaskRunner pid=192441)                                                  'reshard_after_forward': True,
(TaskRunner pid=192441)                                                  'wrap_policy': {'min_num_params': 0}},
(TaskRunner pid=192441)                                  'grad_clip': 1.0,
(TaskRunner pid=192441)                                  'kl_loss_coef': 0.001,
(TaskRunner pid=192441)                                  'kl_loss_type': 'low_var_kl',
(TaskRunner pid=192441)                                  'loss_agg_mode': 'token-mean',
(TaskRunner pid=192441)                                  'optim': {'lr': 1e-06,
(TaskRunner pid=192441)                                            'lr_warmup_steps': -1,
(TaskRunner pid=192441)                                            'lr_warmup_steps_ratio': 0.0,
(TaskRunner pid=192441)                                            'min_lr_ratio': 0.0,
(TaskRunner pid=192441)                                            'num_cycles': 0.5,
(TaskRunner pid=192441)                                            'total_training_steps': -1,
(TaskRunner pid=192441)                                            'warmup_style': 'constant',
(TaskRunner pid=192441)                                            'weight_decay': 0.01},
(TaskRunner pid=192441)                                  'policy_loss': {'clip_cov_lb': 1.0,
(TaskRunner pid=192441)                                                  'clip_cov_ratio': 0.0002,
(TaskRunner pid=192441)                                                  'clip_cov_ub': 5.0,
(TaskRunner pid=192441)                                                  'kl_cov_ratio': 0.0002,
(TaskRunner pid=192441)                                                  'loss_mode': 'vanilla',
(TaskRunner pid=192441)                                                  'ppo_kl_coef': 0.1},
(TaskRunner pid=192441)                                  'ppo_epochs': 1,
(TaskRunner pid=192441)                                  'ppo_max_token_len_per_gpu': 16384,
(TaskRunner pid=192441)                                  'ppo_micro_batch_size': None,
(TaskRunner pid=192441)                                  'ppo_micro_batch_size_per_gpu': 2,
(TaskRunner pid=192441)                                  'ppo_mini_batch_size': 32,
(TaskRunner pid=192441)                                  'profiler': {'all_ranks': False,
(TaskRunner pid=192441)                                               'discrete': False,
(TaskRunner pid=192441)                                               'ranks': None},
(TaskRunner pid=192441)                                  'shuffle': False,
(TaskRunner pid=192441)                                  'strategy': 'fsdp',
(TaskRunner pid=192441)                                  'ulysses_sequence_parallel_size': 1,
(TaskRunner pid=192441)                                  'use_dynamic_bsz': False,
(TaskRunner pid=192441)                                  'use_kl_loss': False,
(TaskRunner pid=192441)                                  'use_torch_compile': True},
(TaskRunner pid=192441)                        'hybrid_engine': True,
(TaskRunner pid=192441)                        'model': {'custom_chat_template': None,
(TaskRunner pid=192441)                                  'enable_activation_offload': False,
(TaskRunner pid=192441)                                  'enable_gradient_checkpointing': True,
(TaskRunner pid=192441)                                  'external_lib': None,
(TaskRunner pid=192441)                                  'fused_kernel_options': {'impl_backend': 'torch'},
(TaskRunner pid=192441)                                  'lora_alpha': 0,
(TaskRunner pid=192441)                                  'lora_rank': 0,
(TaskRunner pid=192441)                                  'override_config': {},
(TaskRunner pid=192441)                                  'path': '/share/huitianlin/models/Qwen2.5-0.5B-Instruct',
(TaskRunner pid=192441)                                  'target_modules': 'all-linear',
(TaskRunner pid=192441)                                  'trust_remote_code': False,
(TaskRunner pid=192441)                                  'use_fused_kernels': False,
(TaskRunner pid=192441)                                  'use_liger': False,
(TaskRunner pid=192441)                                  'use_remove_padding': True,
(TaskRunner pid=192441)                                  'use_shm': False},
(TaskRunner pid=192441)                        'ref': {'entropy_checkpointing': False,
(TaskRunner pid=192441)                                'entropy_from_logits_with_chunking': False,
(TaskRunner pid=192441)                                'fsdp_config': {'forward_prefetch': False,
(TaskRunner pid=192441)                                                'param_offload': True,
(TaskRunner pid=192441)                                                'reshard_after_forward': True,
(TaskRunner pid=192441)                                                'wrap_policy': {'min_num_params': 0}},
(TaskRunner pid=192441)                                'log_prob_max_token_len_per_gpu': 16384,
(TaskRunner pid=192441)                                'log_prob_micro_batch_size': None,
(TaskRunner pid=192441)                                'log_prob_micro_batch_size_per_gpu': 2,
(TaskRunner pid=192441)                                'log_prob_use_dynamic_bsz': False,
(TaskRunner pid=192441)                                'profiler': {'all_ranks': False,
(TaskRunner pid=192441)                                             'discrete': False,
(TaskRunner pid=192441)                                             'ranks': None},
(TaskRunner pid=192441)                                'strategy': 'fsdp',
(TaskRunner pid=192441)                                'ulysses_sequence_parallel_size': 1,
(TaskRunner pid=192441)                                'use_torch_compile': True},
(TaskRunner pid=192441)                        'rollout': {'calculate_log_probs': False,
(TaskRunner pid=192441)                                    'disable_log_stats': True,
(TaskRunner pid=192441)                                    'do_sample': True,
(TaskRunner pid=192441)                                    'dtype': 'bfloat16',
(TaskRunner pid=192441)                                    'enable_chunked_prefill': True,
(TaskRunner pid=192441)                                    'enforce_eager': True,
(TaskRunner pid=192441)                                    'engine_kwargs': {'sglang': {'attention_backend': None},
(TaskRunner pid=192441)                                                      'vllm': {'disable_mm_preprocessor_cache': False,
(TaskRunner pid=192441)                                                               'swap_space': None}},
(TaskRunner pid=192441)                                    'free_cache_engine': True,
(TaskRunner pid=192441)                                    'gpu_memory_utilization': 0.8,
(TaskRunner pid=192441)                                    'ignore_eos': False,
(TaskRunner pid=192441)                                    'layered_summon': False,
(TaskRunner pid=192441)                                    'load_format': 'dummy_dtensor',
(TaskRunner pid=192441)                                    'log_prob_max_token_len_per_gpu': 16384,
(TaskRunner pid=192441)                                    'log_prob_micro_batch_size': None,
(TaskRunner pid=192441)                                    'log_prob_micro_batch_size_per_gpu': 2,
(TaskRunner pid=192441)                                    'log_prob_use_dynamic_bsz': False,
(TaskRunner pid=192441)                                    'max_model_len': None,
(TaskRunner pid=192441)                                    'max_num_batched_tokens': 8192,
(TaskRunner pid=192441)                                    'max_num_seqs': 1024,
(TaskRunner pid=192441)                                    'mode': 'async',
(TaskRunner pid=192441)                                    'multi_stage_wake_up': False,
(TaskRunner pid=192441)                                    'multi_turn': {'completion_callback': None,
(TaskRunner pid=192441)                                                   'enable': False,
(TaskRunner pid=192441)                                                   'format': 'hermes',
(TaskRunner pid=192441)                                                   'interaction_config_path': None,
(TaskRunner pid=192441)                                                   'max_assistant_turns': None,
(TaskRunner pid=192441)                                                   'max_user_turns': None,
(TaskRunner pid=192441)                                                   'tokenization_sanity_check_mode': 'strict',
(TaskRunner pid=192441)                                                   'tool_config_path': None,
(TaskRunner pid=192441)                                                   'use_inference_chat_template': False},
(TaskRunner pid=192441)                                    'n': 1,
(TaskRunner pid=192441)                                    'name': 'sglang',
(TaskRunner pid=192441)                                    'profiler': {'all_ranks': False,
(TaskRunner pid=192441)                                                 'discrete': False,
(TaskRunner pid=192441)                                                 'ranks': None},
(TaskRunner pid=192441)                                    'prompt_length': 512,
(TaskRunner pid=192441)                                    'response_length': 512,
(TaskRunner pid=192441)                                    'temperature': 1.0,
(TaskRunner pid=192441)                                    'tensor_model_parallel_size': 2,
(TaskRunner pid=192441)                                    'top_k': -1,
(TaskRunner pid=192441)                                    'top_p': 1,
(TaskRunner pid=192441)                                    'use_fire_sampling': False,
(TaskRunner pid=192441)                                    'val_kwargs': {'do_sample': False,
(TaskRunner pid=192441)                                                   'n': 1,
(TaskRunner pid=192441)                                                   'temperature': 0,
(TaskRunner pid=192441)                                                   'top_k': -1,
(TaskRunner pid=192441)                                                   'top_p': 1.0}}},
(TaskRunner pid=192441)  'algorithm': {'adv_estimator': 'gae',
(TaskRunner pid=192441)                'gamma': 1.0,
(TaskRunner pid=192441)                'kl_ctrl': {'horizon': 10000,
(TaskRunner pid=192441)                            'kl_coef': 0.001,
(TaskRunner pid=192441)                            'target_kl': 0.1,
(TaskRunner pid=192441)                            'type': 'fixed'},
(TaskRunner pid=192441)                'kl_penalty': 'kl',
(TaskRunner pid=192441)                'lam': 1.0,
(TaskRunner pid=192441)                'norm_adv_by_std_in_grpo': True,
(TaskRunner pid=192441)                'pf_ppo': {'reweight_method': 'pow', 'weight_pow': 2.0},
(TaskRunner pid=192441)                'use_kl_in_reward': False,
(TaskRunner pid=192441)                'use_pf_ppo': False},
(TaskRunner pid=192441)  'critic': {'checkpoint': {'load_contents': ['model', 'optimizer', 'extra'],
(TaskRunner pid=192441)                            'save_contents': ['model', 'optimizer', 'extra']},
(TaskRunner pid=192441)             'cliprange_value': 0.5,
(TaskRunner pid=192441)             'forward_max_token_len_per_gpu': 32768,
(TaskRunner pid=192441)             'forward_micro_batch_size': None,
(TaskRunner pid=192441)             'forward_micro_batch_size_per_gpu': 2,
(TaskRunner pid=192441)             'grad_clip': 1.0,
(TaskRunner pid=192441)             'loss_agg_mode': 'token-mean',
(TaskRunner pid=192441)             'model': {'enable_activation_offload': False,
(TaskRunner pid=192441)                       'enable_gradient_checkpointing': False,
(TaskRunner pid=192441)                       'external_lib': None,
(TaskRunner pid=192441)                       'fsdp_config': {'forward_prefetch': False,
(TaskRunner pid=192441)                                       'fsdp_size': -1,
(TaskRunner pid=192441)                                       'offload_policy': False,
(TaskRunner pid=192441)                                       'optimizer_offload': False,
(TaskRunner pid=192441)                                       'param_offload': False,
(TaskRunner pid=192441)                                       'reshard_after_forward': True,
(TaskRunner pid=192441)                                       'wrap_policy': {'min_num_params': 0}},
(TaskRunner pid=192441)                       'lora_alpha': 16,
(TaskRunner pid=192441)                       'lora_rank': 0,
(TaskRunner pid=192441)                       'override_config': {},
(TaskRunner pid=192441)                       'path': '/share/huitianlin/models/Qwen2.5-0.5B-Instruct',
(TaskRunner pid=192441)                       'target_modules': 'all-linear',
(TaskRunner pid=192441)                       'tokenizer_path': '/share/huitianlin/models/Qwen2.5-0.5B-Instruct',
(TaskRunner pid=192441)                       'trust_remote_code': False,
(TaskRunner pid=192441)                       'use_remove_padding': True,
(TaskRunner pid=192441)                       'use_shm': False},
(TaskRunner pid=192441)             'optim': {'lr': 1e-05,
(TaskRunner pid=192441)                       'lr_warmup_steps_ratio': 0.0,
(TaskRunner pid=192441)                       'min_lr_ratio': None,
(TaskRunner pid=192441)                       'total_training_steps': -1,
(TaskRunner pid=192441)                       'warmup_style': 'constant',
(TaskRunner pid=192441)                       'weight_decay': 0.01},
(TaskRunner pid=192441)             'ppo_epochs': 1,
(TaskRunner pid=192441)             'ppo_max_token_len_per_gpu': 32768,
(TaskRunner pid=192441)             'ppo_micro_batch_size': None,
(TaskRunner pid=192441)             'ppo_micro_batch_size_per_gpu': 2,
(TaskRunner pid=192441)             'ppo_mini_batch_size': 32,
(TaskRunner pid=192441)             'profiler': {'all_ranks': False, 'discrete': False, 'ranks': None},
(TaskRunner pid=192441)             'rollout_n': 1,
(TaskRunner pid=192441)             'shuffle': False,
(TaskRunner pid=192441)             'strategy': 'fsdp',
(TaskRunner pid=192441)             'ulysses_sequence_parallel_size': 1,
(TaskRunner pid=192441)             'use_dynamic_bsz': False},
(TaskRunner pid=192441)  'custom_reward_function': {'name': None, 'path': None},
(TaskRunner pid=192441)  'data': {'custom_cls': {'name': None, 'path': None},
(TaskRunner pid=192441)           'filter_overlong_prompts': False,
(TaskRunner pid=192441)           'filter_overlong_prompts_workers': 1,
(TaskRunner pid=192441)           'image_key': 'images',
(TaskRunner pid=192441)           'max_prompt_length': 512,
(TaskRunner pid=192441)           'max_response_length': 512,
(TaskRunner pid=192441)           'prompt_key': 'prompt',
(TaskRunner pid=192441)           'return_full_prompt': False,
(TaskRunner pid=192441)           'return_raw_chat': True,
(TaskRunner pid=192441)           'return_raw_input_ids': False,
(TaskRunner pid=192441)           'reward_fn_key': 'data_source',
(TaskRunner pid=192441)           'shuffle': True,
(TaskRunner pid=192441)           'tokenizer': None,
(TaskRunner pid=192441)           'train_batch_size': 64,
(TaskRunner pid=192441)           'train_files': '/share/huitianlin/data/gsm8k/train.parquet',
(TaskRunner pid=192441)           'truncation': 'error',
(TaskRunner pid=192441)           'trust_remote_code': False,
(TaskRunner pid=192441)           'use_shm': False,
(TaskRunner pid=192441)           'val_batch_size': None,
(TaskRunner pid=192441)           'val_files': '/share/huitianlin/data/gsm8k/test.parquet',
(TaskRunner pid=192441)           'validation_shuffle': False,
(TaskRunner pid=192441)           'video_key': 'videos'},
(TaskRunner pid=192441)  'ray_init': {'num_cpus': None, 'timeline_json_file': None},
(TaskRunner pid=192441)  'reward_model': {'enable': False,
(TaskRunner pid=192441)                   'forward_max_token_len_per_gpu': 32768,
(TaskRunner pid=192441)                   'launch_reward_fn_async': False,
(TaskRunner pid=192441)                   'max_length': None,
(TaskRunner pid=192441)                   'micro_batch_size': None,
(TaskRunner pid=192441)                   'micro_batch_size_per_gpu': None,
(TaskRunner pid=192441)                   'model': {'external_lib': None,
(TaskRunner pid=192441)                             'fsdp_config': {'forward_prefetch': False,
(TaskRunner pid=192441)                                             'fsdp_size': -1,
(TaskRunner pid=192441)                                             'param_offload': False,
(TaskRunner pid=192441)                                             'reshard_after_forward': True,
(TaskRunner pid=192441)                                             'wrap_policy': {'min_num_params': 0}},
(TaskRunner pid=192441)                             'input_tokenizer': '/share/huitianlin/models/Qwen2.5-0.5B-Instruct',
(TaskRunner pid=192441)                             'path': '~/models/FsfairX-LLaMA3-RM-v0.1',
(TaskRunner pid=192441)                             'trust_remote_code': False,
(TaskRunner pid=192441)                             'use_fused_kernels': False,
(TaskRunner pid=192441)                             'use_remove_padding': False,
(TaskRunner pid=192441)                             'use_shm': False},
(TaskRunner pid=192441)                   'profiler': {'all_ranks': False,
(TaskRunner pid=192441)                                'discrete': False,
(TaskRunner pid=192441)                                'ranks': None},
(TaskRunner pid=192441)                   'reward_manager': 'naive',
(TaskRunner pid=192441)                   'sandbox_fusion': {'max_concurrent': 64,
(TaskRunner pid=192441)                                      'memory_limit_mb': 1024,
(TaskRunner pid=192441)                                      'url': None},
(TaskRunner pid=192441)                   'strategy': 'fsdp',
(TaskRunner pid=192441)                   'ulysses_sequence_parallel_size': 1,
(TaskRunner pid=192441)                   'use_dynamic_bsz': False},
(TaskRunner pid=192441)  'trainer': {'balance_batch': True,
(TaskRunner pid=192441)              'controller_nsight_options': {'cuda-graph-trace': 'graph',
(TaskRunner pid=192441)                                            'cuda-memory-usage': 'true',
(TaskRunner pid=192441)                                            'trace': 'cuda,nvtx,cublas,ucx'},
(TaskRunner pid=192441)              'critic_warmup': 0,
(TaskRunner pid=192441)              'default_hdfs_dir': None,
(TaskRunner pid=192441)              'default_local_dir': 'checkpoints/verl-test/qwen2.5-0.5b-instruct-function-reward-minimal',
(TaskRunner pid=192441)              'del_local_ckpt_after_load': False,
(TaskRunner pid=192441)              'device': 'cuda',
(TaskRunner pid=192441)              'experiment_name': 'qwen2.5-0.5b-instruct-function-reward-minimal',
(TaskRunner pid=192441)              'log_val_generations': 0,
(TaskRunner pid=192441)              'logger': ['console'],
(TaskRunner pid=192441)              'max_actor_ckpt_to_keep': None,
(TaskRunner pid=192441)              'max_critic_ckpt_to_keep': None,
(TaskRunner pid=192441)              'n_gpus_per_node': 2,
(TaskRunner pid=192441)              'nnodes': 1,
(TaskRunner pid=192441)              'profile_steps': None,
(TaskRunner pid=192441)              'project_name': 'verl-test',
(TaskRunner pid=192441)              'ray_wait_register_center_timeout': 300,
(TaskRunner pid=192441)              'resume_from_path': None,
(TaskRunner pid=192441)              'resume_mode': 'disable',
(TaskRunner pid=192441)              'rollout_data_dir': None,
(TaskRunner pid=192441)              'save_freq': -1,
(TaskRunner pid=192441)              'test_freq': -1,
(TaskRunner pid=192441)              'total_epochs': 5,
(TaskRunner pid=192441)              'total_training_steps': None,
(TaskRunner pid=192441)              'val_before_train': False,
(TaskRunner pid=192441)              'val_only': False,
(TaskRunner pid=192441)              'validation_data_dir': None,
(TaskRunner pid=192441)              'worker_nsight_options': {'capture-range': 'cudaProfilerApi',
(TaskRunner pid=192441)                                        'capture-range-end': None,
(TaskRunner pid=192441)                                        'cuda-graph-trace': 'graph',
(TaskRunner pid=192441)                                        'cuda-memory-usage': 'true',
(TaskRunner pid=192441)                                        'kill': 'none',
(TaskRunner pid=192441)                                        'trace': 'cuda,nvtx,cublas,ucx'}}}
(TaskRunner pid=192441) Using dataset class: RLHFDataset
(TaskRunner pid=192441) dataset len: 7473
(TaskRunner pid=192441) Using dataset class: RLHFDataset
(TaskRunner pid=192441) dataset len: 1319
(TaskRunner pid=192441) [validate_config] All configuration checks passed successfully!
(TaskRunner pid=192441) Size of train dataloader: 116, Size of val dataloader: 1
(TaskRunner pid=192441) Total training steps: 580
(TaskRunner pid=192441) colocated worker base class <class 'verl.single_controller.base.worker.Worker'>
(TaskRunner pid=192441) bind role actor_rollout method chat_completion to class <class 'verl.single_controller.ray.base.create_colocated_worker_cls.<locals>.WorkerDict'>
(TaskRunner pid=192441) bind role actor_rollout method execute_method to class <class 'verl.single_controller.ray.base.create_colocated_worker_cls.<locals>.WorkerDict'>
(TaskRunner pid=192441) bind role actor_rollout method sleep to class <class 'verl.single_controller.ray.base.create_colocated_worker_cls.<locals>.WorkerDict'>
(TaskRunner pid=192441) bind role actor_rollout method wake_up to class <class 'verl.single_controller.ray.base.create_colocated_worker_cls.<locals>.WorkerDict'>
(TaskRunner pid=192441) DeprecationWarning: `ray.state.available_resources_per_node` is a private attribute and access will be removed in a future Ray version.
(TaskRunner pid=192441) WARNING:2025-06-25 18:46:52,993:Waiting for register center actor RPQ7i9_register_center to be ready. Elapsed time: 0 seconds out of 300 seconds.
(pid=193641) Using blocking ray.get inside async actor. This blocks the event loop. Please use `await` on object ref with asyncio.gather if you want to yield execution to the event loop instead.
(pid=7248) Using blocking ray.get inside async actor. This blocks the event loop. Please use `await` on object ref with asyncio.gather if you want to yield execution to the event loop instead.
(WorkerDict pid=193641) [W625 18:47:04.931574804 Utils.hpp:136] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
(WorkerDict pid=193641) Critic overriding config {'bos_token_id': None, 'eos_token_id': 151645, 'pad_token_id': 151643}
(WorkerDict pid=193641) Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForTokenClassification is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
(WorkerDict pid=193641) You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
(WorkerDict pid=193641) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention
(WorkerDict pid=193641) Skipping monkey patch for Qwen2ForTokenClassification as use_fused_kernels is False or fused_kernels_backend is None
(WorkerDict pid=193641) Some weights of Qwen2ForTokenClassification were not initialized from the model checkpoint at /share/huitianlin/models/Qwen2.5-0.5B-Instruct and are newly initialized: ['score.bias', 'score.weight']
(WorkerDict pid=193641) You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
(WorkerDict pid=193641) Qwen2ForTokenClassification contains 494.03M parameters
(WorkerDict pid=193641) Before critic FSDP, memory allocated (GB): 0.00, memory reserved (GB): 0.00, device memory used/total (GB): 0.45/79.11
(WorkerDict pid=193641) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=193641) After critic FSDP, memory allocated (GB): 0.92, memory reserved (GB): 2.67, device memory used/total (GB): 3.47/79.11
(WorkerDict pid=193641) Total steps: 580, num_warmup_steps: 0
(WorkerDict pid=193641) Critic use_remove_padding=True
(WorkerDict pid=193641) Model config after override: Qwen2Config {
(WorkerDict pid=193641)   "architectures": [
(WorkerDict pid=193641)     "Qwen2ForCausalLM"
(WorkerDict pid=193641)   ],
(WorkerDict pid=193641)   "attention_dropout": 0.0,
(WorkerDict pid=193641)   "eos_token_id": 151645,
(WorkerDict pid=193641)   "hidden_act": "silu",
(WorkerDict pid=193641)   "hidden_size": 896,
(WorkerDict pid=193641)   "initializer_range": 0.02,
(WorkerDict pid=193641)   "intermediate_size": 4864,
(WorkerDict pid=193641)   "max_position_embeddings": 32768,
(WorkerDict pid=193641)   "max_window_layers": 21,
(WorkerDict pid=193641)   "model_type": "qwen2",
(WorkerDict pid=193641)   "num_attention_heads": 14,
(WorkerDict pid=193641)   "num_hidden_layers": 24,
(WorkerDict pid=193641)   "num_key_value_heads": 2,
(WorkerDict pid=193641)   "pad_token_id": 151643,
(WorkerDict pid=193641)   "rms_norm_eps": 1e-06,
(WorkerDict pid=193641)   "rope_scaling": null,
(WorkerDict pid=193641)   "rope_theta": 1000000.0,
(WorkerDict pid=193641)   "sliding_window": 32768,
(WorkerDict pid=193641)   "tie_word_embeddings": true,
(WorkerDict pid=193641)   "torch_dtype": "bfloat16",
(WorkerDict pid=193641)   "transformers_version": "4.51.1",
(WorkerDict pid=193641)   "use_cache": true,
(WorkerDict pid=193641)   "use_sliding_window": false,
(WorkerDict pid=193641)   "vocab_size": 151936
(WorkerDict pid=193641) }
(WorkerDict pid=193641) 
(WorkerDict pid=193641) Skipping monkey patch for Qwen2ForCausalLM as use_fused_kernels is False or fused_kernels_backend is torch
(WorkerDict pid=193641) Qwen2ForCausalLM contains 494.03M parameters
(WorkerDict pid=193641) wrap_policy: functools.partial(<function _or_policy at 0x7eda03f40dc0>, policies=[functools.partial(<function transformer_auto_wrap_policy at 0x7eda03f40ca0>, transformer_layer_cls={<class 'transformers.models.qwen2.modeling_qwen2.Qwen2DecoderLayer'>})])
(WorkerDict pid=193641) Total steps: 580, num_warmup_steps: 0
(WorkerDict pid=193641) Actor use_remove_padding=True
(WorkerDict pid=193641) Actor use_fused_kernels=False
(WorkerDict pid=7248) /share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
(WorkerDict pid=7248)   warnings.warn(
(WorkerDict pid=7248) [W625 18:47:04.931566612 Utils.hpp:136] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
(WorkerDict pid=7248) Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in Qwen2ForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)` [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
(WorkerDict pid=7248) You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
(WorkerDict pid=7248) Some weights of Qwen2ForTokenClassification were not initialized from the model checkpoint at /share/huitianlin/models/Qwen2.5-0.5B-Instruct and are newly initialized: ['score.bias', 'score.weight']
(WorkerDict pid=7248) You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
(WorkerDict pid=193641) NCCL version 2.21.5+cuda12.4
(WorkerDict pid=7248) Monkey patch _flash_attention_forward in transformers.integrations.flash_attention [repeated 3x across cluster]
(WorkerDict pid=7248) Skipping monkey patch for Qwen2ForTokenClassification as use_fused_kernels is False or fused_kernels_backend is None
(WorkerDict pid=7248) Critic use_remove_padding=True
(WorkerDict pid=7248) Skipping monkey patch for Qwen2ForCausalLM as use_fused_kernels is False or fused_kernels_backend is torch
(WorkerDict pid=193641) [W625 18:47:21.926050073 Utils.hpp:136] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
(WorkerDict pid=193641) [W625 18:47:21.931882010 Utils.hpp:136] Warning: Environment variable NCCL_ASYNC_ERROR_HANDLING is deprecated; use TORCH_NCCL_ASYNC_ERROR_HANDLING instead (function operator())
Capturing batches (avail_mem=13.97 GB):   0%|          | 0/23 [00:00<?, ?it/s]
Capturing batches (avail_mem=13.67 GB):   4%|| 1/23 [00:00<00:20,  1.09it/s]
Capturing batches (avail_mem=13.54 GB):   9%|| 2/23 [00:01<00:11,  1.78it/s]
Capturing batches (avail_mem=13.41 GB):  13%|█▎        | 3/23 [00:01<00:10,  1.86it/s]
Capturing batches (avail_mem=13.29 GB):  17%|█▋        | 4/23 [00:01<00:07,  2.41it/s]
Capturing batches (avail_mem=13.17 GB):  22%|██▏       | 5/23 [00:02<00:06,  2.88it/s]
Capturing batches (avail_mem=13.06 GB):  26%|██▌       | 6/23 [00:02<00:05,  3.29it/s]
Capturing batches (avail_mem=12.95 GB):  30%|███       | 7/23 [00:02<00:04,  3.60it/s]
Capturing batches (avail_mem=12.85 GB):  35%|███▍      | 8/23 [00:02<00:04,  3.71it/s]
Capturing batches (avail_mem=12.75 GB):  39%|███▉      | 9/23 [00:03<00:03,  3.88it/s]
Capturing batches (avail_mem=12.66 GB):  43%|████▎     | 10/23 [00:03<00:03,  4.05it/s]
Capturing batches (avail_mem=12.62 GB):  48%|████▊     | 11/23 [00:03<00:03,  3.96it/s]
Capturing batches (avail_mem=12.53 GB):  52%|█████▏    | 12/23 [00:03<00:02,  4.08it/s]
Capturing batches (avail_mem=12.47 GB):  57%|█████▋    | 13/23 [00:04<00:02,  4.20it/s]
Capturing batches (avail_mem=12.39 GB):  61%|██████    | 14/23 [00:04<00:02,  4.21it/s]
Capturing batches (avail_mem=12.36 GB):  65%|██████▌   | 15/23 [00:04<00:01,  4.25it/s]
Capturing batches (avail_mem=12.31 GB):  70%|██████▉   | 16/23 [00:04<00:01,  4.25it/s]
Capturing batches (avail_mem=12.26 GB):  74%|███████▍  | 17/23 [00:05<00:01,  4.23it/s]
Capturing batches (avail_mem=12.22 GB):  78%|███████▊  | 18/23 [00:05<00:01,  4.19it/s]
Capturing batches (avail_mem=12.19 GB):  83%|████████▎ | 19/23 [00:05<00:00,  4.26it/s]
Capturing batches (avail_mem=12.16 GB):  87%|████████▋ | 20/23 [00:05<00:00,  4.34it/s]
Capturing batches (avail_mem=12.13 GB):  91%|█████████▏| 21/23 [00:05<00:00,  4.37it/s]
Capturing batches (avail_mem=12.10 GB):  96%|█████████▌| 22/23 [00:06<00:00,  4.30it/s]
Capturing batches (avail_mem=12.10 GB): 100%|██████████| 23/23 [00:06<00:00,  3.57it/s]
(WorkerDict pid=193641) /share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py:690: FutureWarning: FSDP.state_dict_type() and FSDP.set_state_dict_type() are being deprecated. Please use APIs, get_state_dict() and set_state_dict(), which can support different parallelisms, FSDP1, FSDP2, DDP. API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html#torch.distributed.checkpoint.state_dict.get_state_dict .Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .
(WorkerDict pid=193641)   warnings.warn(
(TaskRunner pid=192441) No CUDA runtime is found, using CUDA_HOME='/usr/local/cuda'
(AsyncSglangServer pid=11943) FastAPI listen on 11.46.27.250:60355
(TaskRunner pid=192441) Initialized tools: {}
(TaskRunner pid=192441) WARNING:2025-06-25 18:47:40,303:completion_callback is None, use ToolCompletionCallback
Training Progress:   0%|          | 0/580 [00:00<?, ?it/s]
(TaskRunner pid=192441) [ChatCompletionScheduler] generate_sequences sampling params: {'model': 'models/Qwen2.5-0.5B-Instruct', 'temperature': 1.0, 'top_p': 1}
(TaskRunner pid=192441) [id=chatcmpl-c616f62cbe2d4da4bbb2994d8bbb12c0,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-82fdc9bbf49a4699b60c2ed9b255a9e1,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-954818ca46164fc89719bd0ac07b12a1,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-0a36dade92b947719896b6f00f0a36b0,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-430eba7f18814efea23dd86caeadbf5e,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-67aa89c83e9342e3bb2f461bc599a405,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-cc1c0c57863b48ba8e173b1325f85638,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-0141bab4dfb54699b3863446b2b357e5,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-57130be9b4c443cabcf23da09b7bc102,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-7c656d2e108e4866bf362b4a5326c128,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-b28e659bf9fa46d197285f00d1497f32,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-f7d9d3c424104fdd9021c216ce65ca58,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-3005653c3f9f41cf8a211e6fefc1f5d3,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-a6c83098be444d2c887b656e96d3ba74,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-b5bb693a4f7b4c8ca646a6301118b68c,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-2ce2f5d35bf54a3d86d8d82afff9224b,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-d1362c7773ae4630b8ac6232e061d390,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-51cf52159ebb4751bd656b704993793f,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-5653b428752b4f6b8b411ed0a010df57,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-79f16c932d2e41d08771e5525d0f1319,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-b3ecb5660f9e481e80e11b4014b36a36,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-df7d53f3bf794967a0d3ec92b008cd53,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-97adf185f1794115884946ed3552fa8e,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-6ff87f8ca7d5427780fcad5679ea5a52,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-65f29b534536431ca63ba74fbf9be146,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-5dd6c2dc17084b979d63984753984253,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-85e5f0fc05704b60905fa77960767729,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-63b3acaa171a4c91bf4ad94994ed9733,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-8e6c5b9a12de4932af40f4212b6db197,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-0e2e7f5808af4e4d9f6473c2c3fb0660,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-9718d56709454f7293713c63b99e3cce,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-ce9e0fe020fd44a79430e217f1380240,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-0ebc9d303c714ebdaf34696eb3c03b06,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-5989f03244bd4d8ca3b4c42d22f97227,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-a0890834ea024a9aa0d8dcc09d61f6b5,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-183b8e0dc9b5482b902429138d37da68,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-a3e4ec7afcb64b3abd02670615b4f7a0,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-8e8189e0a1a4416a83373bf94fe0357a,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-f886677a757144c79ea25d2c8d11b3d7,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-6103c3bea218412391fdec5d000bdd99,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-9f6475cf18734c3c8c9d15987f894093,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-3572e68dfb6c4597b373bdb82cbff9b5,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-2848848a4d0f4f819eeb0a562550c398,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-222e2acdae0e45a0b625785f8f41f7e1,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-739d97c6518041b78085ee8a41f30d7a,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-21ad031f715f41f2866506d88a6c1aab,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-82dde350698949b2a686cbf568918846,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-f164a91737cc43b1ae754abfc046ad8d,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-144df18661e7416289fce9328e17e3f1,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-63123d0e9498428aaf004bcc40445aeb,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-cba682a79c934613854a8c23f29ed3bc,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-ea0640dbd2784e1bb51b40714e6f36c5,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-c656cb4ee50a434f897459c12e04b715,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-041f21e1b402466b9f7c36e74351ebc9,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-4aa09c2cf7b042ec90d50b0c580b0b4a,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-790a6ab9fffe47af8ed99367959f8f1b,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-749dfe098044424182a8174db9181b63,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-d79e03be7b9c470fbcde34a00ddbe69e,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-e2912413c30c45e1ae7e721aa8614bcc,turn=2,finish_reason=stop] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-16e9278f45a3436e9cf6c73e02ba7589,turn=2,finish_reason=length] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-1870f409adf64532a691903bfe56455a,turn=2,finish_reason=length] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-1219d1deb7b64081ad3d651014c2f7ba,turn=2,finish_reason=length] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-e1103153b24b41f392cb1f65d0997a5c,turn=2,finish_reason=length] No tool called, done!
(TaskRunner pid=192441) [id=chatcmpl-b5abf663ec0c4618b70574d31f354687,turn=2,finish_reason=length] No tool called, done!
(TaskRunner pid=192441) [ChatCompletionScheduler] generate_sequences done
(TaskRunner pid=192441) step:1 - global_seqlen/min:12171.000 - global_seqlen/max:14002.000 - global_seqlen/minmax_diff:1831.000 - global_seqlen/balanced_min:13086.000 - global_seqlen/balanced_max:13087.000 - global_seqlen/mean:13086.500 - actor/entropy:0.527 - critic/vf_loss:12.674 - critic/vf_clipfrac:0.329 - critic/vpred_mean:1.989 - critic/grad_norm:843.039 - perf/mfu/critic:0.018 - critic/lr:0.000 - actor/pg_loss:-0.019 - actor/pg_clipfrac:0.001 - actor/ppo_kl:0.000 - actor/pg_clipfrac_lower:0.000 - actor/grad_norm:2.653 - perf/mfu/actor:0.010 - perf/max_memory_allocated_gb:8.031 - perf/max_memory_reserved_gb:16.273 - perf/cpu_memory_used_gb:46.973 - actor/lr:0.000 - training/global_step:1.000 - training/epoch:0.000 - critic/score/mean:0.016 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.016 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:0.000 - critic/advantages/max:4.757 - critic/advantages/min:-3.485 - critic/returns/mean:0.011 - critic/returns/max:1.000 - critic/returns/min:0.000 - critic/values/mean:3.547 - critic/values/max:15.875 - critic/values/min:-13.312 - critic/vf_explained_var:-1151.247 - response_length/mean:304.141 - response_length/max:514.000 - response_length/min:112.000 - response_length/clip_ratio:0.078 - prompt_length/mean:104.812 - prompt_length/max:170.000 - prompt_length/min:66.000 - prompt_length/clip_ratio:0.016 - timing_s/generate_sequences:5.143 - timing_s/gen:6.614 - timing_s/reward:0.019 - timing_s/old_log_prob:4.187 - timing_s/values:0.966 - timing_s/adv:0.022 - timing_s/update_critic:2.922 - timing_s/update_actor:4.927 - timing_s/step:19.660 - timing_per_token_ms/update_actor:0.188 - timing_per_token_ms/adv:0.001 - timing_per_token_ms/gen:0.340 - timing_per_token_ms/update_critic:0.112 - timing_per_token_ms/values:0.037 - perf/total_num_tokens:26173.000 - perf/time_per_step:19.660 - perf/throughput:665.629
Training Progress:   0%|          | 1/580 [00:19<3:12:04, 19.90s/it]
(WorkerDict pid=193641) Fatal Python error: Aborted
(WorkerDict pid=193641) 
(WorkerDict pid=193641) Thread 0x00007fdcd74d2700 (most recent call first):
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py", line 1829 in watchdog_thread
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/threading.py", line 946 in run
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/threading.py", line 1009 in _bootstrap_inner
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/threading.py", line 966 in _bootstrap
(WorkerDict pid=193641) 
(WorkerDict pid=193641) Thread 0x00007fdcd7cd3700 (most recent call first):
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/threading.py", line 320 in wait
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/queue.py", line 171 in get
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/sglang/srt/managers/tp_worker_overlap_thread.py", line 130 in forward_thread_func_
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116 in decorate_context
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/sglang/srt/managers/tp_worker_overlap_thread.py", line 118 in forward_thread_func
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/threading.py", line 946 in run
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/threading.py", line 1009 in _bootstrap_inner
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/threading.py", line 966 in _bootstrap
(WorkerDict pid=193641) 
(WorkerDict pid=193641) Thread 0x00007fdce5fff700 (most recent call first):
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/threading.py", line 324 in wait
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/threading.py", line 600 in wait
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/tqdm/_monitor.py", line 60 in run
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/threading.py", line 1009 in _bootstrap_inner
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/threading.py", line 966 in _bootstrap
(WorkerDict pid=193641) 
(WorkerDict pid=193641) Current thread 0x00007fe14a5c4740 (most recent call first):
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/sglang/srt/managers/tp_worker.py", line 255 in update_weights_from_tensor
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/sglang/srt/managers/tp_worker_overlap_thread.py", line 254 in update_weights_from_tensor
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py", line 2035 in update_weights_from_tensor
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/sglang/utils.py", line 471 in __call__
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py", line 889 in process_input_requests
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py", line 662 in event_loop_overlap
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116 in decorate_context
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py", line 2311 in run_scheduler_process
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/multiprocessing/process.py", line 108 in run
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/multiprocessing/process.py", line 315 in _bootstrap
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/multiprocessing/spawn.py", line 129 in _main
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/multiprocessing/spawn.py", line 116 in spawn_main
(WorkerDict pid=193641)   File "<string>", line 1 in <module>
(WorkerDict pid=193641) 
(WorkerDict pid=193641) Extension modules: msgpack._cmsgpack, google._upb._message, psutil._psutil_linux, psutil._psutil_posix, setproctitle, yaml._yaml, charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, uvloop.loop, ray._raylet, numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, multidict._multidict, yarl._quoting_c, propcache._helpers_c, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket.mask, aiohttp._websocket.reader_c, frozenlist._frozenlist, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, zmq.backend.cython._zmq, markupsafe._speedups, PIL._imaging, PIL._imagingft, av._core, av.logging, av.bytesource, av.buffer, av.audio.format, av.error, av.dictionary, av.container.pyio, av.utils, av.option, av.descriptor, av.format, av.stream, av.container.streams, av.sidedata.motionvectors, av.sidedata.sidedata, av.opaque, av.packet, av.container.input, av.container.output, av.container.core, av.codec.context, av.video.format, av.video.reformatter, av.plane, av.video.plane, av.video.frame, av.video.stream, av.codec.hwaccel, av.codec.codec, av.frame, av.audio.layout, av.audio.plane, av.audio.frame, av.audio.stream, av.filter.pad, av.filter.link, av.filter.context, av.filter.graph, av.filter.filter, av.filter.loudnorm, av.audio.resampler, av.audio.codeccontext, av.audio.fifo, av.bitstream, av.video.codeccontext, sentencepiece._sentencepiece, msgspec._core, _cffi_backend, cuda.bindings._lib.utils, cuda.bindings._bindings.cydriver, cuda.bindings.cydriver, cuda.bindings.driver, cuda.bindings._bindings.cynvrtc, cuda.bindings.cynvrtc, cuda.bindings.nvrtc, regex._regex, scipy._lib._ccallback_c, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg._matfuncs_expm, scipy.linalg._linalg_pythran, scipy.linalg.cython_blas, scipy.linalg._decomp_update, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.linalg._propack._spropack, scipy.sparse.linalg._propack._dpropack, scipy.sparse.linalg._propack._cpropack, scipy.sparse.linalg._propack._zpropack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.optimize._group_columns, scipy._lib.messagestream, scipy.optimize._trlib._trlib, scipy.optimize._lbfgsb, _moduleTNC, scipy.optimize._moduleTNC, scipy.optimize._cobyla, scipy.optimize._slsqp, scipy.optimize._minpack, scipy.optimize._lsq.givens_elimination, scipy.optimize._zeros, scipy.optimize._cython_nnls, scipy._lib._uarray._uarray, scipy.special._ufuncs_cxx, scipy.special._ufuncs, scipy.special._specfun, scipy.special._comb, scipy.special._ellip_harm_2, scipy.linalg._decomp_interpolative, scipy.optimize._bglu_dense, scipy.optimize._lsap, scipy.spatial._ckdtree, scipy.spatial._qhull, scipy.spatial._voronoi, scipy.spatial._distance_wrap, scipy.spatial._hausdorff, scipy.spatial.transform._rotation, scipy.optimize._direct, cuda_utils, __triton_launcher (total: 162)
(WorkerDict pid=193641) [2025-06-25 18:48:01 TP1] Scheduler hit an exception: Traceback (most recent call last):
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py", line 2311, in run_scheduler_process
(WorkerDict pid=193641)     scheduler.event_loop_overlap()
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
(WorkerDict pid=193641)     return func(*args, **kwargs)
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py", line 661, in event_loop_overlap
(WorkerDict pid=193641)     recv_reqs = self.recv_requests()
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/sglang/srt/managers/scheduler.py", line 872, in recv_requests
(WorkerDict pid=193641)     recv_reqs = broadcast_pyobj(
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/sglang/srt/utils.py", line 950, in broadcast_pyobj
(WorkerDict pid=193641)     dist.broadcast(tensor_size, src=src, group=dist_group)
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
(WorkerDict pid=193641)     return func(*args, **kwargs)
(WorkerDict pid=193641)   File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2730, in broadcast
(WorkerDict pid=193641)     work.wait()
(WorkerDict pid=193641) RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:534] Connection closed by peer [11.46.27.250]:47809
(WorkerDict pid=193641) 
Training Progress:   0%|          | 1/580 [00:21<3:26:51, 21.44s/it]
(WorkerDict pid=193641) /share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 2 leaked semaphore objects to clean up at shutdown
(WorkerDict pid=193641)   warnings.warn('resource_tracker: There appear to be %d '
(WorkerDict pid=193641) /share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked shared_memory objects to clean up at shutdown
(WorkerDict pid=193641)   warnings.warn('resource_tracker: There appear to be %d '
(raylet) A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffff8d6e751785c07185abec77a801000000 Worker ID: 3586079486b2191487fc2ba281d7ef56d4113939c5a2a5f6ef7745ec Node ID: 7c5fd14e0e20c8f1599aec596ae954b16507fa119c7253f4f3ef2330 Worker IP address: 11.46.27.250 Worker port: 43035 Worker PID: 193641 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
Error executing job with overrides: ['algorithm.adv_estimator=gae', 'data.train_files=/share/huitianlin/data/gsm8k/train.parquet', 'data.val_files=/share/huitianlin/data/gsm8k/test.parquet', 'data.train_batch_size=64', 'data.max_prompt_length=512', 'data.max_response_length=512', 'data.return_raw_chat=True', 'actor_rollout_ref.model.path=/share/huitianlin/models/Qwen2.5-0.5B-Instruct', 'actor_rollout_ref.model.use_shm=False', 'actor_rollout_ref.model.lora_rank=0', 'actor_rollout_ref.model.lora_alpha=0', 'actor_rollout_ref.actor.optim.lr=1e-6', 'actor_rollout_ref.model.use_remove_padding=True', 'actor_rollout_ref.model.use_fused_kernels=False', 'actor_rollout_ref.model.fused_kernel_options.impl_backend=torch', 'actor_rollout_ref.actor.ppo_mini_batch_size=32', 'actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2', 'actor_rollout_ref.actor.strategy=fsdp', 'actor_rollout_ref.actor.fsdp_config.param_offload=False', 'actor_rollout_ref.actor.fsdp_config.optimizer_offload=False', 'actor_rollout_ref.actor.fsdp_config.fsdp_size=-1', 'actor_rollout_ref.actor.ulysses_sequence_parallel_size=1', "actor_rollout_ref.actor.checkpoint.save_contents=['model','optimizer','extra']", 'actor_rollout_ref.actor.use_kl_loss=False', 'actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2', 'actor_rollout_ref.rollout.tensor_model_parallel_size=2', 'actor_rollout_ref.rollout.name=sglang', 'actor_rollout_ref.rollout.mode=async', 'actor_rollout_ref.rollout.load_format=dummy_dtensor', 'actor_rollout_ref.rollout.layered_summon=False', 'actor_rollout_ref.rollout.gpu_memory_utilization=0.8', 'actor_rollout_ref.rollout.enable_chunked_prefill=True', 'actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=2', 'actor_rollout_ref.ref.fsdp_config.param_offload=True', 'critic.optim.lr=1e-5', 'critic.model.use_remove_padding=True', 'critic.model.path=/share/huitianlin/models/Qwen2.5-0.5B-Instruct', 'critic.model.enable_gradient_checkpointing=False', 'critic.ppo_micro_batch_size_per_gpu=2', 'critic.model.fsdp_config.param_offload=False', 'critic.model.fsdp_config.optimizer_offload=False', 'custom_reward_function.path=null', 'custom_reward_function.name=null', 'algorithm.use_kl_in_reward=False', 'algorithm.kl_penalty=kl', 'algorithm.kl_ctrl.kl_coef=0.001', 'trainer.critic_warmup=0', 'trainer.logger=[console]', 'trainer.project_name=verl-test', 'trainer.experiment_name=qwen2.5-0.5b-instruct-function-reward-minimal', 'trainer.nnodes=1', 'trainer.n_gpus_per_node=2', 'trainer.val_before_train=False', 'trainer.test_freq=-1', 'trainer.save_freq=-1', 'trainer.resume_mode=disable', 'trainer.total_epochs=5', 'trainer.device=cuda']
Traceback (most recent call last):
  File "/share/huitianlin/verl/verl/trainer/main_ppo.py", line 31, in main
    run_ppo(config)
  File "/share/huitianlin/verl/verl/trainer/main_ppo.py", line 54, in run_ppo
    ray.get(runner.run.remote(config))
  File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 104, in wrapper
    return func(*args, **kwargs)
  File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/ray/_private/worker.py", line 2849, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/ray/_private/worker.py", line 937, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(RuntimeError): ray::TaskRunner.run() (pid=192441, ip=11.46.27.250, actor_id=e541f671fe590f5052df22bf01000000, repr=<main_ppo.TaskRunner object at 0x7eec064ad030>)
  File "/share/huitianlin/verl/verl/trainer/main_ppo.py", line 190, in run
    trainer.fit()
  File "/share/huitianlin/verl/verl/trainer/ppo/ray_trainer.py", line 983, in fit
    self.async_rollout_manager.wake_up()
  File "/share/huitianlin/verl/verl/workers/rollout/async_server.py", line 183, in wake_up
    ray.get([server.wake_up.remote() for server in self.async_llm_servers])
ray.exceptions.RayTaskError(RuntimeError): ray::AsyncSglangServer.wake_up() (pid=11943, ip=11.46.27.250, actor_id=7b44345e32ee73cce680842d01000000, repr=<verl.workers.rollout.sglang_rollout.async_sglang_server.AsyncSglangServer object at 0x7fa650047040>)
  File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/concurrent/futures/_base.py", line 438, in result
    return self.__get_result()
  File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/concurrent/futures/_base.py", line 390, in __get_result
    raise self._exception
  File "/share/huitianlin/verl/verl/workers/rollout/sglang_rollout/async_sglang_server.py", line 69, in wake_up
    await asyncio.gather(*tasks)
  File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/asyncio/tasks.py", line 648, in _wrap_awaitable
    return (yield from awaitable.__await__())
ray.exceptions.RayTaskError(RuntimeError): ray::WorkerDict.wake_up() (pid=7248, ip=11.46.27.250, actor_id=e2bec28d1aa44c4bf924970601000000, repr=<verl.single_controller.ray.base.WorkerDict object at 0x7f93ce71a5c0>)
  File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/concurrent/futures/_base.py", line 445, in result
    return self.__get_result()
  File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/concurrent/futures/_base.py", line 390, in __get_result
    raise self._exception
  File "/share/huitianlin/verl/verl/single_controller/ray/base.py", line 667, in async_func
    return await getattr(self.worker_dict[key], name)(*args, **kwargs)
  File "/share/huitianlin/verl/verl/single_controller/base/decorator.py", line 546, in async_inner
    return await func(*args, **kwargs)
  File "/share/huitianlin/verl/verl/workers/fsdp_workers.py", line 1518, in wake_up
    await self.rollout.wake_up()
  File "/share/huitianlin/verl/verl/workers/rollout/sglang_rollout/sglang_rollout.py", line 1211, in wake_up
    await self.sharding_manager.wake_up()  # pylint: disable=C2801
  File "/share/huitianlin/verl/verl/workers/sharding_manager/fsdp_sglang.py", line 200, in wake_up
    await self.update_weights(params)
  File "/share/huitianlin/verl/verl/workers/sharding_manager/fsdp_sglang.py", line 166, in update_weights
    dist.gather_object(
  File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
    return func(*args, **kwargs)
  File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3149, in gather_object
    all_gather(object_size_list, local_size, group=group)
  File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper
    return func(*args, **kwargs)
  File "/share/huitianlin/miniconda/envs/verl_sglang/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3709, in all_gather
    work.wait()
RuntimeError: [/pytorch/third_party/gloo/gloo/transport/tcp/pair.cc:534] Connection closed by peer [11.46.27.250]:35016

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.

@chenhaiq Yes, I used the latest code. Could you please provide some hints to help me solve it?

@Alice1998
Copy link

I face the same issue! with "There appear to be 1 leaked shared_memory objects to clean up at shutdown“
when running the run_qwen2.5-3b_gsm8k_multiturn.sh with the latest code.

chenjiaoAngel added a commit to chenjiaoAngel/verl that referenced this pull request Nov 14, 2025
…#1769)

Changed sglang rollout pipeline to async method to have better
performance.

resolved issue volcengine#1721

### Checklist Before Starting

- [ done ] Search for similar PR(s).

### What does this PR do?

In previous version, the sglang async_generate is called with a sync ray
actor with lots of sync functions, and resulted poor performance ( GPU
SM is 20% in TP2)

This PR changed the while pipeline to async method. 

Performance comparsion to previous "sglang_async" mode:
  | sglang_async (old) | async (new) | % faster
-- | -- | -- | --
timing_s/gen | 95 | 25 | 73.68%
timing_s/step | 170 | 90 | 47.06%
perf/throughput | 2700 | 4000 | 48.15%

### High-Level Design

see volcengine#1698

This is a follow up task from above PR.


### Usage Example

examples/grpo_trainer/run_qwen2-7b_seq_balance.sh

### Test

.github/workflows/e2e_ppo_trainer.yml

### Additional Info.

- **Issue Number**: Fixes issue volcengine#1721

### Checklist Before Submitting

- [ done ] Read the [Contribute
Guide](https://github.com/volcengine/verl?tab=readme-ov-file#contribution-guide).
- [ done ] Apply [pre-commit
checks](https://github.com/volcengine/verl?tab=readme-ov-file#code-linting-and-formatting).
- [ done ] Add `[BREAKING]` to the PR title if it breaks any API.
- [ done ] Update the documentation about your changes in the
[docs](https://github.com/volcengine/verl/tree/main/docs).
- [ done ] Add CI test(s) if necessary.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants