Skip to content

[Model] Pixtral Support #253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ loss.backward()
| LLaMA 3.2-Vision | `liger_kernel.transformers.apply_liger_kernel_to_mllama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Pixtral | `liger_kernel.transformers.apply_liger_kernel_to_pixtral` | RoPE, RMSNorm, SwiGLU|
| Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma3 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3_text` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def is_xpu_available():

try:
result = subprocess.run("sycl-ls", check=True, capture_output=True, shell=True)
if 'level_zero:gpu' in result.stdout.decode():
if "level_zero:gpu" in result.stdout.decode():
return True
except (subprocess.SubprocessError, FileNotFoundError):
pass
Expand Down
3 changes: 3 additions & 0 deletions src/liger_kernel/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_olmo2 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_paligemma # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_phi3 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_pixtral # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2 # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_5_vl # noqa: F401
from liger_kernel.transformers.monkey_patch import apply_liger_kernel_to_qwen2_vl # noqa: F401
Expand Down Expand Up @@ -93,6 +94,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_olmo2",
"apply_liger_kernel_to_paligemma",
"apply_liger_kernel_to_phi3",
"apply_liger_kernel_to_pixtral",
"apply_liger_kernel_to_qwen2",
"apply_liger_kernel_to_qwen2_5_vl",
"apply_liger_kernel_to_qwen2_vl",
Expand Down Expand Up @@ -147,6 +149,7 @@ def __getattr__(name: str):
"apply_liger_kernel_to_olmo2",
"apply_liger_kernel_to_paligemma",
"apply_liger_kernel_to_phi3",
"apply_liger_kernel_to_pixtral",
"apply_liger_kernel_to_qwen2",
"apply_liger_kernel_to_qwen2_5_vl",
"apply_liger_kernel_to_qwen2_vl",
Expand Down
95 changes: 95 additions & 0 deletions src/liger_kernel/transformers/model/pixtral.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not familiar with pixtral but it looks like it's just a base model. The loss isn't computed in the forward pass, so there's no need to patch CrossEntropy and FusedLinearCrossEntropy.

Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
from typing import Optional
from typing import Tuple
from typing import Union

import torch

from transformers.modeling_outputs import BaseModelOutput


def lce_forward(
self,
inputs_embeds,
attention_mask: Optional[torch.Tensor] = None,
position_embeddings: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**loss_kwargs,
) -> Union[Tuple, BaseModelOutput]:
r"""
Copy paste Pixtral's forward from transformers v4.44.2 but replace torch cross entropy with liger fused linear cross entropy

Args:
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
Embeddings which serve as input to the Transformer.
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.

[What are attention masks?](../glossary#attention-mask)
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs = self.model(
inputs_embeds,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)

hidden_states = outputs[0]

encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
for encoder_layer in self.layers:
if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
encoder_layer.__call__,
hidden_states,
attention_mask,
position_embeddings,
output_attentions,
)
else:
layer_outputs = encoder_layer(
hidden_states,
attention_mask,
position_embeddings=position_embeddings,
output_attentions=output_attentions,
)

hidden_states = layer_outputs[0]

if output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)

if output_hidden_states:
encoder_states = encoder_states + (hidden_states,)

if not return_dict:
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)

return BaseModelOutput(
last_hidden_states=hidden_states,
hidden_states=encoder_states,
attentions=all_attentions,
)
29 changes: 29 additions & 0 deletions src/liger_kernel/transformers/monkey_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from liger_kernel.transformers.model.mixtral import lce_forward_deprecated as mixtral_lce_forward_deprecated
from liger_kernel.transformers.model.phi3 import lce_forward as phi3_lce_forward
from liger_kernel.transformers.model.phi3 import lce_forward_deprecated as phi3_lce_forward_deprecated
from liger_kernel.transformers.model.pixtral import lce_forward as pixtral_lce_forward
from liger_kernel.transformers.model.qwen2 import lce_forward as qwen2_lce_forward
from liger_kernel.transformers.model.qwen2 import lce_forward_deprecated as qwen2_lce_forward_deprecated
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
Expand Down Expand Up @@ -595,6 +596,33 @@ def apply_liger_kernel_to_mixtral(
_patch_rms_norm_module(decoder_layer.post_attention_layernorm)


def apply_liger_kernel_to_pixtral(
rope: bool = True,
rms_norm: bool = True,
fused_linear_cross_entropy: bool = True,
swiglu: bool = True,
) -> None:
"""
Apply Liger kernels to replace original implementation in HuggingFace Pixtral models.

