Skip to content

[Model][VLM] Add Qwen2.5-Omni model support (thinker only) #15130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 39 commits into from
Apr 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
732ec71
Initial commit for Qwen2.5-Omni (thinker only).
fyabc Mar 19, 2025
30d8ac6
update doc and typing
fyabc Mar 27, 2025
f8668bf
Merge branch 'refs/heads/main' into qwen2_omni_public_v1
fyabc Mar 27, 2025
3108e36
fix typing error
fyabc Mar 27, 2025
42ab3b7
fix typing
fyabc Mar 27, 2025
139d305
fix typing
fyabc Mar 27, 2025
04fe220
fix bug in multi-audio merging
fyabc Mar 27, 2025
7bbd238
add more examples
fyabc Mar 27, 2025
d4a61c9
adapt for transformers update
fyabc Mar 28, 2025
d40f54f
fix bug of 'use_audio_in_video'
fyabc Mar 28, 2025
a6f878e
Merge branch 'main' into qwen2_omni_public_v1
ywang96 Mar 31, 2025
d3eb60d
precommit
ywang96 Mar 31, 2025
0cd5aa8
update V1 interface
ywang96 Mar 31, 2025
98226ad
add TODO
ywang96 Mar 31, 2025
6b4c705
Update docs/source/models/supported_models.md
ywang96 Mar 31, 2025
286e755
assert VLLM_USE_V1=0 audio in video example
ywang96 Mar 31, 2025
9cf9d26
adapt for transformers PR
fyabc Mar 31, 2025
53501f3
multiple fixes
ywang96 Mar 31, 2025
512f874
squeeze only one dimension
ywang96 Apr 1, 2025
9c984d0
fix squeezing
ywang96 Apr 1, 2025
3908518
minor refactoring
ywang96 Apr 1, 2025
864accf
precommit
ywang96 Apr 1, 2025
adc5cdf
Merge branch 'main' into qwen2_omni_public_v1
ywang96 Apr 1, 2025
7108ba3
reformat
fyabc Apr 2, 2025
71f96e4
add omni to chat utils
ywang96 Apr 2, 2025
ebd8b88
fix model type
ywang96 Apr 2, 2025
512bd41
fix typo
ywang96 Apr 3, 2025
68004d8
Merge remote-tracking branch 'upstream/main' into qwen2_omni_public_v1
ywang96 Apr 3, 2025
1dac918
Merge branch 'refs/heads/main' into qwen2_omni_public_v1
fyabc Apr 8, 2025
1b0bf89
Fix vision attention qkv
fyabc Apr 9, 2025
f5def01
fix hard code
fyabc Apr 9, 2025
753858b
fix hidden_size
fyabc Apr 9, 2025
ed6dca1
Merge branch 'refs/heads/main' into qwen2_omni_public_v1
fyabc Apr 15, 2025
976fbf0
fix tests
fyabc Apr 16, 2025
f726ac8
refactor dummy inputs builder
fyabc Apr 16, 2025
e168e09
Update qwen2_5_omni_thinker.py
wangxiongts Apr 18, 2025
7994068
Update qwen2_5_omni_thinker.py
wangxiongts Apr 18, 2025
41c5855
fix test registry
fyabc Apr 18, 2025
d1e2046
Merge remote-tracking branch 'origin/qwen2_omni_public_v1' into qwen2…
fyabc Apr 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -1026,6 +1026,13 @@ See [this page](#generative-models) for more information on how to use generativ
* ✅︎
* ✅︎
* ✅︎
- * `Qwen2_5OmniThinkerForConditionalGeneration`
* Qwen2.5-Omni
* T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>+</sup>
* `Qwen/Qwen2.5-Omni-7B`
*
* ✅︎
* ✅︎\*
- * `SkyworkR1VChatModel`
* Skywork-R1V-38B
* T + I
Expand Down Expand Up @@ -1095,6 +1102,14 @@ For more details, please see: <gh-pr:4087#issuecomment-2250397630>
Our PaliGemma implementations have the same problem as Gemma 3 (see above) for both V0 and V1.
:::

:::{note}
To use Qwen2.5-Omni, you have to install a fork of Hugging Face Transformers library from source via
`pip install git+https://github.com/BakerBunker/transformers.git@qwen25omni`.

Read audio from video pre-processing is currently supported on V0 (but not V1), because overlapping modalities is not yet supported in V1.
`--mm-processor-kwargs '{"use_audio_in_video": True}'`.
:::

### Pooling Models

See [this page](pooling-models) for more information on how to use pooling models.
Expand Down
31 changes: 31 additions & 0 deletions examples/offline_inference/audio_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,36 @@ def run_qwen2_audio(question: str, audio_count: int) -> ModelRequestData:
)


# Qwen2.5-Omni
def run_qwen2_5_omni(question: str, audio_count: int):
model_name = "Qwen/Qwen2.5-Omni-7B"

engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
limit_mm_per_prompt={"audio": audio_count},
)

