Skip to content

Commit 8363057

Browse files
hmellorlionelvillard
authored andcommitted
Improve configs - TokenizerPoolConfig + DeviceConfig (vllm-project#16603)
Signed-off-by: Harry Mellor <[email protected]>
1 parent f1c539b commit 8363057

File tree

3 files changed

+136
-81
lines changed

3 files changed

+136
-81
lines changed

tests/test_config.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,36 @@
11
# SPDX-License-Identifier: Apache-2.0
22

3-
from dataclasses import asdict
3+
from dataclasses import MISSING, Field, asdict, dataclass, field
44

55
import pytest
66

7-
from vllm.config import ModelConfig, PoolerConfig
7+
from vllm.config import ModelConfig, PoolerConfig, get_field
88
from vllm.model_executor.layers.pooler import PoolingType
99
from vllm.platforms import current_platform
1010

1111

12+
def test_get_field():
13+
14+
@dataclass
15+
class TestConfig:
16+
a: int
17+
b: dict = field(default_factory=dict)
18+
c: str = "default"
19+
20+
with pytest.raises(ValueError):
21+
get_field(TestConfig, "a")
22+
23+
b = get_field(TestConfig, "b")
24+
assert isinstance(b, Field)
25+
assert b.default is MISSING
26+
assert b.default_factory is dict
27+
28+
c = get_field(TestConfig, "c")
29+
assert isinstance(c, Field)
30+
assert c.default == "default"
31+
assert c.default_factory is MISSING
32+
33+
1234
@pytest.mark.parametrize(
1335
("model_id", "expected_runner_type", "expected_task"),
1436
[

vllm/config.py

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,23 @@ def config(cls: type[Config]) -> type[Config]:
182182
return cls
183183

184184

185+
def get_field(cls: type[Config], name: str) -> Field:
186+
"""Get the default factory field of a dataclass by name. Used for getting
187+
default factory fields in `EngineArgs`."""
188+
if not is_dataclass(cls):
189+
raise TypeError("The given class is not a dataclass.")
190+
cls_fields = {f.name: f for f in fields(cls)}
191+
if name not in cls_fields:
192+
raise ValueError(f"Field '{name}' not found in {cls.__name__}.")
193+
named_field: Field = cls_fields.get(name)
194+
if (default_factory := named_field.default_factory) is not MISSING:
195+
return field(default_factory=default_factory)
196+
if (default := named_field.default) is not MISSING:
197+
return field(default=default)
198+
raise ValueError(
199+
f"{cls.__name__}.{name} must have a default value or default factory.")
200+
201+
185202
class ModelConfig:
186203
"""Configuration for the model.
187204
@@ -1364,20 +1381,26 @@ def verify_with_parallel_config(
13641381
logger.warning("Possibly too large swap space. %s", msg)
13651382

13661383

1384+
PoolType = Literal["ray"]
1385+
1386+
1387+
@config
13671388
@dataclass
13681389
class TokenizerPoolConfig:
1369-
"""Configuration for the tokenizer pool.
1390+
"""Configuration for the tokenizer pool."""
13701391

1371-
Args:
1372-
pool_size: Number of tokenizer workers in the pool.
1373-
pool_type: Type of the pool.
1374-
extra_config: Additional config for the pool.
1375-
The way the config will be used depends on the
1376-
pool type.
1377-
"""
1378-
pool_size: int
1379-
pool_type: Union[str, type["BaseTokenizerGroup"]]
1380-
extra_config: dict
1392+
pool_size: int = 0
1393+
"""Number of tokenizer workers in the pool to use for asynchronous
1394+
tokenization. If 0, will use synchronous tokenization."""
1395+
1396+
pool_type: Union[PoolType, type["BaseTokenizerGroup"]] = "ray"
1397+
"""Type of tokenizer pool to use for asynchronous tokenization. Ignored if
1398+
tokenizer_pool_size is 0."""
1399+
1400+
extra_config: dict = field(default_factory=dict)
1401+
"""Additional config for the pool. The way the config will be used depends
1402+
on the pool type. This should be a JSON string that will be parsed into a
1403+
dictionary. Ignored if tokenizer_pool_size is 0."""
13811404

13821405
def compute_hash(self) -> str:
13831406
"""
@@ -1408,7 +1431,7 @@ def __post_init__(self):
14081431
@classmethod
14091432
def create_config(
14101433
cls, tokenizer_pool_size: int,
1411-
tokenizer_pool_type: Union[str, type["BaseTokenizerGroup"]],
1434+
tokenizer_pool_type: Union[PoolType, type["BaseTokenizerGroup"]],
14121435
tokenizer_pool_extra_config: Optional[Union[str, dict]]
14131436
) -> Optional["TokenizerPoolConfig"]:
14141437
"""Create a TokenizerPoolConfig from the given parameters.
@@ -1483,7 +1506,7 @@ class LoadConfig:
14831506
download_dir: Optional[str] = None
14841507
"""Directory to download and load the weights, default to the default
14851508
cache directory of Hugging Face."""
1486-
model_loader_extra_config: Optional[Union[str, dict]] = None
1509+
model_loader_extra_config: dict = field(default_factory=dict)
14871510
"""Extra config for model loader. This will be passed to the model loader
14881511
corresponding to the chosen load_format. This should be a JSON string that
14891512
will be parsed into a dictionary."""
@@ -1514,10 +1537,6 @@ def compute_hash(self) -> str:
15141537
return hash_str
15151538

15161539
def __post_init__(self):
1517-
model_loader_extra_config = self.model_loader_extra_config or {}
1518-
if isinstance(model_loader_extra_config, str):
1519-
self.model_loader_extra_config = json.loads(
1520-
model_loader_extra_config)
15211540
if isinstance(self.load_format, str):
15221541
load_format = self.load_format.lower()
15231542
self.load_format = LoadFormat(load_format)
@@ -2029,9 +2048,19 @@ def is_multi_step(self) -> bool:
20292048
return self.num_scheduler_steps > 1
20302049

20312050

2051+
Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"]
2052+
2053+
2054+
@config
2055+
@dataclass
20322056
class DeviceConfig:
2033-
device: Optional[torch.device]
2034-
device_type: str
2057+
"""Configuration for the device to use for vLLM execution."""
2058+
2059+
device: Union[Device, torch.device] = "auto"
2060+
"""Device type for vLLM execution."""
2061+
device_type: str = field(init=False)
2062+
"""Device type from the current platform. This is set in
2063+
`__post_init__`."""
20352064

20362065
def compute_hash(self) -> str:
20372066
"""
@@ -2053,8 +2082,8 @@ def compute_hash(self) -> str:
20532082
usedforsecurity=False).hexdigest()
20542083
return hash_str
20552084

2056-
def __init__(self, device: str = "auto") -> None:
2057-
if device == "auto":
2085+
def __post_init__(self):
2086+
if self.device == "auto":
20582087
# Automated device type detection
20592088
from vllm.platforms import current_platform
20602089
self.device_type = current_platform.device_type
@@ -2065,7 +2094,7 @@ def __init__(self, device: str = "auto") -> None:
20652094
"to turn on verbose logging to help debug the issue.")
20662095
else:
20672096
# Device type is assigned explicitly
2068-
self.device_type = device
2097+
self.device_type = self.device
20692098

20702099
# Some device types require processing inputs on CPU
20712100
if self.device_type in ["neuron"]:

vllm/engine/arg_utils.py

Lines changed: 61 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616

1717
import vllm.envs as envs
1818
from vllm import version
19-
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
20-
DecodingConfig, DeviceConfig,
19+
from vllm.config import (CacheConfig, CompilationConfig, Config, ConfigFormat,
20+
DecodingConfig, Device, DeviceConfig,
2121
DistributedExecutorBackend, HfOverrides,
2222
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
2323
ModelConfig, ModelImpl, ObservabilityConfig,
24-
ParallelConfig, PoolerConfig, PromptAdapterConfig,
25-
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
26-
TaskOption, TokenizerPoolConfig, VllmConfig,
27-
get_attr_docs)
24+
ParallelConfig, PoolerConfig, PoolType,
25+
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
26+
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
27+
VllmConfig, get_attr_docs, get_field)
2828
from vllm.executor.executor_base import ExecutorBase
2929
from vllm.logger import init_logger
3030
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@@ -44,27 +44,17 @@
4444

4545
ALLOWED_DETAILED_TRACE_MODULES = ["model", "worker", "all"]
4646

47-
DEVICE_OPTIONS = [
48-
"auto",
49-
"cuda",
50-
"neuron",
51-
"cpu",
52-
"tpu",
53-
"xpu",
54-
"hpu",
55-
]
56-
5747
# object is used to allow for special typing forms
5848
T = TypeVar("T")
5949
TypeHint = Union[type[Any], object]
6050
TypeHintT = Union[type[T], object]
6151

6252

63-
def optional_arg(val: str, return_type: type[T]) -> Optional[T]:
53+
def optional_arg(val: str, return_type: Callable[[str], T]) -> Optional[T]:
6454
if val == "" or val == "None":
6555
return None
6656
try:
67-
return cast(Callable, return_type)(val)
57+
return return_type(val)
6858
except ValueError as e:
6959
raise argparse.ArgumentTypeError(
7060
f"Value {val} cannot be converted to {return_type}.") from e
@@ -82,8 +72,11 @@ def optional_float(val: str) -> Optional[float]:
8272
return optional_arg(val, float)
8373

8474

85-
def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
86-
"""Parses a string containing comma separate key [str] to value [int]
75+
def nullable_kvs(val: str) -> Optional[dict[str, int]]:
76+
"""NOTE: This function is deprecated, args should be passed as JSON
77+
strings instead.
78+
79+
Parses a string containing comma separate key [str] to value [int]
8780
pairs into a dictionary.
8881
8982
Args:
@@ -117,6 +110,17 @@ def nullable_kvs(val: str) -> Optional[Mapping[str, int]]:
117110
return out_dict
118111

119112

113+
def optional_dict(val: str) -> Optional[dict[str, int]]:
114+
try:
115+
return optional_arg(val, json.loads)
116+
except ValueError:
117+
logger.warning(
118+
"Failed to parse JSON string. Attempting to parse as "
119+
"comma-separated key=value pairs. This will be deprecated in a "
120+
"future release.")
121+
return nullable_kvs(val)
122+
123+
120124
@dataclass
121125
class EngineArgs:
122126
"""Arguments for vLLM engine."""
@@ -178,12 +182,14 @@ class EngineArgs:
178182
enforce_eager: Optional[bool] = None
179183
max_seq_len_to_capture: int = 8192
180184
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
181-
tokenizer_pool_size: int = 0
185+
tokenizer_pool_size: int = TokenizerPoolConfig.pool_size
182186
# Note: Specifying a tokenizer pool by passing a class
183187
# is intended for expert use only. The API may change without
184188
# notice.
185-
tokenizer_pool_type: Union[str, Type["BaseTokenizerGroup"]] = "ray"
186-
tokenizer_pool_extra_config: Optional[Dict[str, Any]] = None
189+
tokenizer_pool_type: Union[PoolType, Type["BaseTokenizerGroup"]] = \
190+
TokenizerPoolConfig.pool_type
191+
tokenizer_pool_extra_config: dict[str, Any] = \
192+
get_field(TokenizerPoolConfig, "extra_config")
187193
limit_mm_per_prompt: Optional[Mapping[str, int]] = None
188194
mm_processor_kwargs: Optional[Dict[str, Any]] = None
189195
disable_mm_preprocessor_cache: bool = False
@@ -199,14 +205,14 @@ class EngineArgs:
199205
long_lora_scaling_factors: Optional[Tuple[float]] = None
200206
lora_dtype: Optional[Union[str, torch.dtype]] = 'auto'
201207
max_cpu_loras: Optional[int] = None
202-
device: str = 'auto'
208+
device: Device = DeviceConfig.device
203209
num_scheduler_steps: int = SchedulerConfig.num_scheduler_steps
204210
multi_step_stream_outputs: bool = SchedulerConfig.multi_step_stream_outputs
205211
ray_workers_use_nsight: bool = ParallelConfig.ray_workers_use_nsight
206212
num_gpu_blocks_override: Optional[int] = None
207213
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
208-
model_loader_extra_config: Optional[
209-
dict] = LoadConfig.model_loader_extra_config
214+
model_loader_extra_config: dict = \
215+
get_field(LoadConfig, "model_loader_extra_config")
210216
ignore_patterns: Optional[Union[str,
211217
List[str]]] = LoadConfig.ignore_patterns
212218
preemption_mode: Optional[str] = SchedulerConfig.preemption_mode
@@ -294,14 +300,15 @@ def is_custom_type(cls: TypeHint) -> bool:
294300
"""Check if the class is a custom type."""
295301
return cls.__module__ != "builtins"
296302

297-
def get_kwargs(cls: type[Any]) -> dict[str, Any]:
303+
def get_kwargs(cls: type[Config]) -> dict[str, Any]:
298304
cls_docs = get_attr_docs(cls)
299305
kwargs = {}
300306
for field in fields(cls):
301307
name = field.name
302-
# One of these will always be present
303-
default = (field.default_factory
304-
if field.default is MISSING else field.default)
308+
default = field.default
309+
# This will only be True if default is MISSING
310+
if field.default_factory is not MISSING:
311+
default = field.default_factory()
305312
kwargs[name] = {"default": default, "help": cls_docs[name]}
306313

307314
# Make note of if the field is optional and get the actual
@@ -331,8 +338,9 @@ def get_kwargs(cls: type[Any]) -> dict[str, Any]:
331338
elif can_be_type(field_type, float):
332339
kwargs[name][
333340
"type"] = optional_float if optional else float
341+
elif can_be_type(field_type, dict):
342+
kwargs[name]["type"] = optional_dict
334343
elif (can_be_type(field_type, str)
335-
or can_be_type(field_type, dict)
336344
or is_custom_type(field_type)):
337345
kwargs[name]["type"] = optional_str if optional else str
338346
else:
@@ -674,25 +682,19 @@ def get_kwargs(cls: type[Any]) -> dict[str, Any]:
674682
'Additionally for encoder-decoder models, if the '
675683
'sequence length of the encoder input is larger '
676684
'than this, we fall back to the eager mode.')
677-
parser.add_argument('--tokenizer-pool-size',
678-
type=int,
679-
default=EngineArgs.tokenizer_pool_size,
680-
help='Size of tokenizer pool to use for '
681-
'asynchronous tokenization. If 0, will '
682-
'use synchronous tokenization.')
683-
parser.add_argument('--tokenizer-pool-type',
684-
type=str,
685-
default=EngineArgs.tokenizer_pool_type,
686-
help='Type of tokenizer pool to use for '
687-
'asynchronous tokenization. Ignored '
688-
'if tokenizer_pool_size is 0.')
689-
parser.add_argument('--tokenizer-pool-extra-config',
690-
type=optional_str,
691-
default=EngineArgs.tokenizer_pool_extra_config,
692-
help='Extra config for tokenizer pool. '
693-
'This should be a JSON string that will be '
694-
'parsed into a dictionary. Ignored if '
695-
'tokenizer_pool_size is 0.')
685+
686+
# Tokenizer arguments
687+
tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
688+
tokenizer_group = parser.add_argument_group(
689+
title="TokenizerPoolConfig",
690+
description=TokenizerPoolConfig.__doc__,
691+
)
692+
tokenizer_group.add_argument('--tokenizer-pool-size',
693+
**tokenizer_kwargs["pool_size"])
694+
tokenizer_group.add_argument('--tokenizer-pool-type',
695+
**tokenizer_kwargs["pool_type"])
696+
tokenizer_group.add_argument('--tokenizer-pool-extra-config',
697+
**tokenizer_kwargs["extra_config"])
696698

697699
# Multimodal related configs
698700
parser.add_argument(
@@ -784,11 +786,15 @@ def get_kwargs(cls: type[Any]) -> dict[str, Any]:
784786
type=int,
785787
default=EngineArgs.max_prompt_adapter_token,
786788
help='Max number of PromptAdapters tokens')
787-
parser.add_argument("--device",
788-
type=str,
789-
default=EngineArgs.device,
790-
choices=DEVICE_OPTIONS,
791-
help='Device type for vLLM execution.')
789+
790+
# Device arguments
791+
device_kwargs = get_kwargs(DeviceConfig)
792+
device_group = parser.add_argument_group(
793+
title="DeviceConfig",
794+
description=DeviceConfig.__doc__,
795+
)
796+
device_group.add_argument("--device", **device_kwargs["device"])
797+
792798
parser.add_argument('--num-scheduler-steps',
793799
type=int,
794800
default=1,
@@ -1302,8 +1308,6 @@ def create_engine_config(
13021308

13031309
if self.qlora_adapter_name_or_path is not None and \
13041310
self.qlora_adapter_name_or_path != "":
1305-
if self.model_loader_extra_config is None:
1306-
self.model_loader_extra_config = {}
13071311
self.model_loader_extra_config[
13081312
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
13091313

0 commit comments

Comments
 (0)