1
1
# SPDX-License-Identifier: Apache-2.0
2
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
3
4
- from typing import Callable , Optional
4
+ from typing import Any , Callable , Optional , cast
5
5
6
6
import torch
7
7
from torch .nn import Parameter
19
19
20
20
class QuarkW8A8Fp8 (QuarkScheme ):
21
21
22
- def __init__ (self , qscheme : str , is_static_input_scheme : Optional [bool ]):
23
- self .qscheme = qscheme
24
- self .is_static_input_scheme = is_static_input_scheme
25
- self .fp8_linear = Fp8LinearOp (use_per_token_if_dynamic = False )
22
+ def __init__ (self , weight_config : dict [str , Any ],
23
+ input_config : Optional [dict [str , Any ]]):
24
+ self .weight_qscheme = cast (str , weight_config .get ("qscheme" ))
25
+ self .is_static_input_scheme : bool = False
26
+ self .input_qscheme : Optional [str ] = None
27
+ if input_config is not None :
28
+ self .is_static_input_scheme = not cast (
29
+ bool , input_config .get ("is_dynamic" ))
30
+ self .input_qscheme = cast (str , input_config .get ("qscheme" ))
31
+ self .use_per_token_if_dynamic = (not self .is_static_input_scheme \
32
+ and self .input_qscheme == "per_channel" )
33
+ self .fp8_linear = Fp8LinearOp (
34
+ use_per_token_if_dynamic = self .use_per_token_if_dynamic )
26
35
self .out_dtype = torch .get_default_dtype ()
27
36
28
37
@classmethod
@@ -34,7 +43,7 @@ def process_weights_after_loading(self, layer) -> None:
34
43
# If per tensor, when we have a fused module (e.g. QKV) with per
35
44
# tensor scales (thus N scales being passed to the kernel),
36
45
# requantize so we can always run per tensor
37
- if self .qscheme == "per_tensor" :
46
+ if self .weight_qscheme == "per_tensor" :
38
47
if current_platform .is_rocm ():
39
48
input_scale = getattr (layer , 'input_scale' , None )
40
49
weight , max_w_scale , input_scale = normalize_e4m3fn_to_e4m3fnuz (
@@ -58,7 +67,7 @@ def process_weights_after_loading(self, layer) -> None:
58
67
layer .weight_scale = Parameter (max_w_scale , requires_grad = False )
59
68
60
69
# If channelwise, scales are already lined up, so just transpose.
61
- elif self .qscheme == "per_channel" :
70
+ elif self .weight_qscheme == "per_channel" :
62
71
weight = layer .weight
63
72
64
73
if current_platform .is_fp8_fnuz ():
@@ -73,13 +82,15 @@ def process_weights_after_loading(self, layer) -> None:
73
82
requires_grad = False )
74
83
else :
75
84
weight_scale = layer .weight_scale .data
76
-
85
+ if self .use_per_token_if_dynamic :
86
+ weight_scale = weight_scale .view (- 1 , 1 )
77
87
layer .weight = Parameter (weight .t (), requires_grad = False )
78
88
# required by torch.compile to be torch.nn.Parameter
79
89
layer .weight_scale = Parameter (weight_scale , requires_grad = False )
80
90
81
91
else :
82
- raise ValueError (f"Unknown quantization scheme { self .qscheme } " )
92
+ raise ValueError (
93
+ f"Unknown quantization scheme { self .weight_qscheme } " )
83
94
84
95
# INPUT SCALE
85
96
if self .is_static_input_scheme :
@@ -109,14 +120,14 @@ def create_weights(self, layer: torch.nn.Module,
109
120
# WEIGHT SCALE
110
121
# TODO: update create_xxx_parameter functions to return
111
122
# the newly added parameters
112
- if self .qscheme == "per_channel" :
123
+ if self .weight_qscheme == "per_channel" :
113
124
weight_scale = ChannelQuantScaleParameter (
114
125
data = torch .empty ((sum (output_partition_sizes )),
115
126
dtype = torch .float32 ),
116
127
output_dim = 0 ,
117
128
weight_loader = weight_loader )
118
129
else :
119
- assert self .qscheme == "per_tensor"
130
+ assert self .weight_qscheme == "per_tensor"
120
131
weight_scale = PerTensorScaleParameter (data = torch .empty (
121
132
len (output_partition_sizes ), dtype = torch .float32 ),
122
133
weight_loader = weight_loader )
0 commit comments