Skip to content

Make name of compressed-tensors quant method consistent across vLLM #17255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions tests/compile/test_full_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@ def models_list(*, all: bool = True, keywords: Optional[list[str]] = None):
("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
"dtype": torch.float16,
"quantization": "compressed-tensors"
"quantization": "compressed_tensors"
}),
("neuralmagic/Llama-3.2-1B-Instruct-FP8-dynamic", {
"dtype": torch.float16,
"quantization": "compressed-tensors"
"quantization": "compressed_tensors"
}),
("neuralmagic/Llama-3.2-1B-Instruct-quantized.w8a8", {
"quantization": "compressed-tensors"
"quantization": "compressed_tensors"
}),
("meta-llama/Llama-3.2-1B-Instruct", {}),
]
Expand Down
8 changes: 4 additions & 4 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,9 +752,8 @@ def _verify_quantization(self) -> None:
supported_quantization = QUANTIZATION_METHODS
optimized_quantization_methods = [
"fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin",
"awq_marlin", "fbgemm_fp8", "compressed_tensors",
"compressed-tensors", "experts_int8", "quark", "nvfp4", "bitblas",
"gptq_bitblas"
"awq_marlin", "fbgemm_fp8", "compressed_tensors", "experts_int8",
"quark", "nvfp4", "bitblas", "gptq_bitblas"
]
if self.quantization is not None:
self.quantization = self.quantization.lower()
Expand All @@ -763,7 +762,8 @@ def _verify_quantization(self) -> None:
quant_cfg = self._parse_quant_hf_config()

if quant_cfg is not None:
quant_method = quant_cfg.get("quant_method", "").lower()
quant_method = quant_cfg.get("quant_method", "")
quant_method = quant_method.replace("-", "_").lower()

# Detect which checkpoint is it
for name in QUANTIZATION_METHODS:
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/layers/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
"gptq_bitblas",
"awq_marlin",
"gptq",
"compressed-tensors",
"compressed_tensors",
"bitsandbytes",
"qqq",
"hqq",
Expand Down Expand Up @@ -130,7 +130,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
"gptq_bitblas": GPTQBitBLASConfig,
"awq_marlin": AWQMarlinConfig,
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"compressed_tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"ptpc_fp8": PTPCFp8Config,
"qqq": QQQConfig,
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_model_architecture(
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = [
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin"
"fp8", "compressed_tensors", "gptq_marlin", "awq_marlin"
]

if (model_config.quantization is not None
Expand Down
4 changes: 2 additions & 2 deletions vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ class RocmPlatform(Platform):
device_control_env_var: str = "CUDA_VISIBLE_DEVICES"

supported_quantization: list[str] = [
"awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors",
"fbgemm_fp8", "gguf", "quark", "ptpc_fp8"
"awq", "gptq", "fp8", "compressed_tensors", "fbgemm_fp8", "gguf",
"quark", "ptpc_fp8"
]

@classmethod
Expand Down
4 changes: 1 addition & 3 deletions vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@ class TpuPlatform(Platform):
ray_device_key: str = "TPU"
device_control_env_var: str = "TPU_VISIBLE_CHIPS"

supported_quantization: list[str] = [
"tpu_int8", "compressed-tensors", "compressed_tensors"
]
supported_quantization: list[str] = ["tpu_int8", "compressed_tensors"]

additional_env_vars: list[str] = [
"TPU_CHIPS_PER_HOST_BOUNDS", "TPU_HOST_BOUNDS"
Expand Down
2 changes: 1 addition & 1 deletion vllm/transformers_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,7 @@ def recurse_elems(elem: Any):
"quant_method": "fp8",
"activation_scheme": "static"
}
elif quantization.get("quant_method") == "compressed-tensors":
elif quantization.get("quant_method") == "compressed_tensors":
# Pass through the quantization config to compressed-tensors
quantization_config = quantization
else:
Expand Down