Skip to content

Commit 1c50e10

Browse files
[Bugfix] fix quark ptpc (vllm-project#20251)
Signed-off-by: Haoyang Li <[email protected]> Co-authored-by: Haoyang Li <[email protected]>
1 parent 3ee56e2 commit 1c50e10

File tree

2 files changed

+23
-16
lines changed

2 files changed

+23
-16
lines changed

vllm/model_executor/layers/quantization/quark/quark.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -312,11 +312,7 @@ def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
312312
is_fp8_w8a8_supported = self._check_scheme_supported(
313313
QuarkW8A8Fp8.get_min_capability(), error=False)
314314
if is_fp8_w8a8_supported:
315-
weight_qscheme = cast(str, weight_config.get("qscheme"))
316-
input_static = (input_config is not None and
317-
not cast(bool, input_config.get("is_dynamic")))
318-
return QuarkW8A8Fp8(qscheme=weight_qscheme,
319-
is_static_input_scheme=input_static)
315+
return QuarkW8A8Fp8(weight_config, input_config)
320316
elif self._is_static_tensor_w8a8(weight_config, input_config):
321317
weight_qscheme = cast(str, weight_config.get("qscheme"))
322318
return QuarkW8A8Int8(qscheme=weight_qscheme,

vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from typing import Callable, Optional
4+
from typing import Any, Callable, Optional, cast
55

66
import torch
77
from torch.nn import Parameter
@@ -19,10 +19,19 @@
1919

2020
class QuarkW8A8Fp8(QuarkScheme):
2121

22-
def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]):
23-
self.qscheme = qscheme
24-
self.is_static_input_scheme = is_static_input_scheme
25-
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=False)
22+
def __init__(self, weight_config: dict[str, Any],
23+
input_config: Optional[dict[str, Any]]):
24+
self.weight_qscheme = cast(str, weight_config.get("qscheme"))
25+
self.is_static_input_scheme: bool = False
26+
self.input_qscheme: Optional[str] = None
27+
if input_config is not None:
28+
self.is_static_input_scheme = not cast(
29+
bool, input_config.get("is_dynamic"))
30+
self.input_qscheme = cast(str, input_config.get("qscheme"))
31+
self.use_per_token_if_dynamic = (not self.is_static_input_scheme \
32+
and self.input_qscheme == "per_channel")
33+
self.fp8_linear = Fp8LinearOp(
34+
use_per_token_if_dynamic=self.use_per_token_if_dynamic)
2635
self.out_dtype = torch.get_default_dtype()
2736

2837
@classmethod
@@ -34,7 +43,7 @@ def process_weights_after_loading(self, layer) -> None:
3443
# If per tensor, when we have a fused module (e.g. QKV) with per
3544
# tensor scales (thus N scales being passed to the kernel),
3645
# requantize so we can always run per tensor
37-
if self.qscheme == "per_tensor":
46+
if self.weight_qscheme == "per_tensor":
3847
if current_platform.is_rocm():
3948
input_scale = getattr(layer, 'input_scale', None)
4049
weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz(
@@ -58,7 +67,7 @@ def process_weights_after_loading(self, layer) -> None:
5867
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
5968

6069
# If channelwise, scales are already lined up, so just transpose.
61-
elif self.qscheme == "per_channel":
70+
elif self.weight_qscheme == "per_channel":
6271
weight = layer.weight
6372

6473
if current_platform.is_fp8_fnuz():
@@ -73,13 +82,15 @@ def process_weights_after_loading(self, layer) -> None:
7382
requires_grad=False)
7483
else:
7584
weight_scale = layer.weight_scale.data
76-
85+
if self.use_per_token_if_dynamic:
86+
weight_scale = weight_scale.view(-1, 1)
7787
layer.weight = Parameter(weight.t(), requires_grad=False)
7888
# required by torch.compile to be torch.nn.Parameter
7989
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
8090

8191
else:
82-
raise ValueError(f"Unknown quantization scheme {self.qscheme}")
92+
raise ValueError(
93+
f"Unknown quantization scheme {self.weight_qscheme}")
8394

8495
# INPUT SCALE
8596
if self.is_static_input_scheme:
@@ -109,14 +120,14 @@ def create_weights(self, layer: torch.nn.Module,
109120
# WEIGHT SCALE
110121
# TODO: update create_xxx_parameter functions to return
111122
# the newly added parameters
112-
if self.qscheme == "per_channel":
123+
if self.weight_qscheme == "per_channel":
113124
weight_scale = ChannelQuantScaleParameter(
114125
data=torch.empty((sum(output_partition_sizes)),
115126
dtype=torch.float32),
116127
output_dim=0,
117128
weight_loader=weight_loader)
118129
else:
119-
assert self.qscheme == "per_tensor"
130+
assert self.weight_qscheme == "per_tensor"
120131
weight_scale = PerTensorScaleParameter(data=torch.empty(
121132
len(output_partition_sizes), dtype=torch.float32),
122133
weight_loader=weight_loader)

0 commit comments

Comments
 (0)