audio_in_prompt = "".join([
"<|audio_bos|><|AUDIO|><|audio_eos|>\n" for idx in range(audio_count)
])

default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech.")

prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n"
f"{audio_in_prompt}{question}<|im_end|>\n"
"<|im_start|>assistant\n")
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
)


# Ultravox 0.5-1B
def run_ultravox(question: str, audio_count: int) -> ModelRequestData:
model_name = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
Expand Down Expand Up @@ -182,6 +212,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
"minicpmo": run_minicpmo,
"phi4_mm": run_phi4mm,
"qwen2_audio": run_qwen2_audio,
"qwen2_5_omni": run_qwen2_5_omni,
"ultravox": run_ultravox,
"whisper": run_whisper,
}
Expand Down
32 changes: 32 additions & 0 deletions examples/offline_inference/qwen2_5_omni/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Qwen2.5-Omni Offline Inference Examples

This folder provides several example scripts on how to inference Qwen2.5-Omni offline.

## Thinker Only

```bash
# Audio + image + video
python examples/offline_inference/qwen2_5_omni/only_thinker.py -q mixed_modalities

# Read vision and audio inputs from a single video file
# NOTE: V1 engine does not support interleaved modalities yet.
VLLM_USE_V1=0 python examples/offline_inference/qwen2_5_omni/only_thinker.py -q use_audio_in_video

# Multiple audios
VLLM_USE_V1=0 python examples/offline_inference/qwen2_5_omni/only_thinker.py -q multi_audios
```

This script will run the thinker part of Qwen2.5-Omni, and generate text response.

You can also test Qwen2.5-Omni on a single modality:

```bash
# Process audio inputs
python examples/offline_inference/audio_language.py --model-type qwen2_5_omni

# Process image inputs
python examples/offline_inference/vision_language.py --modality image --model-type qwen2_5_omni

# Process video inputs
python examples/offline_inference/vision_language.py --modality video --model-type qwen2_5_omni
```
160 changes: 160 additions & 0 deletions examples/offline_inference/qwen2_5_omni/only_thinker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# SPDX-License-Identifier: Apache-2.0
"""
This example shows how to use vLLM for running offline inference
with the correct prompt format on Qwen2.5-Omni (thinker only).
"""

from typing import NamedTuple

import vllm.envs as envs
from vllm import LLM, SamplingParams
from vllm.assets.audio import AudioAsset
from vllm.assets.image import ImageAsset
from vllm.assets.video import VideoAsset
from vllm.utils import FlexibleArgumentParser


class QueryResult(NamedTuple):
inputs: dict
limit_mm_per_prompt: dict[str, int]


# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on
# lower-end GPUs.
# Unless specified, these settings have been tested to work on a single L4.

default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech.")


def get_mixed_modalities_query() -> QueryResult:
question = ("What is recited in the audio? "
"What is the content of this image? Why is this video funny?")
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
"<|vision_bos|><|IMAGE|><|vision_eos|>"
"<|vision_bos|><|VIDEO|><|vision_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n")
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio":
AudioAsset("mary_had_lamb").audio_and_sample_rate,
"image":
ImageAsset("cherry_blossom").pil_image.convert("RGB"),
"video":
VideoAsset(name="sample_demo_1.mp4",
num_frames=16).np_ndarrays,
},
},
limit_mm_per_prompt={
"audio": 1,
"image": 1,
"video": 1
},
)


def get_use_audio_in_video_query() -> QueryResult:
question = ("Describe the content of the video, "
"then convert what the baby say into text.")
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n")
asset = VideoAsset(name="sample_demo_1.mp4", num_frames=16)
audio = asset.get_audio(sampling_rate=16000)
assert not envs.VLLM_USE_V1, ("V1 does not support use_audio_in_video. "
"Please launch this example with "
"`VLLM_USE_V1=0`.")
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"video": asset.np_ndarrays,
"audio": audio,
},
"mm_processor_kwargs": {
"use_audio_in_video": True,
},
},
limit_mm_per_prompt={
"audio": 1,
"video": 1
},
)


def get_multi_audios_query() -> QueryResult:
question = "Are these two audio clips the same?"
prompt = (f"<|im_start|>system\n{default_system}<|im_end|>\n"
"<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>"
"<|audio_bos|><|AUDIO|><|audio_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n")
return QueryResult(
inputs={
"prompt": prompt,
"multi_modal_data": {
"audio": [
AudioAsset("winning_call").audio_and_sample_rate,
AudioAsset("mary_had_lamb").audio_and_sample_rate,
],
},
},
limit_mm_per_prompt={
"audio": 2,
},
)


query_map = {
"mixed_modalities": get_mixed_modalities_query,
"use_audio_in_video": get_use_audio_in_video_query,
"multi_audios": get_multi_audios_query,
}


def main(args):
model_name = "Qwen/Qwen2.5-Omni-7B"
query_result = query_map[args.query_type]()

llm = LLM(model=model_name,
max_model_len=5632,
max_num_seqs=5,
limit_mm_per_prompt=query_result.limit_mm_per_prompt,
seed=args.seed)

# We set temperature to 0.2 so that outputs can be different
# even when all prompts are identical when running batch inference.
sampling_params = SamplingParams(temperature=0.2, max_tokens=64)

outputs = llm.generate(query_result.inputs,
sampling_params=sampling_params)

for o in outputs:
generated_text = o.outputs[0].text
print(generated_text)


if __name__ == "__main__":
parser = FlexibleArgumentParser(
description='Demo on using vLLM for offline inference with '
'audio language models')
parser.add_argument('--query-type',
'-q',
type=str,
default="mixed_modalities",
choices=query_map.keys(),
help='Query type.')
parser.add_argument("--seed",
type=int,
default=None,
help="Set the seed when initializing `vllm.LLM`.")

args = parser.parse_args()
main(args)
37 changes: 37 additions & 0 deletions examples/offline_inference/vision_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,42 @@ def run_qwen2_5_vl(questions: list[str], modality: str) -> ModelRequestData:
)


# Qwen2.5-Omni
def run_qwen2_5_omni(questions: list[str], modality: str):
model_name = "Qwen/Qwen2.5-Omni-7B"

engine_args = EngineArgs(
model=model_name,
max_model_len=4096,
max_num_seqs=5,
mm_processor_kwargs={
"min_pixels": 28 * 28,
"max_pixels": 1280 * 28 * 28,
"fps": [1],
},
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
)

if modality == "image":
placeholder = "<|IMAGE|>"
elif modality == "video":
placeholder = "<|VIDEO|>"

default_system = (
"You are Qwen, a virtual human developed by the Qwen Team, Alibaba "
"Group, capable of perceiving auditory and visual inputs, as well as "
"generating text and speech.")

prompts = [(f"<|im_start|>system\n{default_system}<|im_end|>\n"
f"<|im_start|>user\n<|vision_bos|>{placeholder}<|vision_eos|>"
f"{question}<|im_end|>\n"
"<|im_start|>assistant\n") for question in questions]
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
)


# SkyworkR1V
def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
Expand Down Expand Up @@ -1010,6 +1046,7 @@ def run_skyworkr1v(questions: list[str], modality: str) -> ModelRequestData:
"qwen_vl": run_qwen_vl,
"qwen2_vl": run_qwen2_vl,
"qwen2_5_vl": run_qwen2_5_vl,
"qwen2_5_omni": run_qwen2_5_omni,
"skywork_chat": run_skyworkr1v,
"smolvlm": run_smolvlm,
}
Expand Down
17 changes: 17 additions & 0 deletions tests/models/decoder_only/vision_language/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,23 @@
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
"qwen2_5_omni": VLMTestInfo(
models=["Qwen/Qwen2.5-Omni-7B"],
test_type=(
VLMTestType.IMAGE,
VLMTestType.MULTI_IMAGE,
VLMTestType.VIDEO
),
prompt_formatter=lambda img_prompt: f"<|im_start|>User\n{img_prompt}<|im_end|>\n<|im_start|>assistant\n", # noqa: E501
img_idx_to_prompt=lambda idx: "<|vision_bos|><|IMAGE|><|vision_eos|>", # noqa: E501
video_idx_to_prompt=lambda idx: "<|vision_bos|><|VIDEO|><|vision_eos|>", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
#### Extended model tests
"aria": VLMTestInfo(
models=["rhymes-ai/Aria"],
Expand Down
1 change: 1 addition & 0 deletions tests/models/multimodal/processing/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def _test_processing_correctness_mistral(
"Qwen/Qwen2-VL-2B-Instruct",
"Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
"Qwen/Qwen2.5-Omni-7B",
"Skywork/Skywork-R1V-38B",
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
Expand Down
2 changes: 2 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,8 @@ def check_available_online(
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct", # noqa: E501
min_transformers_version="4.49"), # noqa: E501
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B", # noqa: E501
min_transformers_version="4.52"), # noqa: E501
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501
Expand Down
18 changes: 17 additions & 1 deletion vllm/assets/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@

from dataclasses import dataclass
from functools import lru_cache
from typing import Literal
from typing import Literal, Optional

import cv2
import numpy as np
import numpy.typing as npt
from huggingface_hub import hf_hub_download
from PIL import Image

from vllm.utils import PlaceholderModule

from .base import get_cache_dir

try:
import librosa
except ImportError:
librosa = PlaceholderModule("librosa") # type: ignore[assignment]


@lru_cache
def download_video_asset(filename: str) -> str:
Expand Down Expand Up @@ -85,3 +92,12 @@ def np_ndarrays(self) -> npt.NDArray:
video_path = download_video_asset(self.name)
ret = video_to_ndarrays(video_path, self.num_frames)
return ret

def get_audio(self, sampling_rate: Optional[float] = None) -> npt.NDArray:
"""
Read audio data from the video asset, used in Qwen2.5-Omni examples.

See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
"""
video_path = download_video_asset(self.name)
return librosa.load(video_path, sr=sampling_rate)[0]
Loading