Skip to content

Commit 5380cd7

Browse files
authored
model(vlm): pixtral (#5084)
1 parent b2e95f6 commit 5380cd7

File tree

16 files changed

+1125
-39
lines changed

16 files changed

+1125
-39
lines changed

docs/supported_models/vision_language_models.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ python3 -m sglang.launch_server \
2020
| **Janus-Pro** (1B, 7B) | `deepseek-ai/Janus-Pro-7B` | `janus-pro` | DeepSeek’s open-source multimodal model capable of both image understanding and generation. Janus-Pro employs a decoupled architecture for separate visual encoding paths, enhancing performance in both tasks. |
2121
| **MiniCPM-V / MiniCPM-o** | `openbmb/MiniCPM-V-2_6` | `minicpmv` | MiniCPM-V (2.6, ~8B) supports image inputs, and MiniCPM-o adds audio/video; these multimodal LLMs are optimized for end-side deployment on mobile/edge devices. |
2222
| **Llama 3.2 Vision** (11B) | `meta-llama/Llama-3.2-11B-Vision-Instruct` | `llama_3_vision` | Vision-enabled variant of Llama 3 (11B) that accepts image inputs for visual question answering and other multimodal tasks. |
23+
| **Pixtral** (12B, 124B) | `mistral-community/pixtral-12b` | `mistral` | Pixtral is a vision-language model from Mistral AI that can process both text and images. |
2324
| **LLaVA** (v1.5 & v1.6) | *e.g.* `liuhaotian/llava-v1.5-13b` | `vicuna_v1.1` | Open vision-chat models that add an image encoder to LLaMA/Vicuna (e.g. LLaMA2 13B) for following multimodal instruction prompts. |
2425
| **LLaVA-NeXT** (8B, 72B) | `lmms-lab/llava-next-72b` | `chatml-llava` | Improved LLaVA models (with an 8B Llama3 version and a 72B version) offering enhanced visual instruction-following and accuracy on multimodal benchmarks. |
2526
| **LLaVA-OneVision** | `lmms-lab/llava-onevision-qwen2-7b-ov` | `chatml-llava` | Enhanced LLaVA variant integrating Qwen as the backbone; supports multiple images (and even video frames) as inputs via an OpenAI Vision API-compatible format. |

examples/runtime/README.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ The `hidden_states` folder contains examples on how to extract hidden states usi
3333
* `hidden_states_engine.py`: An example how to extract hidden states using the Engine API.
3434
* `hidden_states_server.py`: An example how to extract hidden states using the Server API.
3535

36-
## LLaVA-NeXT
36+
## Multimodal
37+
38+
SGLang supports multimodal inputs for various model architectures. The `multimodal` folder contains examples showing how to use urls, files or encoded data to make requests to multimodal models. Examples include querying the [Llava-OneVision](multimodal/llava_onevision_server.py) model (image, multi-image, video), Llava-backed [Qwen-Llava](multimodal/qwen_llava_server.py) and [Llama3-Llava](multimodal/llama3_llava_server.py) models (image, multi-image), and Mistral AI's [Pixtral](multimodal/pixtral_server.py) (image, multi-image).
3739

38-
SGLang support LLaVA-OneVision with single-image, multi-image and video are supported. The folder `llava_onevision` shows how to do this.
3940

4041
## Token In, Token Out
4142

examples/runtime/llava_onevision/http_llama3_llava_test.py renamed to examples/runtime/multimodal/llama3_llava_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# Endpoint Service CLI:
77
python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000
88
9-
python3 http_llama3_llava_test.py
9+
python3 llama3_llava_server.py
1010
1111
Output:
1212
"Friends posing for a fun photo with a life-sized teddy bear, creating a playful and memorable moment."

examples/runtime/llava_onevision/http_llava_onevision_test.py renamed to examples/runtime/multimodal/llava_onevision_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
44
python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8
55
6-
python3 http_llava_onevision_test.py
6+
python3 llava_onevision_server.py
77
"""
88

99
import base64
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
"""
2+
Usage:
3+
# Run a Pixtral model with SGLang:
4+
# HuggingFace:
5+
python -m sglang.launch_server --model-path mistral-community/pixtral-12b --port=30000
6+
# ModelScope:
7+
python -m sglang.launch_server --model-path AI-ModelScope/pixtral-12b --port=30000
8+
9+
# Then test it with:
10+
python pixtral_server.py
11+
12+
This script tests Pixtral model with both single and multiple images.
13+
"""
14+
15+
import argparse
16+
import asyncio
17+
import json
18+
19+
import aiohttp
20+
import requests
21+
22+
IMAGE_TOKEN_SEP = "\n[IMG]"
23+
ROUTE = "/generate"
24+
25+
26+
async def send_request(url, data, delay=0):
27+
await asyncio.sleep(delay)
28+
async with aiohttp.ClientSession() as session:
29+
async with session.post(url, json=data) as resp:
30+
output = await resp.json()
31+
return output
32+
33+
34+
async def test_concurrent(args):
35+
url = f"{args.host}:{args.port}{ROUTE}"
36+
37+
# Single image test
38+
if args.single_image:
39+
prompt = f"<s>[INST]Describe this image in detail.{IMAGE_TOKEN_SEP}[/INST]"
40+
image_url = "https://picsum.photos/id/237/400/300"
41+
modality = ["image"]
42+
# Multiple images test
43+
else:
44+
image_urls = [
45+
"https://picsum.photos/id/237/400/300",
46+
"https://picsum.photos/id/27/500/500",
47+
]
48+
prompt = f"<s>[INST]How many photos are there? Describe each in a very short sentence.{IMAGE_TOKEN_SEP * len(image_urls)}[/INST]"
49+
image_url = image_urls
50+
modality = ["multi-images"]
51+
52+
response = await send_request(
53+
url,
54+
{
55+
"text": prompt,
56+
"image_data": image_url,
57+
"sampling_params": {
58+
"max_new_tokens": 100,
59+
"temperature": 0.7,
60+
"top_p": 0.9,
61+
},
62+
"modalities": modality,
63+
},
64+
)
65+
66+
print(f"Response: {response}")
67+
if "text" in response:
68+
print("\nOutput text:", response["text"])
69+
70+
71+
def test_streaming(args):
72+
url = f"{args.host}:{args.port}/generate"
73+
74+
# Single image test
75+
if args.single_image:
76+
prompt = f"<s>[INST]Describe this image in detail.{IMAGE_TOKEN_SEP}[/INST]"
77+
image_data = "https://picsum.photos/id/237/400/300"
78+
modality = ["image"]
79+
# Multiple images test
80+
else:
81+
image_urls = [
82+
"https://picsum.photos/id/237/400/300",
83+
"https://picsum.photos/id/27/500/500",
84+
]
85+
prompt = f"<s>[INST]How many photos are there? Describe each in a very short sentence.{IMAGE_TOKEN_SEP * len(image_urls)}[/INST]"
86+
image_data = image_urls
87+
modality = ["multi-images"]
88+
89+
pload = {
90+
"text": prompt,
91+
"image_data": image_data,
92+
"sampling_params": {"max_new_tokens": 100, "temperature": 0.7, "top_p": 0.9},
93+
"modalities": modality,
94+
"stream": True,
95+
}
96+
97+
response = requests.post(url, json=pload, stream=True)
98+
99+
print("Streaming response:")
100+
prev = 0
101+
for chunk in response.iter_lines(decode_unicode=False):
102+
chunk = chunk.decode("utf-8")
103+
if chunk and chunk.startswith("data:"):
104+
if chunk == "data: [DONE]":
105+
break
106+
data = json.loads(chunk[5:].strip("\n"))
107+
output = data["text"].strip()
108+
print(output[prev:], end="", flush=True)
109+
prev = len(output)
110+
print("\n")
111+
112+
113+
if __name__ == "__main__":
114+
parser = argparse.ArgumentParser()
115+
parser.add_argument("--host", type=str, default="http://127.0.0.1")
116+
parser.add_argument("--port", type=int, default=30000)
117+
parser.add_argument(
118+
"--single-image",
119+
action="store_true",
120+
help="Test with single image instead of multiple images",
121+
)
122+
parser.add_argument("--no-stream", action="store_true", help="Don't test streaming")
123+
args = parser.parse_args()
124+
125+
asyncio.run(test_concurrent(args))
126+
if not args.no_stream:
127+
test_streaming(args)

examples/runtime/llava_onevision/http_qwen_llava_test.py renamed to examples/runtime/multimodal/qwen_llava_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# Endpoint Service CLI:
77
python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8
88
9-
python3 http_qwen_llava_test.py
9+
python3 qwen_llava_server.py
1010
1111
Output:
1212
"Two children pose with a large teddy bear, one holding a smaller stuffed bear, in a room with an American flag and potted plants."

python/sglang/lang/chat_template.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,21 @@ def get_chat_template_by_model_path(model_path):
194194
)
195195
)
196196

197+
# Reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
198+
register_chat_template(
199+
ChatTemplate(
200+
name="mistral",
201+
default_system_prompt=None,
202+
role_prefix_and_suffix={
203+
"system": ("[SYSTEM_PROMPT] ", " [/SYSTEM_PROMPT]"),
204+
"user": ("[INST] ", " [/INST]"),
205+
"assistant": ("", " </s><s>"),
206+
},
207+
stop_str=("</s>",),
208+
image_token="[IMG]",
209+
)
210+
)
211+
197212
register_chat_template(
198213
ChatTemplate(
199214
name="llama-3-instruct",
@@ -509,13 +524,19 @@ def match_vicuna(model_path: str):
509524
@register_chat_template_matching_function
510525
def match_llama2_chat(model_path: str):
511526
if re.search(
512-
r"llama-2.*chat|(mistral|mixtral).*instruct|codellama.*instruct",
527+
r"llama-2.*chat|codellama.*instruct",
513528
model_path,
514529
re.IGNORECASE,
515530
):
516531
return "llama-2-chat"
517532

518533

534+
@register_chat_template_matching_function
535+
def match_mistral(model_path: str):
536+
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
537+
return "mistral"
538+
539+
519540
@register_chat_template_matching_function
520541
def match_llama3_instruct(model_path: str):
521542
if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE):

python/sglang/srt/configs/model_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -545,6 +545,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
545545
"Llama4ForConditionalGeneration",
546546
"LlavaMistralForCausalLM",
547547
"LlavaQwenForCausalLM",
548+
"LlavaForConditionalGeneration",
548549
"LlavaVidForCausalLM",
549550
"MiniCPMO",
550551
"MiniCPMV",

python/sglang/srt/conversation.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,20 @@ def generate_chat_conv(
634634
)
635635
)
636636

637+
# reference: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/chat_template.json
638+
register_conv_template(
639+
Conversation(
640+
name="mistral",
641+
system_template="[SYSTEM_PROMPT]\n{system_message}\n[/SYSTEM_PROMPT]\n\n",
642+
roles=("[INST]", "[/INST]"),
643+
sep_style=SeparatorStyle.LLAMA2,
644+
sep=" ",
645+
sep2=" </s><s>",
646+
stop_str=["[INST]", "[/INST]", "[SYSTEM_PROMPT]", "[/SYSTEM_PROMPT]"],
647+
image_token="[IMG]",
648+
)
649+
)
650+
637651
# reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
638652
register_conv_template(
639653
Conversation(
@@ -880,13 +894,19 @@ def match_vicuna(model_path: str):
880894
@register_conv_template_matching_function
881895
def match_llama2_chat(model_path: str):
882896
if re.search(
883-
r"llama-2.*chat|(mistral|mixtral).*instruct|codellama.*instruct",
897+
r"llama-2.*chat|codellama.*instruct",
884898
model_path,
885899
re.IGNORECASE,
886900
):
887901
return "llama-2"
888902

889903

904+
@register_conv_template_matching_function
905+
def match_mistral(model_path: str):
906+
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
907+
return "mistral"
908+
909+
890910
@register_conv_template_matching_function
891911
def match_deepseek_vl(model_path: str):
892912
if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):

python/sglang/srt/managers/multimodal_processors/llava.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
import asyncio
2+
import importlib
23
from typing import List, Optional, Union
34

45
import numpy as np
6+
from transformers.models.auto.processing_auto import (
7+
PROCESSOR_MAPPING_NAMES as HF_MAPPING_NAMES,
8+
)
59

10+
import sglang.srt.managers.multimodal_processor as sgl_mm_processor_utils
611
from sglang.srt.managers.multimodal_processors.base_processor import (
712
BaseMultimodalProcessor,
813
)
914
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
1015
from sglang.srt.mm_utils import expand2square, process_anyres_image
1116
from sglang.srt.models.llava import (
17+
LlavaForConditionalGeneration,
1218
LlavaLlamaForCausalLM,
1319
LlavaMistralForCausalLM,
1420
LlavaQwenForCausalLM,
@@ -133,6 +139,7 @@ async def process_mm_data_async(
133139
img_data, aspect_ratio, grid_pinpoints
134140
)
135141
)
142+
136143
res = await asyncio.gather(*res)
137144
for pixel_v, image_h, image_s in res:
138145
pixel_values.append(pixel_v)
@@ -165,3 +172,42 @@ async def process_mm_data_async(
165172
)
166173
],
167174
}
175+
176+
177+
class LlavaMultimodalProcessor(BaseMultimodalProcessor):
178+
"""
179+
This is a wrapper class used to identify the multimodal processor for Llava architecture models.
180+
"""
181+
182+
models = [LlavaForConditionalGeneration]
183+
184+
def _get_sgl_processor_cls(self, model_type: str):
185+
if hf_name := HF_MAPPING_NAMES.get(model_type):
186+
sgl_mm_processor_set = sgl_mm_processor_utils.PROCESSOR_MAPPING.values()
187+
sgl_processor_cls = list(
188+
filter(lambda p: p.__name__ == hf_name, sgl_mm_processor_set)
189+
)
190+
if sgl_processor_cls:
191+
return sgl_processor_cls[0]
192+
raise ValueError(
193+
f"Cannot find corresponding multimodal processor registered in sglang for model type `{model_type}`"
194+
)
195+
196+
def __init__(self, hf_config, server_args, _processor):
197+
assert hasattr(hf_config, "vision_config")
198+
assert hasattr(hf_config, "text_config")
199+
self.vision_config = hf_config.vision_config
200+
self.text_config = hf_config.text_config
201+
self.hf_config = hf_config
202+
203+
if vision_type := getattr(self.vision_config, "model_type"):
204+
self.inner = self._get_sgl_processor_cls(vision_type)(
205+
hf_config, server_args, _processor
206+
)
207+
else:
208+
raise ValueError(
209+
f"Required `vision_config.model_type` is not found in hf_config: `{hf_config}`"
210+
)
211+
212+
async def process_mm_data_async(self, *args, **kwargs):
213+
return await self.inner.process_mm_data_async(*args, **kwargs)

0 commit comments

Comments
 (0)