enable loading llama 3_2 from meta checkpoint #1688
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Context
What is the purpose of this PR? Is it to
HF checkpoint doesnt have output.weight. Meta checkpoint has it, and its the same as the embedding.
Before this PR, the following error would happen:
I tried to minimize code footprint by not creating a new checkpoint loading/saving logic. Instead, we just pop the key or add it.
NOTE: a more robust solution should be to add a flag to all tied embedding models (qwen, gemma, llama3_2) tie_weights. If False, then allow to train them separately, like it is with qwen. In the future, its possible that someone has such a model, and when they load it on torchtune, the output would be replaced with the embedding, because this is hardcoded. Since this is an edge case, this PR will not address this at this moment
Test plan
code:
yaml
run:
test
output:
output