Args:
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
fused_linear_cross_entropy (bool): If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient. Default is True.
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
"""
from transformers.models.pixtral import modeling_pixtral

if rope:
modeling_pixtral.apply_rotary_pos_emb = liger_rotary_pos_emb
if rms_norm:
modeling_pixtral.PixtralRMSNorm = LigerRMSNorm
if fused_linear_cross_entropy:
modeling_pixtral.PixtralTransformer.forward = pixtral_lce_forward
if swiglu:
modeling_pixtral.PixtralMLP = LigerSwiGLUMLP


def apply_liger_kernel_to_gemma(
rope: bool = True,
cross_entropy: bool = False,
Expand Down Expand Up @@ -1561,6 +1589,7 @@ def apply_liger_kernel_to_glm4(
"mistral": apply_liger_kernel_to_mistral,
"mixtral": apply_liger_kernel_to_mixtral,
"olmo2": apply_liger_kernel_to_olmo2,
"pixtral": apply_liger_kernel_to_pixtral,
"qwen2": apply_liger_kernel_to_qwen2,
"qwen3": apply_liger_kernel_to_qwen3,
"qwen3_moe": apply_liger_kernel_to_qwen3_moe,
Expand Down
39 changes: 38 additions & 1 deletion test/convergence/fp32/test_mini_models_with_logits.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from transformers.models.mixtral import MixtralForCausalLM
from transformers.models.phi3 import Phi3Config
from transformers.models.phi3 import Phi3ForCausalLM
from transformers.models.pixtral import PixtralVisionConfig
from transformers.models.pixtral import PixtralVisionModel
from transformers.models.qwen2 import Qwen2Config
from transformers.models.qwen2 import Qwen2ForCausalLM

Expand All @@ -30,6 +32,7 @@
from liger_kernel.transformers import apply_liger_kernel_to_mllama
from liger_kernel.transformers import apply_liger_kernel_to_olmo2
from liger_kernel.transformers import apply_liger_kernel_to_phi3
from liger_kernel.transformers import apply_liger_kernel_to_pixtral
from liger_kernel.transformers import apply_liger_kernel_to_qwen2
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
Expand All @@ -50,6 +53,7 @@
from test.utils import revert_liger_kernel_to_mllama
from test.utils import revert_liger_kernel_to_olmo2
from test.utils import revert_liger_kernel_to_phi3
from test.utils import revert_liger_kernel_to_pixtral
from test.utils import revert_liger_kernel_to_qwen2
from test.utils import revert_liger_kernel_to_qwen2_5_vl
from test.utils import revert_liger_kernel_to_qwen2_vl
Expand Down Expand Up @@ -280,6 +284,24 @@
attn_implementation="sdpa",
),
),
"mini_pixtral": MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_pixtral,
liger_kernel_patch_revert_func=revert_liger_kernel_to_pixtral,
model_class=PixtralVisionModel,
mini_model_config=PixtralVisionConfig(
hidden_size=1024, # 1024
intermediate_size=2048, # 4096
num_hidden_layers=4, # 24
num_attention_heads=8, # 16
num_channels=1, # 3
image_size=256, # 1024
patch_size=16, # 16
hidden_act="gelu", # gelu
attention_dropout=0.0,
rope_theta=10000.0,
initializer_range=0.02,
),
),
"mini_gemma1": MiniModelConfig(
liger_kernel_patch_func=apply_liger_kernel_to_gemma,
liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma,
Expand Down Expand Up @@ -829,7 +851,21 @@ def run_mini_model(
for i in range(num_steps):
batch = next(loader_iter).to(model.device)
optimizer.zero_grad()
output = model(**batch)
if model_name == "mini_pixtral":
dummy_pixel_values = torch.randn(
1,
1,
model.config.num_channels,
model.config.image_size,
model.config.image_size,
device=model.device,
dtype=dtype
)
model_input = {"pixel_values": dummy_pixel_values}
output = model(**model_input)
else:
batch_on_device = {k: v.to(model.device) for k, v in batch.items()}
output = model(**batch_on_device)
output.loss.backward()
optimizer.step()
print(f"Step {i}, Loss: {output.loss.item()}")
Expand Down Expand Up @@ -993,6 +1029,7 @@ def run_mini_model(
# TODO: mixtral is flaky so disable the test for now
# ("mini_mixtral", 32, 1e-4, torch.float32, 5e-4, 1e-4, 5e-3, 1e-5, 1e-2, 1e-5),
# Gemma 1.1 and 2 has more tolerance because currently, the kernel is not a perfect match
("mini_pixtral", 32, 1e-4, torch.float32, 1e-8, 1e-5, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma1.1", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
("mini_gemma2", 32, 1e-4, torch.float32, 1e-8, 1e-4, 5e-3, 1e-5, 5e-3, 1e-5),
Expand Down
11 changes: 11 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,6 +440,17 @@ def revert_liger_kernel_to_phi3(model_config: MiniModelConfig):
print("Liger kernel patches have been reverted.")


def revert_liger_kernel_to_pixtral(model_config: MiniModelConfig):
"""
Revert all Liger kernel patches applied to Pixtral.
"""
from transformers.models.pixtral import modeling_pixtral

importlib.reload(modeling_pixtral)
model_config.model_class = modeling_pixtral.PixtralVisionModel
print("Liger kernel patches have been reverted.")


def revert_liger_kernel_to_olmo2(model_config: MiniModelConfig):
"""
Revert all Liger kernel patches applied to Olmo2.
Expand Down
Loading