-
Notifications
You must be signed in to change notification settings - Fork 373
llama4 ckpt conversion #1816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
llama4 ckpt conversion #1816
Conversation
4c2fd98
to
3b1d43d
Compare
8981d6d
to
4ff7e82
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4ff7e82
to
f589d21
Compare
MaxText/llama4_ckpt_unscanned.py
Outdated
# vision model ########################################### | ||
max_logging.log("Processing vision model") | ||
jax_weights["vision_encoder"]["Llama4VisionModel_0"]["class_embedding"] = ( | ||
chkpt_vars["vision_model.class_embedding"].to(torch.float32).numpy().astype(CAST_DTYPE) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you extract the conversions part into a function and re-use it?
def torch_to_numpy(tensor: torch.Tensor, transpose: bool = False):
"""Converts a PyTorch tensor to a NumPy array with the target dtype."""
result = tensor.to(torch.float32).numpy().astype(CAST_DTYPE)
if transpose:
return result.transpose()
return result
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! I added function _pt_to_np and refactored
f589d21
to
fbb76c9
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
fbb76c9
to
1d19c88
Compare
Description
Joint by @aireenmei and @hengtaoguo .
Add ckpt conversion for llama4 vision model from hf to maxtext.
Tests
Verified in #1809
Checklist
Before submitting this PR, please make sure (put X in square brackets):