Skip to content

Commit e9c6cbc

Browse files
hmellorMu Huai
authored andcommitted
Improve-mm-and-pooler-and-decoding-configs (vllm-project#16789)
Signed-off-by: Harry Mellor <[email protected]> Signed-off-by: Mu Huai <[email protected]>
1 parent bb37ee5 commit e9c6cbc

File tree

14 files changed

+84
-78
lines changed

14 files changed

+84
-78
lines changed

docs/source/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,7 @@ llm = LLM(
788788
Online serving:
789789

790790
```bash
791-
vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt image=4
791+
vllm serve Qwen/Qwen2-VL-7B-Instruct --limit-mm-per-prompt '{"image":4}'
792792
```
793793

794794
**This is no longer required if you are using vLLM V1.**

docs/source/serving/multimodal_inputs.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ First, launch the OpenAI-compatible server:
228228

229229
```bash
230230
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
231-
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
231+
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}'
232232
```
233233

234234
Then, you can use the OpenAI client as follows:

examples/offline_inference/mistral-small.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
# # Mistral format
1717
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
1818
# --tokenizer-mode mistral --config-format mistral --load-format mistral \
19-
# --limit-mm-per-prompt 'image=4' --max-model-len 16384
19+
# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384
2020
#
2121
# # HF format
2222
# vllm serve mistralai/Mistral-Small-3.1-24B-Instruct-2503 \
23-
# --limit-mm-per-prompt 'image=4' --max-model-len 16384
23+
# --limit-mm-per-prompt '{"image":4}' --max-model-len 16384
2424
# ```
2525
#
2626
# - Client:

examples/online_serving/openai_chat_completion_client_for_multimodal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
1010
(multi-image inference with Phi-3.5-vision-instruct)
1111
vllm serve microsoft/Phi-3.5-vision-instruct --task generate \
12-
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt image=2
12+
--trust-remote-code --max-model-len 4096 --limit-mm-per-prompt '{"image":2}'
1313
1414
(audio inference with Ultravox)
1515
vllm serve fixie-ai/ultravox-v0_5-llama-3_2-1b --max-model-len 4096

tests/engine/test_arg_utils.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@
2424
}),
2525
])
2626
def test_limit_mm_per_prompt_parser(arg, expected):
27+
"""This functionality is deprecated and will be removed in the future.
28+
This argument should be passed as JSON string instead.
29+
30+
TODO: Remove with nullable_kvs."""
2731
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
2832
if arg is None:
2933
args = parser.parse_args([])

tests/entrypoints/openai/test_audio.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def server():
2727
"--enforce-eager",
2828
"--trust-remote-code",
2929
"--limit-mm-per-prompt",
30-
f"audio={MAXIMUM_AUDIOS}",
30+
str({"audio": MAXIMUM_AUDIOS}),
3131
]
3232

3333
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

tests/entrypoints/openai/test_video.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def server():
3131
"--enforce-eager",
3232
"--trust-remote-code",
3333
"--limit-mm-per-prompt",
34-
f"video={MAXIMUM_VIDEOS}",
34+
str({"video": MAXIMUM_VIDEOS}),
3535
]
3636

3737
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

tests/entrypoints/openai/test_vision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def server():
3535
"--enforce-eager",
3636
"--trust-remote-code",
3737
"--limit-mm-per-prompt",
38-
f"image={MAXIMUM_IMAGES}",
38+
str({"image": MAXIMUM_IMAGES}),
3939
]
4040

4141
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:

tests/entrypoints/openai/test_vision_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def server():
3737
"--enforce-eager",
3838
"--trust-remote-code",
3939
"--limit-mm-per-prompt",
40-
f"image={MAXIMUM_IMAGES}",
40+
str({"image": MAXIMUM_IMAGES}),
4141
"--chat-template",
4242
str(vlm2vec_jinja_path),
4343
]

tests/models/decoder_only/audio_language/test_ultravox.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ def audio(request):
4848
])
4949
def server(request, audio_assets):
5050
args = [
51-
"--dtype=bfloat16", "--max-model-len=4096", "--enforce-eager",
52-
f"--limit-mm-per-prompt=audio={len(audio_assets)}",
53-
"--trust-remote-code"
51+
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
52+
"--limit-mm-per-prompt",
53+
str({"audio": len(audio_assets)}), "--trust-remote-code"
5454
] + [
5555
f"--{key.replace('_','-')}={value}"
5656
for key, value in request.param.items()

vllm/config.py

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from importlib.util import find_spec
1818
from pathlib import Path
1919
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
20-
Optional, Protocol, TypeVar, Union)
20+
Optional, Protocol, TypeVar, Union, get_args)
2121

2222
import torch
2323
from pydantic import BaseModel, Field, PrivateAttr
@@ -2725,13 +2725,16 @@ def verify_with_model_config(self, model_config: ModelConfig):
27252725
self.prompt_adapter_dtype)
27262726

27272727

2728+
@config
27282729
@dataclass
27292730
class MultiModalConfig:
27302731
"""Controls the behavior of multimodal models."""
27312732

27322733
limit_per_prompt: Mapping[str, int] = field(default_factory=dict)
27332734
"""
27342735
The maximum number of input items allowed per prompt for each modality.
2736+
This should be a JSON string that will be parsed into a dictionary.
2737+
Defaults to 1 (V0) or 999 (V1) for each modality.
27352738
"""
27362739

27372740
def compute_hash(self) -> str:
@@ -2753,24 +2756,20 @@ def compute_hash(self) -> str:
27532756
usedforsecurity=False).hexdigest()
27542757
return hash_str
27552758

2756-
def get_default_limit_per_prompt(self) -> int:
2757-
"""
2758-
Return the default number of input items allowed per prompt
2759-
for any modality if not specified by the user.
2760-
"""
2761-
return 999 if envs.VLLM_USE_V1 else 1
2762-
27632759
def get_limit_per_prompt(self, modality: str) -> int:
27642760
"""
27652761
Get the maximum number of input items allowed per prompt
27662762
for the given modality.
27672763
"""
2768-
default = self.get_default_limit_per_prompt()
2769-
return self.limit_per_prompt.get(modality, default)
2764+
return self.limit_per_prompt.get(
2765+
modality,
2766+
999 if envs.VLLM_USE_V1 else 1,
2767+
)
27702768

27712769
# TODO: Add configs to init vision tower or not.
27722770

27732771

2772+
@config
27742773
@dataclass
27752774
class PoolerConfig:
27762775
"""Controls the behavior of output pooling in pooling models."""
@@ -3095,15 +3094,28 @@ def get_served_model_name(model: str,
30953094
return served_model_name
30963095

30973096

3097+
GuidedDecodingBackendV0 = Literal["auto", "outlines", "lm-format-enforcer",
3098+
"xgrammar"]
3099+
GuidedDecodingBackendV1 = Literal["auto", "xgrammar", "guidance"]
3100+
3101+
3102+
@config
30983103
@dataclass
30993104
class DecodingConfig:
3100-
"""Dataclass which contains the decoding strategy of the engine"""
3105+
"""Dataclass which contains the decoding strategy of the engine."""
31013106

3102-
# Which guided decoding algo to use.
3103-
# 'outlines' / 'lm-format-enforcer' / 'xgrammar'
3104-
guided_decoding_backend: str = "auto" if envs.VLLM_USE_V1 else "xgrammar"
3107+
guided_decoding_backend: Union[
3108+
GuidedDecodingBackendV0,
3109+
GuidedDecodingBackendV1] = "auto" if envs.VLLM_USE_V1 else "xgrammar"
3110+
"""Which engine will be used for guided decoding (JSON schema / regex etc)
3111+
by default. With "auto", we will make opinionated choices based on request
3112+
contents and what the backend libraries currently support, so the behavior
3113+
is subject to change in each release."""
31053114

31063115
reasoning_backend: Optional[str] = None
3116+
"""Select the reasoning parser depending on the model that you're using.
3117+
This is used to parse the reasoning content into OpenAI API format.
3118+
Required for `--enable-reasoning`."""
31073119

31083120
def compute_hash(self) -> str:
31093121
"""
@@ -3125,17 +3137,12 @@ def compute_hash(self) -> str:
31253137
return hash_str
31263138

31273139
def __post_init__(self):
3128-
v0_valid_guided_backends = [
3129-
'outlines', 'lm-format-enforcer', 'xgrammar', 'auto'
3130-
]
3131-
v1_valid_guided_backends = ['xgrammar', 'guidance', 'auto']
3132-
31333140
backend = GuidedDecodingParams(
31343141
backend=self.guided_decoding_backend).backend_name
31353142
if envs.VLLM_USE_V1:
3136-
valid_guided_backends = v1_valid_guided_backends
3143+
valid_guided_backends = get_args(GuidedDecodingBackendV1)
31373144
else:
3138-
valid_guided_backends = v0_valid_guided_backends
3145+
valid_guided_backends = get_args(GuidedDecodingBackendV0)
31393146
if backend not in valid_guided_backends:
31403147
raise ValueError(f"Invalid guided_decoding_backend '{backend}',"
31413148
f" must be one of {valid_guided_backends}")

vllm/engine/arg_utils.py

Lines changed: 32 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,12 @@
2020
DecodingConfig, Device, DeviceConfig,
2121
DistributedExecutorBackend, HfOverrides,
2222
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
23-
ModelConfig, ModelImpl, ObservabilityConfig,
24-
ParallelConfig, PoolerConfig, PoolType,
25-
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
26-
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
27-
VllmConfig, get_attr_docs, get_field)
23+
ModelConfig, ModelImpl, MultiModalConfig,
24+
ObservabilityConfig, ParallelConfig, PoolerConfig,
25+
PoolType, PromptAdapterConfig, SchedulerConfig,
26+
SchedulerPolicy, SpeculativeConfig, TaskOption,
27+
TokenizerPoolConfig, VllmConfig, get_attr_docs,
28+
get_field)
2829
from vllm.executor.executor_base import ExecutorBase
2930
from vllm.logger import init_logger
3031
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@@ -190,7 +191,8 @@ class EngineArgs:
190191
TokenizerPoolConfig.pool_type
191192
tokenizer_pool_extra_config: dict[str, Any] = \
192193
get_field(TokenizerPoolConfig, "extra_config")
193-
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
194+
limit_mm_per_prompt: Mapping[str, int] = \
195+
get_field(MultiModalConfig, "limit_per_prompt")
194196
mm_processor_kwargs: Optional[Dict[str, Any]] = None
195197
disable_mm_preprocessor_cache: bool = False
196198
enable_lora: bool = False
@@ -252,7 +254,7 @@ class EngineArgs:
252254

253255
additional_config: Optional[Dict[str, Any]] = None
254256
enable_reasoning: Optional[bool] = None
255-
reasoning_parser: Optional[str] = None
257+
reasoning_parser: Optional[str] = DecodingConfig.reasoning_backend
256258
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
257259

258260
def __post_init__(self):
@@ -478,18 +480,22 @@ def get_kwargs(cls: type[Config]) -> dict[str, Any]:
478480
'Examples:\n'
479481
'- 1k → 1000\n'
480482
'- 1K → 1024\n')
481-
parser.add_argument(
483+
484+
# Guided decoding arguments
485+
guided_decoding_kwargs = get_kwargs(DecodingConfig)
486+
guided_decoding_group = parser.add_argument_group(
487+
title="DecodingConfig",
488+
description=DecodingConfig.__doc__,
489+
)
490+
guided_decoding_group.add_argument(
482491
'--guided-decoding-backend',
483-
type=str,
484-
default=DecodingConfig.guided_decoding_backend,
485-
help='Which engine will be used for guided decoding'
486-
' (JSON schema / regex etc) by default. Currently support '
487-
'https://github.com/mlc-ai/xgrammar and '
488-
'https://github.com/guidance-ai/llguidance.'
489-
'Valid backend values are "xgrammar", "guidance", and "auto". '
490-
'With "auto", we will make opinionated choices based on request '
491-
'contents and what the backend libraries currently support, so '
492-
'the behavior is subject to change in each release.')
492+
**guided_decoding_kwargs["guided_decoding_backend"])
493+
guided_decoding_group.add_argument(
494+
"--reasoning-parser",
495+
# This choices is a special case because it's not static
496+
choices=list(ReasoningParserManager.reasoning_parsers),
497+
**guided_decoding_kwargs["reasoning_backend"])
498+
493499
parser.add_argument(
494500
'--logits-processor-pattern',
495501
type=optional_str,
@@ -697,18 +703,14 @@ def get_kwargs(cls: type[Config]) -> dict[str, Any]:
697703
**tokenizer_kwargs["extra_config"])
698704

699705
# Multimodal related configs
700-
parser.add_argument(
701-
'--limit-mm-per-prompt',
702-
type=nullable_kvs,
703-
default=EngineArgs.limit_mm_per_prompt,
704-
# The default value is given in
705-
# MultiModalConfig.get_default_limit_per_prompt
706-
help=('For each multimodal plugin, limit how many '
707-
'input instances to allow for each prompt. '
708-
'Expects a comma-separated list of items, '
709-
'e.g.: `image=16,video=2` allows a maximum of 16 '
710-
'images and 2 videos per prompt. Defaults to '
711-
'1 (V0) or 999 (V1) for each modality.'))
706+
multimodal_kwargs = get_kwargs(MultiModalConfig)
707+
multimodal_group = parser.add_argument_group(
708+
title="MultiModalConfig",
709+
description=MultiModalConfig.__doc__,
710+
)
711+
multimodal_group.add_argument('--limit-mm-per-prompt',
712+
**multimodal_kwargs["limit_per_prompt"])
713+
712714
parser.add_argument(
713715
'--mm-processor-kwargs',
714716
default=None,
@@ -1018,16 +1020,6 @@ def get_kwargs(cls: type[Config]) -> dict[str, Any]:
10181020
"If enabled, the model will be able to generate reasoning content."
10191021
)
10201022

1021-
parser.add_argument(
1022-
"--reasoning-parser",
1023-
type=str,
1024-
choices=list(ReasoningParserManager.reasoning_parsers),
1025-
default=None,
1026-
help=
1027-
"Select the reasoning parser depending on the model that you're "
1028-
"using. This is used to parse the reasoning content into OpenAI "
1029-
"API format. Required for ``--enable-reasoning``.")
1030-
10311023
parser.add_argument(
10321024
"--disable-cascade-attn",
10331025
action="store_true",

vllm/multimodal/processing.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# SPDX-License-Identifier: Apache-2.0
2+
import json
23
import re
34
import sys
45
from abc import ABC, abstractmethod
@@ -1117,8 +1118,9 @@ def _to_mm_items(
11171118

11181119
if num_items > allowed_limit:
11191120
raise ValueError(
1120-
f"You set or defaulted to {modality}={allowed_limit} "
1121-
f"in --limit-mm-per-prompt`, but passed {num_items} "
1121+
"You set or defaulted to "
1122+
f"'{json.dumps({modality: allowed_limit})}' in "
1123+
f"`--limit-mm-per-prompt`, but passed {num_items} "
11221124
f"{modality} items in the same prompt.")
11231125

11241126
return mm_items

vllm/multimodal/registry.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
import functools
4+
import json
45
from collections import UserDict
56
from collections.abc import Mapping, Sequence
67
from dataclasses import dataclass
@@ -194,9 +195,9 @@ def map_input(
194195
max_items = self._limits_by_model[model_config][data_key]
195196
if num_items > max_items:
196197
raise ValueError(
197-
f"You set {data_key}={max_items} (or defaulted to 1) in "
198-
f"`--limit-mm-per-prompt`, but found {num_items} items "
199-
"in the same prompt.")
198+
f"You set '{json.dumps({data_key: max_items})}' (or "
199+
"defaulted to 1) in `--limit-mm-per-prompt`, but found "
200+
f"{num_items} items in the same prompt.")
200201

201202
input_dict = plugin.map_input(model_config, data_value,
202203
mm_processor_kwargs)

0 commit comments

Comments
 (0)