-
Notifications
You must be signed in to change notification settings - Fork 364
[Model] Pixtral Support #253
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Pixtral isn't yet fully supported in transformers library. PR pending release of Pixtral in transformers package. |
Exciting! @AndreSlavescu seems it is now supported in the transformer https://github.com/huggingface/transformers/tree/main/src/transformers/models/pixtral, |
yes, I'll try to finish this either today or tomorrow |
@Tcc0403 pinging If you'd like to take a look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not familiar with pixtral but it looks like it's just a base model. The loss isn't computed in the forward pass, so there's no need to patch CrossEntropy and FusedLinearCrossEntropy.
else: | ||
output = model(**batch) | ||
loss = output.loss | ||
loss.backward() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, if we have to generate pixel value input just for this specific vision model, do we really want to support pure vision models in Liger Kernel? cc @lancerts @shivam15s @yundai424
If the answer is yes, then I think we should make another convergence test file for vision models to follow this type of workflow, generating pixel inputs and applying custom loss function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I was thinking to implement a custom loss function for this, because patching with FusedLinearCrossEntropy won't work for this with the current API.
And yes, the main difficulty with integrating this with the current mini model tests is that it expects pixel inputs to the constructor of the PixtralVisionModel. So I have done a hacky solution for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah making model input a fixture or whatnot and loss function into something also customizable (all configured in mini model config) is a good idea 🤔
Summary
This PR aims to support pixtral
Testing Done
tested model + tested monkey patch
make test
to ensure correctnessmake checkstyle
to ensure code stylemake test-convergence
to ensure convergence