Skip to content

enable loading llama 3_2 from meta checkpoint #1688

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 26, 2024

Conversation

felipemello1
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

HF checkpoint doesnt have output.weight. Meta checkpoint has it, and its the same as the embedding.
Before this PR, the following error would happen:

RuntimeError: Error(s) in loading state_dict for TransformerDecoder:
        Unexpected key(s) in state_dict: "output.weight".

I tried to minimize code footprint by not creating a new checkpoint loading/saving logic. Instead, we just pop the key or add it.

NOTE: a more robust solution should be to add a flag to all tied embedding models (qwen, gemma, llama3_2) tie_weights. If False, then allow to train them separately, like it is with qwen. In the future, its possible that someone has such a model, and when they load it on torchtune, the output would be replaced with the embedding, because this is hardcoded. Since this is an edge case, this PR will not address this at this moment

Test plan

  1. train with meta ckpt and confirm that before and after training the embedding and output are the same

code:

yaml

checkpointer:
  _component_: torchtune.training.FullModelMetaCheckpointer
  checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/original/
  checkpoint_files: [
    consolidated.00.pth
  ]
  recipe_checkpoint: null
  output_dir: /tmp/Llama-3.2-1B-Instruct/
  model_type: LLAMA3_2
resume_from_checkpoint: False

run:

tune run full_finetune_single_device --config llama3_2/1B_full_single_device epochs=1 max_steps_per_epoch=200 compile=True batch_size=8 metric_logger=torchtune.training.metric_logging.WandBLogger 

test

import torch

weight_dict = torch.load("/tmp/Llama-3.2-1B-Instruct/meta_model_0.pt", map_location="cpu")
weight_dict_original = torch.load("/tmp/Llama-3.2-1B-Instruct/original/consolidated.00.pth", map_location="cpu")

#before training
print(weight_dict["output.weight"].mean())
print(weight_dict["tok_embeddings.weight"].mean())
print()
#after training
print(weight_dict_original["output.weight"].mean())
print(weight_dict_original["tok_embeddings.weight"].mean())

output:

tensor(-8.8692e-05, dtype=torch.bfloat16)
tensor(-8.8692e-05, dtype=torch.bfloat16)

tensor(-9.3460e-05, dtype=torch.bfloat16)
tensor(-9.3460e-05, dtype=torch.bfloat16)
  1. train with meta checkpoint and resume from training
image
  1. train with hf checkpoint and confirm output is not there
from safetensors import safe_open
weight_dict_original = {}
with safe_open("/tmp/Llama-3.2-1B-Instruct/model.safetensors", framework="pt", device="cpu") as f:
    for k in f.keys():
        weight_dict_original[k] = f.get_tensor(k)
weight_dict = torch.load("/tmp/Llama-3.2-1B-Instruct/hf_model_0001_0.pt", map_location="cpu")

# before training
assert "model.output.weight" not in weight_dict
print(weight_dict["model.embed_tokens.weight"].mean())
print()

# after training
assert "output.weight" not in weight_dict_original
print(weight_dict_original["model.embed_tokens.weight"].mean())

output

tensor(-8.8692e-05, dtype=torch.bfloat16)

tensor(-9.3460e-05, dtype=torch.bfloat16)

Copy link

pytorch-bot bot commented Sep 26, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1688

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 083d2ad with merge base 4e69db8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 26, 2024
Copy link
Contributor

@SLR722 SLR722 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me! Thanks for the quick fix!

@felipemello1 felipemello1 merged commit 7da96d1 into pytorch:main Sep 26, 2024
17 checks passed
@felipemello1 felipemello1 deleted the llama32_ckpt branch September 26, 2024 14:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants