Skip to content

Qwen3 VL Moe: Expected self.dtype to be equal to src.dtype #41418

@danielquintas8

Description

@danielquintas8

System Info

  • transformers version: 4.57.0
  • Platform: Linux-5.14.0-452.el9.x86_64-x86_64-with-glibc2.34
  • Python version: 3.12.11
  • Huggingface_hub version: 0.35.3
  • Safetensors version: 0.6.2
  • Accelerate version: 1.10.1
  • Accelerate config: not found
  • DeepSpeed version: 0.17.6
  • PyTorch version (accelerator?): 2.8.0+cu128 (CUDA)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: Yes
  • Using GPU in script?: Yes
  • GPU type: NVIDIA A100 80GB PCIe

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

import os
import torch
from transformers import EarlyStoppingCallback, Qwen3VLMoeForConditionalGeneration, AutoProcessor
from trl import SFTTrainer, SFTConfig
from peft import LoraConfig

from datasets import load_dataset


# Prepare dataset
def formatting_prompts_func(examples):
    images = [[image] for image in examples["image"]]
    prompt_template =[
        {
            "role": "system",
            "content": [
                {
                    "type": "text",
                    "text": """You are an helpful assistant"""
                },
            ]
        },
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                },
                {
                    "type": "text",
                    "text": "Process the provided images."
                }
            ]
        }
    ]
    completion_template = [
        {
            "role": "assistant",
            "content": [
                {
                    "type": "text",
                    "text": "example" ## just a text field to simulate
                }
            ]
        }
    ]

    prompts = [prompt_template] * len(examples)
    completions = [completion_template] * len(examples)
    return {"images": images, "prompt": prompts, "completion": completions}



epochs = 15
learning_rate = 1e-4
per_device_batch_size = 1
gradient_accumulation_steps = 2
weight_decay = 0.01
warmup_steps = 0
lora_rank = 8

adapter_path = f"adapters/reproduce"
output_path = f"{adapter_path}/outputs/"

# Load tokenizer
model_path = "Qwen/Qwen3-VL-30B-A3B-Instruct"
processor = AutoProcessor.from_pretrained(model_path)
model = Qwen3VLMoeForConditionalGeneration.from_pretrained(
    model_path,
    dtype=torch.bfloat16,
)

training_args = SFTConfig(
    max_length=None,
    # padding="longest",
    seed=3407,
    output_dir=output_path,
    num_train_epochs=epochs,
    learning_rate=learning_rate,
    lr_scheduler_type="cosine",
    per_device_train_batch_size=per_device_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    per_device_eval_batch_size=1,
    eval_accumulation_steps=1,
    warmup_steps=warmup_steps,
    logging_steps=10,
    logging_strategy="steps",
    do_eval=True,
    bf16_full_eval=True,
    eval_strategy="epoch",
    save_strategy="epoch",
    packing=False,
    fp16=not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_bf16_supported(),
    gradient_checkpointing=True,
    weight_decay=weight_decay,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    completion_only_loss=True,
    eos_token=processor.tokenizer.eos_token,
)

# LoRA configuration;
peft_config = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_rank * 2,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=[
        "self_attn.q_proj",
        "self_attn.k_proj",
        "self_attn.v_proj",
        "self_attn.o_proj",
    ],
)
model.enable_input_require_grads()

# Load dataset
dataset = load_dataset("huggingface/documentation-images", split="train[:1]")  ## just an image/text dataset to simulate

formatted_dataset = dataset.map(
    formatting_prompts_func,
    batched=True,
    remove_columns=dataset.column_names,
)

# Initialize trainer
trainer = SFTTrainer(
    model=model,
    peft_config=peft_config,
    processing_class=processor,
    args=training_args,
    train_dataset=formatted_dataset,
    eval_dataset=formatted_dataset,
)

