Skip to content

[Bugfix] fix pp for llama4 #16746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 18, 2025
Merged

Conversation

luccafong
Copy link
Collaborator

@luccafong luccafong commented Apr 17, 2025

This PR fixes PP for llama4 (#16385)

Loading language model with architecture to support PP verification and correct the prefix to separate model loading weights for language model and rest of the models so both PP=0 and PP=1 can work.

VLLM_DISABLE_COMPILE_CACHE=1 python -m vllm.entrypoints.openai.api_server --model $LLAMA_DIR --served-model-name maverick  --max-model-len 20000 --tensor-parallel 4 --pipeline-parallel-size 2 --gpu-memory-utilization 0.95  --port 8000 
curl http://localhost:8000/v1/completions \
    -H "Content-Type: application/json" \
    -d '{
        "model": "maverick",
        "prompt": "San Francisco is a",
        "max_tokens": 7,
        "temperature": 0
    }'
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.930|±  |0.0181|
|     |       |strict-match    |     5|exact_match|↑  |0.935|±  |0.0175|

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@luccafong luccafong marked this pull request as ready for review April 17, 2025 01:16
Signed-off-by: Lu Fang <[email protected]>
@luccafong luccafong requested a review from houseroad April 17, 2025 03:56
@luccafong luccafong mentioned this pull request Apr 17, 2025
1 task
@luccafong luccafong requested a review from ywang96 April 17, 2025 04:03
self.language_model = _initialize_model(
vllm_config=vllm_config.with_hf_config(config.text_config),
vllm_config=vllm_config.with_hf_config(config.text_config,
["LlamaForCausalLM"]),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use Llama4ForCausalLM?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Llama4ForCasualLm is not registered architecture, we should avoid using that which requires adding a lot of hacks as in the initial PR

Copy link
Member

@ywang96 ywang96 Apr 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably a dumb question, but since _initialize_model is already pointing to model_class=Llama4ForCausalLM, why do we need to override the architectures here to LlamaForCausalLM?

def _initialize_model(
vllm_config: VllmConfig,
*,
prefix: str = "",
model_class: Optional[type[nn.Module]] = None,
) -> nn.Module:
"""Initialize a model with the given configurations."""
model_config = vllm_config.model_config
if model_class is None:
model_class, _ = get_model_architecture(model_config)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

during __post__init__ of the hf config when we call replace inside with_hf_config function

return replace(self, model_config=model_config)
, it will check if the architecture is supporting PP or not via
self.model_config.verify_with_parallel_config(self.parallel_config)
, where None architecture will raise an issue as in #16385 in below code

normalized_arch = list(
filter(lambda model: model in self.models, architectures))

@@ -824,7 +824,7 @@ def load_weights(self, weights: Iterable[Tuple[str,
# language_model is an Llama4ForCausalLM instance. We load it's
# using llama4's load_weights routine.
language_model_weights, other_weights = self.separate_weights(
weights, prefix="language_model.model.")
weights, prefix="language_model.")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering why this issue was not triggered before?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

language_model.lm_head can also be loaded by the parent model weight loader if not PP enabled, but for PP, as we split the weights into 2 parts, the lm_head is missing in PP=0 so it will raise an issue weight not found. while in llama4.py model loading we have logic handling is_pp_missing_parameter to avoid the exception

Copy link
Collaborator

@houseroad houseroad left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix.

@houseroad houseroad added the ready ONLY add when PR is ready to merge/full CI is needed label Apr 17, 2025
@DarkLight1337 DarkLight1337 merged commit e31045f into vllm-project:main Apr 18, 2025
62 checks passed
yangw-dev pushed a commit to yangw-dev/vllm that referenced this pull request Apr 21, 2025
jikunshang pushed a commit to jikunshang/vllm that referenced this pull request Apr 29, 2025
lk-chen pushed a commit to lk-chen/vllm that referenced this pull request Apr 29, 2025
adobrzyn pushed a commit to HabanaAI/vllm-fork that referenced this pull request Apr 30, 2025
Signed-off-by: Lu Fang <[email protected]>
Signed-off-by: Agata Dobrzyniewicz <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants