Skip to content

[Bug]: Llama4 MoE weight_loaders Removed from Parameters After Initial Load, Causing Errors During Refitting #17915

Closed
@yaoyu-33

Description

@yaoyu-33

Your current environment

Environment:

  • vLLM version: [0.8.4]
  • ray version: [2.43.0]

🐛 Describe the bug

Description:
We are integrating Llama4 (with Mixture of Experts) support into a reinforcement learning framework. While the initial model loading and weight initialization from Hugging Face checkpoints work correctly with vLLM, we encounter an issue when attempting to refit or update the model weights via worker extensions. The process fails with an error indicating that an MoE parameter (specifically observed with w2_weight) is missing its weight_loader. Dense parameters do not exhibit this issue.

Through debugging, we've observed that the weight_loader attribute for MoE parameters (e.g., w2_weight) appears to be removed or unset after the initial model loading and processing steps, specifically during the _process_weights_after_loading call which subsequently involves Llama4Model.setattr. This removal prevents subsequent operations like weight updates that might rely on these loaders.

Steps to Reproduce:

  1. Initialize a Llama4 model with MoE layers using vLLM.
  2. Load Hugging Face pretrained weights. (This step completes successfully).
  3. Attempt to update or refit the MoE weights using a mechanism that might rely on the weight_loader attribute (e.g., through worker extensions or a custom refitting loop).
  4. Observe an error related to a missing weight_loader for an MoE parameter.

Expected Behavior:
The weight_loader attributes for MoE parameters should persist after model initialization and weight loading, allowing for subsequent weight modification or refitting operations.

Actual Behavior:
The weight_loader attribute for MoE parameters (e.g., model.layers[x].block_sparse_moe.experts.experts[y].w2_weight.weight_loader) is found that the attribute is missing after the initial model loading process, leading to errors during weight update attempts.

Traceback/Debugging Information:
The following traceback was observed when a watcher was placed on the w2_weight attribute to monitor its modifications. The weight_loader was removed from w2_weight at the point indicated, leading to failures in later refitting stages.

(RayWorkerWrapper pid=2587067)   File "/usr/local/lib/python3.12/dist-packages/ray/_private/workers/default_worker.py", line 297, in 
(RayWorkerWrapper pid=2587067)     worker.main_loop()
(RayWorkerWrapper pid=2587067)   File "[USER_CODE_PATH]/lib/python3.12/site-packages/ray/_private/worker.py", line 935, in main_loop
(RayWorkerWrapper pid=2587067)     self.core_worker.run_task_loop()
(RayWorkerWrapper pid=2587067)   File "[USER_CODE_PATH]/lib/python3.12/site-packages/ray/_private/function_manager.py", line 696, in actor_method_executor
(RayWorkerWrapper pid=2587067)     return method(__ray_actor, *args, **kwargs)
(RayWorkerWrapper pid=2587067)   File "[USER_CODE_PATH]/lib/python3.12/site-packages/ray/util/tracing/tracing_helper.py", line 463, in _resume_span
(RayWorkerWrapper pid=2587067)     return method(self, *_args, **_kwargs)
(RayWorkerWrapper pid=2587067)   File "[USER_CODE_PATH]/lib/python3.12/site-packages/vllm/worker/worker_base.py", line 612, in execute_method
(RayWorkerWrapper pid=2587067)     return run_method(self, method, args, kwargs)
(RayWorkerWrapper pid=2587067)   File "[USER_CODE_PATH]/lib/python3.12/site-packages/vllm/utils.py", line 2378, in run_method
(RayWorkerWrapper pid=2587067)     return func(*args, **kwargs)
(RayWorkerWrapper pid=2587067)   File "[USER_CODE_PATH]/lib/python3.12/site-packages/vllm/v1/worker/gpu_worker.py", line 136, in load_model
(RayWorkerWrapper pid=2587067)     self.model_runner.load_model()
(RayWorkerWrapper pid=2587067)   File "[USER_CODE_PATH]/lib/python3.12/site-packages/vllm/v1/worker/gpu_model_runner.py", line 1289, in load_model
(RayWorkerWrapper pid=2587067)     self.model = get_model(vllm_config=self.vllm_config)
(RayWorkerWrapper pid=2587067)   File "[USER_CODE_PATH]/lib/python3.12/site-packages/vllm/model_executor/model_loader/__init__.py", line 14, in get_model
(RayWorkerWrapper pid=2587067)     return loader.load_model(vllm_config=vllm_config)
(RayWorkerWrapper pid=2587067)   File "[USER_CODE_PATH]/lib/python3.12/site-packages/vllm/model_executor/model_loader/loader.py", line 504, in load_model
(RayWorkerWrapper pid=2587067)     _process_weights_after_loading(model, model_config, target_device)
(RayWorkerWrapper pid=2587067)   File "[USER_CODE_PATH]/lib/python3.12/site-packages/vllm/model_executor/model_loader/loader.py", line 180, in _process_weights_after_loading
(RayWorkerWrapper pid=2587067)     quant_method.process_weights_after_loading(module)
(RayWorkerWrapper pid=2587067)   File "[USER_CODE_PATH]/lib/python3.12/site-packages/vllm/model_executor/layers/fused_moe/layer.py", line 119, in process_weights_after_loading
(RayWorkerWrapper pid=2587067)     layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
(RayWorkerWrapper pid=2587067)   File "[USER_CODE_PATH]/lib/python3.12/site-packages/vllm/model_executor/models/llama4.py", line 363, in __setattr__ 
(RayWorkerWrapper pid=2587067)     # This is where weight_loader was observed to be removed from w2_weight
(RayWorkerWrapper pid=2587067)     traceback.print_stack() # User added for debugging

Additional Context/Observations:

  • The issue seems specific to MoE parameters; dense parameters in the same model can be refitted without this problem.
  • The problem occurs after the vllm.model_executor.model_loader.loader._process_weights_after_loading function calls quant_method.process_weights_after_loading, which in turn for MoE layers like FusedMoELayer appears to re-assign the weight parameter (e.g., layer.w2_weight = torch.nn.Parameter(...)). This reassignment might be losing the original weight_loader that was attached during the initial loading phase.
  • The __setattr__ method in vllm.model_executor.models.llama4.py (or a similar model-specific file) is implicated by the traceback as the point where the attribute modification (and effective removal of weight_loader) occurs.

We suspect that when the MoE weights are processed and potentially padded/reformatted (e.g., in FusedMoELayer.process_weights_after_loading), the new torch.nn.Parameter created does not retain the weight_loader attribute from the original tensor.

Please let us know if any further information or a minimal reproducible example would be helpful.

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    Status

    Done

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions