Skip to content

Commit 9025a9a

Browse files
authored
[Quant] [Bugfix] Fix quantization config matching with hf_to_vllm_mapper (vllm-project#20046)
1 parent c05596f commit 9025a9a

File tree

17 files changed

+107
-29
lines changed

17 files changed

+107
-29
lines changed

tests/quantization/test_register_quantization_config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class CustomQuantConfig(QuantizationConfig):
5353

5454
def __init__(self, num_bits: int = 8) -> None:
5555
"""Initialize the quantization config."""
56+
super().__init__()
5657
self.num_bits = num_bits
5758

5859
def get_name(self) -> QuantizationMethods:

vllm/lora/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -805,7 +805,7 @@ def create_lora_manager(
805805
lora_manager_cls: type[LoRAModelManager] = LoRAModelManager,
806806
**kwargs) -> LoRAModelManager:
807807
"""Create a LoRA adapter for a given model."""
808-
if not hasattr(model, "packed_modules_mapping"):
808+
if not isinstance(model, SupportsLoRA):
809809
raise ValueError(f"Model {type(model)} is not supported for LoRA.")
810810
lora_manager = lora_manager_cls(
811811
model=model,

vllm/lora/worker_manager.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,10 +111,7 @@ def _load_adapter(self, lora_request: LoRARequest) -> LoRAModel:
111111
# For some models like Qwen2VL, we need to use hf_to_vllm_mapper
112112
# to ensure correct loading of lora weights.
113113
model = self._adapter_manager.model
114-
hf_to_vllm_mapper = None
115-
if (hasattr(model, "hf_to_vllm_mapper")
116-
and model.hf_to_vllm_mapper is not None):
117-
hf_to_vllm_mapper = model.hf_to_vllm_mapper
114+
hf_to_vllm_mapper = getattr(model, "hf_to_vllm_mapper", None)
118115

119116
lora = self._lora_model_cls.from_local_checkpoint(
120117
lora_path,

vllm/model_executor/layers/quantization/base_config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
if TYPE_CHECKING:
1212
from vllm.model_executor.layers.quantization import QuantizationMethods
13+
from vllm.model_executor.models.utils import WeightsMapper
1314
else:
1415
QuantizationMethods = str
1516

@@ -149,3 +150,15 @@ def get_quant_method(self, layer: torch.nn.Module,
149150

150151
def get_cache_scale(self, name: str) -> Optional[str]:
151152
return None
153+
154+
def apply_vllm_mapper( # noqa: B027
155+
self, hf_to_vllm_mapper: "WeightsMapper"):
156+
"""
157+
Interface for models to update module names referenced in
158+
quantization configs in order to reflect the vllm model structure
159+
160+
:param hf_to_vllm_mapper: maps from hf model structure (the assumed
161+
structure of the qconfig) to vllm model structure
162+
"""
163+
# TODO (@kylesayrs): add implementations for all subclasses
164+
pass

vllm/model_executor/layers/quantization/bitblas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def __init__(
6363
# (since we have only one group per output channel)
6464
desc_act = False
6565

66+
super().__init__()
6667
self.weight_bits = weight_bits
6768
self.group_size = group_size
6869
self.desc_act = desc_act

vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
from contextlib import suppress
5-
from typing import Any, Literal, Optional, cast
5+
from typing import TYPE_CHECKING, Any, Literal, Optional, cast
66

77
import torch
88
from compressed_tensors.config import (CompressionFormat,
@@ -37,6 +37,9 @@
3737
cutlass_fp4_supported)
3838
from vllm.platforms import current_platform
3939

40+
if TYPE_CHECKING:
41+
from vllm.model_executor.models.utils import WeightsMapper
42+
4043
logger = init_logger(__name__)
4144

4245
__all__ = ["CompressedTensorsLinearMethod"]
@@ -80,6 +83,18 @@ def get_min_capability(cls) -> int:
8083
def get_name(self) -> QuantizationMethods:
8184
return "compressed-tensors"
8285

86+
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
87+
self.target_scheme_map = hf_to_vllm_mapper.apply_dict(
88+
self.target_scheme_map)
89+
self.ignore = hf_to_vllm_mapper.apply_list(self.ignore)
90+
self.sparsity_scheme_map = hf_to_vllm_mapper.apply_dict(
91+
self.sparsity_scheme_map)
92+
self.sparsity_ignore_list = hf_to_vllm_mapper.apply_list(
93+
self.sparsity_ignore_list)
94+
if self.kv_cache_scheme is not None:
95+
self.kv_cache_scheme = hf_to_vllm_mapper.apply_dict(
96+
self.kv_cache_scheme)
97+
8398
def get_quant_method(
8499
self,
85100
layer: torch.nn.Module,

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

44
import functools
5-
from typing import Any, Callable, Optional, Union
5+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
66

77
import torch
88
import torch.nn.functional as F
@@ -39,6 +39,9 @@
3939
from vllm.scalar_type import scalar_types
4040
from vllm.utils import has_deep_gemm
4141

42+
if TYPE_CHECKING:
43+
from vllm.model_executor.models.utils import WeightsMapper
44+
4245
ACTIVATION_SCHEMES = ["static", "dynamic"]
4346

4447
logger = init_logger(__name__)
@@ -100,6 +103,11 @@ def get_min_capability(cls) -> int:
100103
def get_config_filenames(cls) -> list[str]:
101104
return []
102105

106+
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
107+
if self.ignored_layers is not None:
108+
self.ignored_layers = hf_to_vllm_mapper.apply_list(
109+
self.ignored_layers)
110+
103111
@classmethod
104112
def from_config(cls, config: dict[str, Any]) -> "Fp8Config":
105113
quant_method = cls.get_from_keys(config, ["quant_method"])

vllm/model_executor/layers/quantization/gptq_bitblas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(
8181
# (since we have only one group per output channel)
8282
desc_act = False
8383

84+
super().__init__()
8485
self.weight_bits = weight_bits
8586
self.group_size = group_size
8687
self.desc_act = desc_act

vllm/model_executor/layers/quantization/marlin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ def __init__(
3232
group_size: int,
3333
lm_head_quantized: bool,
3434
) -> None:
35+
super().__init__()
36+
3537
# Group size for the quantization.
3638
self.group_size = group_size
3739
self.lm_head_quantized = lm_head_quantized

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,7 @@ def __init__(
181181
exclude_modules: list[str],
182182
group_size: int = 16,
183183
) -> None:
184+
super().__init__()
184185
self.is_checkpoint_nvfp4_serialized = is_checkpoint_nvfp4_serialized
185186
if is_checkpoint_nvfp4_serialized:
186187
logger.warning(

vllm/model_executor/layers/quantization/torchao.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ def __init__(self,
5555
os.environ["VLLM_DISABLE_COMPILE_CACHE"] = "1"
5656
logger.info("Using TorchAO: Setting VLLM_DISABLE_COMPILE_CACHE=1")
5757
"""
58+
super().__init__()
5859
self.torchao_config = torchao_config
5960
self.skip_modules = skip_modules or []
6061

vllm/model_executor/model_loader/utils.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from vllm.model_executor.models.adapters import (as_classification_model,
2525
as_embedding_model,
2626
as_reward_model)
27+
from vllm.model_executor.models.interfaces import SupportsQuant
2728
from vllm.utils import is_pin_memory_available
2829

2930
logger = init_logger(__name__)
@@ -294,13 +295,16 @@ def configure_quant_config(quant_config: QuantizationConfig,
294295
295296
Note that model attributes are passed by reference to quant_config,
296297
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
298+
299+
Once the `SupportsQuant` mixin has been added to all models, this
300+
function can be removed
297301
"""
298-
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
299-
if packed_mapping is not None:
300-
# pass packed_modules_mapping by reference to quant_config
301-
quant_config.packed_modules_mapping = packed_mapping
302-
else:
303-
logger.warning(
304-
"The model class %s has not defined `packed_modules_mapping`, "
305-
"this may lead to incorrect mapping of quantized or ignored "
306-
"modules", model_class.__name__)
302+
if not issubclass(model_class, SupportsQuant):
303+
hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
304+
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
305+
306+
# pass mappings by reference to quant_config
307+
if hf_to_vllm_mapper is not None:
308+
quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
309+
if packed_mapping is not None:
310+
quant_config.packed_modules_mapping = packed_mapping

vllm/model_executor/models/interfaces.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
if TYPE_CHECKING:
2020
from vllm.attention import AttentionMetadata
21+
from vllm.model_executor.models.utils import WeightsMapper
2122
from vllm.sequence import IntermediateTensors
2223

2324
logger = init_logger(__name__)
@@ -566,20 +567,36 @@ def has_step_pooler(model: Union[type[object], object]) -> bool:
566567
class SupportsQuant:
567568
"""The interface required for all models that support quantization."""
568569

569-
packed_modules_mapping: ClassVar[dict[str, list[str]]] = {}
570+
hf_to_vllm_mapper: ClassVar[Optional["WeightsMapper"]] = None
571+
packed_modules_mapping: ClassVar[Optional[dict[str, list[str]]]] = None
570572
quant_config: Optional[QuantizationConfig] = None
571573

572574
def __new__(cls, *args, **kwargs) -> Self:
573575
instance = super().__new__(cls)
576+
577+
# find config passed in arguments
574578
quant_config = cls._find_quant_config(*args, **kwargs)
575579
if quant_config is not None:
580+
581+
# attach config to model for general use
576582
instance.quant_config = quant_config
577-
instance.quant_config.packed_modules_mapping.update(
578-
cls.packed_modules_mapping)
583+
584+
# apply model mappings to config for proper config-model matching
585+
# NOTE: `TransformersForCausalLM` is not supported due to how this
586+
# class defines `hf_to_vllm_mapper` as a post-init `@property`.
587+
# After this is fixed, get `instance.hf_to_vllm_mapper` directly
588+
if getattr(instance, "hf_to_vllm_mapper", None) is not None:
589+
instance.quant_config.apply_vllm_mapper(
590+
instance.hf_to_vllm_mapper)
591+
if getattr(instance, "packed_modules_mapping", None) is not None:
592+
instance.quant_config.packed_modules_mapping.update(
593+
instance.packed_modules_mapping)
594+
579595
return instance
580596

581597
@staticmethod
582598
def _find_quant_config(*args, **kwargs) -> Optional[QuantizationConfig]:
599+
"""Find quant config passed through model constructor args"""
583600
from vllm.config import VllmConfig # avoid circular import
584601

585602
args_values = list(args) + list(kwargs.values())

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from vllm.transformers_utils.config import uses_mrope
6262

6363
from .interfaces import (MultiModalEmbeddings, SupportsLoRA,
64-
SupportsMultiModal, SupportsPP)
64+
SupportsMultiModal, SupportsPP, SupportsQuant)
6565
from .qwen2_vl import Qwen2VLDummyInputsBuilder as Qwen2_5_VLDummyInputsBuilder
6666
from .qwen2_vl import (Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
6767
apply_rotary_pos_emb_vision)
@@ -821,7 +821,8 @@ def _get_mm_fields_config(
821821
info=Qwen2_5_VLProcessingInfo,
822822
dummy_inputs=Qwen2_5_VLDummyInputsBuilder)
823823
class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
824-
SupportsLoRA, SupportsPP):
824+
SupportsLoRA, SupportsPP,
825+
SupportsQuant):
825826

826827
# To ensure correct weight loading and mapping.
827828
hf_to_vllm_mapper = WeightsMapper(
@@ -837,7 +838,6 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module, SupportsMultiModal,
837838
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
838839
super().__init__()
839840
config: Qwen2_5_VLConfig = vllm_config.model_config.hf_config
840-
quant_config = vllm_config.quant_config
841841
multimodal_config = vllm_config.model_config.multimodal_config
842842

843843
self.config = config
@@ -846,7 +846,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
846846
self.visual = Qwen2_5_VisionTransformer(
847847
config.vision_config,
848848
norm_eps=getattr(config, "rms_norm_eps", 1e-6),
849-
quant_config=self._maybe_ignore_quant_config(quant_config),
849+
quant_config=self._maybe_ignore_quant_config(self.quant_config),
850850
prefix=maybe_prefix(prefix, "visual"),
851851
)
852852

@@ -859,12 +859,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
859859
self.make_empty_intermediate_tensors = (
860860
self.language_model.make_empty_intermediate_tensors)
861861

862-
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
862+
def _maybe_ignore_quant_config(self, config: Optional[QuantizationConfig]):
863863
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
864864
# seems to avoid vision encoder sections for some models.
865-
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
865+
if isinstance(config, (GPTQConfig, GPTQMarlinConfig)):
866866
return None
867-
return quant_config
867+
return config
868868

869869
def _validate_and_reshape_mm_tensor(self, mm_input: object,
870870
name: str) -> torch.Tensor:

vllm/model_executor/models/transformers.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
467467
# FIXME(Isotr0py): Don't use any weights mapper for Transformers backend,
468468
# this makes thing complicated. We need to remove this mapper after refactor
469469
# `TransformersModel` in the future.
470+
# NOTE: `SupportsQuant` can be updated after property decorator is removed
470471
@property
471472
def hf_to_vllm_mapper(self):
472473
prefix_mapper = {

vllm/model_executor/models/utils.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import itertools
55
from collections.abc import Iterable, Mapping
66
from dataclasses import dataclass, field
7-
from typing import Callable, Literal, Optional, Protocol, Union, overload
7+
from typing import Any, Callable, Literal, Optional, Protocol, Union, overload
88

99
import torch
1010
import torch.nn as nn
@@ -64,6 +64,19 @@ def apply(
6464
return ((out_name, data) for name, data in weights
6565
if (out_name := self._map_name(name)) is not None)
6666

67+
def apply_list(self, values: list[str]) -> list[str]:
68+
return [
69+
out_name for name in values
70+
if (out_name := self._map_name(name)) is not None
71+
]
72+
73+
def apply_dict(self, values: dict[str, Any]) -> dict[str, Any]:
74+
return {
75+
out_name: value
76+
for name, value in values.items()
77+
if (out_name := self._map_name(name)) is not None
78+
}
79+
6780

6881
class AutoWeightsLoader:
6982
"""

vllm/model_executor/utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,15 +58,18 @@ def _synced_weight_loader(param, *args, **kwargs):
5858

5959

6060
def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
61-
parent_map = copy.deepcopy(getattr(model, "packed_modules_mapping", {}))
61+
parent_map = getattr(model, "packed_modules_mapping", None)
62+
parent_map = copy.deepcopy(parent_map) if parent_map is not None else {}
6263

6364
# don't infer mapping if the model has defined it explicitly.
6465
if parent_map:
6566
return parent_map
6667

6768
# We only check main components instead of whole model submodules
6869
for child in model.children():
69-
child_map = getattr(child, "packed_modules_mapping", {})
70+
child_map = getattr(child, "packed_modules_mapping", None)
71+
child_map = copy.deepcopy(child_map) if child_map is not None else {}
72+
7073
if any((k in parent_map and parent_map[k] != v)
7174
for k, v in child_map.items()):
7275
raise ValueError(

0 commit comments

Comments
 (0)