trainer.train()
trainer.save_model(output_dir=adapter_path)
$ accelerate launch --mixed_precision bf16 --use_fsdp --fsdp_sharding_strategy 1 --fsdp_backward_prefetch NO_PREFETCH --fsdp_offload_params true --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap Qwen3VLMoeVisionBlock,Qwen3VLMoeTextDecoderLayer --num_machines 1 --num_processes 4 main.py
[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/quintas/workspace/reproduce.py", line 142, in <module>
[rank0]:     # accelerate launch --mixed_precision bf16 --use_fsdp --fsdp_sharding_strategy 1 --fsdp_backward_prefetch NO_PREFETCH --fsdp_offload_params true --fsdp_auto_wrap_policy TRANSFORMER_BASED_WRAP --fsdp_transformer_layer_cls_to_wrap Qwen3VLMoeVisionBlock,Qwen3VLMoeTextDecoderLayer --num_machines 1 --num_processes 4 Qwen3-VL-30B-A3B-Instruct/adapters/invoice_reading/train.py
[rank0]:     ^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/trainer.py", line 2325, in train
[rank0]:     return inner_training_loop(
[rank0]:            ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/trainer.py", line 2790, in _inner_training_loop
[rank0]:     self._maybe_log_save_evaluate(
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/trainer.py", line 3221, in _maybe_log_save_evaluate
[rank0]:     metrics = self._evaluate(trial, ignore_keys_for_eval)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/trainer.py", line 3170, in _evaluate
[rank0]:     metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/trainer.py", line 4489, in evaluate
[rank0]:     output = eval_loop(
[rank0]:              ^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/trainer.py", line 4685, in evaluation_loop
[rank0]:     losses, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
[rank0]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/trainer.py", line 4902, in prediction_step
[rank0]:     loss, outputs = self.compute_loss(
[rank0]:                     ^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/trl/trainer/sft_trainer.py", line 1096, in compute_loss
[rank0]:     (loss, outputs) = super().compute_loss(
[rank0]:                       ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/trainer.py", line 4110, in compute_loss
[rank0]:     outputs = model(**inputs)
[rank0]:               ^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/accelerate/utils/operations.py", line 818, in forward
[rank0]:     return model_forward(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/accelerate/utils/operations.py", line 806, in __call__
[rank0]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 854, in forward
[rank0]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/accelerate/utils/operations.py", line 818, in forward
[rank0]:     return model_forward(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/accelerate/utils/operations.py", line 806, in __call__
[rank0]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]:                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/peft/peft_model.py", line 1850, in forward
[rank0]:     return self.base_model(
[rank0]:            ^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/peft/tuners/tuners_utils.py", line 222, in forward
[rank0]:     return self.model.forward(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/utils/generic.py", line 1064, in wrapper
[rank0]:     outputs = func(self, *args, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py", line 1601, in forward
[rank0]:     outputs = self.model(
[rank0]:               ^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/utils/generic.py", line 1064, in wrapper
[rank0]:     outputs = func(self, *args, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py", line 1389, in forward
[rank0]:     outputs = self.language_model(
[rank0]:               ^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/utils/generic.py", line 1064, in wrapper
[rank0]:     outputs = func(self, *args, **kwargs)
[rank0]:               ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py", line 962, in forward
[rank0]:     layer_outputs = decoder_layer(
[rank0]:                     ^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 854, in forward
[rank0]:     output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/modeling_layers.py", line 94, in __call__
[rank0]:     return super().__call__(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank0]:     return func(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py", line 391, in forward
[rank0]:     hidden_states = self.mlp(hidden_states)
[rank0]:                     ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/quintas/miniconda3/envs/Comudel/lib/python3.12/site-packages/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py", line 149, in forward
[rank0]:     router_weights = torch.zeros_like(router_logits).scatter_(1, router_indices, routing_weights)
[rank0]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: scatter(): Expected self.dtype to be equal to src.dtype

Expected behavior

Train successful

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions