@@ -29,6 +29,34 @@ def pack_int4(x: torch.Tensor) -> torch.Tensor:
29
29
return torch .bitwise_or (low_x , high_x ).contiguous ()
30
30
31
31
32
+ def int4_row_quantize_zp (
33
+ x : torch .Tensor ,
34
+ group_size : int = 128 ,
35
+ ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
36
+ n_bit = 4 # Number of target bits.
37
+ to_quant = x .reshape (- 1 , group_size ).to (torch .float )
38
+
39
+ max_val = to_quant .amax (dim = 1 , keepdim = True )
40
+ min_val = to_quant .amin (dim = 1 , keepdim = True )
41
+ max_int = 2 ** n_bit - 1
42
+ min_int = 0
43
+ scales = (max_val - min_val ).clamp (min = 1e-6 ) / max_int
44
+
45
+ zeros = min_val + scales * (2 ** (n_bit - 1 ))
46
+
47
+ out = to_quant .sub (min_val ).div (scales ).round ().clamp_ (min_int , max_int )
48
+
49
+ # Recenter output and move to int8.
50
+ out = (out - 2 ** (n_bit - 1 )).to (dtype = torch .int8 ).reshape (x .shape )
51
+
52
+ # Cutlass expects column major layout for scale and zero point,
53
+ # so we transpose here and make them contiguous.
54
+ scales = scales .view (x .shape [0 ], - 1 ).t ().contiguous ()
55
+ zeros = zeros .view (x .shape [0 ], - 1 ).t ().contiguous ()
56
+
57
+ return out , scales , zeros
58
+
59
+
32
60
def int4_row_quantize (
33
61
x : torch .Tensor ,
34
62
group_size : int = 128 ,
@@ -63,8 +91,8 @@ def int4_row_quantize(
63
91
64
92
65
93
def quantize_int4_preshuffle (
66
- w : torch .Tensor , group_size : int = 128
67
- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
94
+ w : torch .Tensor , group_size : int = 128 , dtype : str = "fp8"
95
+ ) -> Tuple [torch .Tensor , Tuple [ torch .Tensor , torch .Tensor ] ]:
68
96
"""
69
97
Quantizes an input weight tensor to int4 using preshuffling and scale packing.
70
98
This function is intended to be used with fbgemms mixed dtype kernels and is expected
@@ -73,47 +101,57 @@ def quantize_int4_preshuffle(
73
101
Args:
74
102
w (Tensor): [N, K] Higher precision weight tensor to quantize. May optionally have a batch dimension.
75
103
group_size (int): Number of elements to calculate group scale for, must be at least 128.
104
+ dtype (torch.dtype): Type of corresponding activations. Must be fp8 or bf16.
76
105
Returns:
77
106
wq (Tensor): [N, K // 2] Quantized int4 weight tensor packed into int8 elements.
78
- row_scale (Tensor): [N] FP32 Scale per row of the weight tensor.
79
- group_scale (Tensor): [K / group_size, 8, N] FP8 Scale per group of the weight tensor.
107
+ scales (Tuple[Tensor]): Scale tensors for the specified activation type. When FP8 is used,
108
+ scales is a tuple of row_scale ([N]) and group_scale ([K / group_size, 8, N]). When BF16 is
109
+ used, scales is a tuple of group_scale([K / group_size, N]) and group_zero ([K / group_size, N])
80
110
"""
81
111
82
- def _quantize (w : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
83
- # Start by lowering weights to FP8 and producing row scales.
84
- wq , row_scale = quantize_fp8_row (w )
85
-
86
- # Now reduce to INT4.
87
- wq , group_scale = int4_row_quantize (wq , group_size )
88
- # Reduce group scale to FP8.
89
- group_scale = group_scale .to (torch .float8_e4m3fn )
90
-
91
- # Take quantized weights and pack them efficiently.
92
- wq = pack_int4 (wq )
93
-
94
- # Finally pack weights and scales into efficient preshuffled format.
95
- wq , group_scale = torch .ops .fbgemm .preshuffle_i4 (wq , group_scale )
96
-
97
- return wq , row_scale , group_scale
112
+ def _quantize (
113
+ w : torch .Tensor , dtype : str = "fp8"
114
+ ) -> Tuple [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
115
+
116
+ if dtype == "fp8" :
117
+ # Start by lowering weights to FP8 and producing row scales.
118
+ wq , row_scale = quantize_fp8_row (w )
119
+
120
+ # Now reduce to INT4.
121
+ wq , group_scale = int4_row_quantize (wq , group_size )
122
+ # Reduce group scale to FP8.
123
+ group_scale = group_scale .to (torch .float8_e4m3fn )
124
+ # Take quantized weights and pack them efficiently.
125
+ wq = pack_int4 (wq )
126
+ # Finally pack weights and scales into efficient preshuffled format.
127
+ wq , group_scale = torch .ops .fbgemm .preshuffle_i4 (wq , group_scale )
128
+ return wq , (group_scale , row_scale )
129
+
130
+ elif dtype == "bf16" :
131
+ wq , group_scale , group_zero = int4_row_quantize_zp (w , group_size )
132
+ # Set scales to activation type.
133
+ group_scale = group_scale .to (torch .bfloat16 )
134
+ group_zero = group_zero .to (torch .bfloat16 )
135
+ # Take quantized weights and pack them efficiently.
136
+ wq = pack_int4 (wq )
137
+ # Finally pack weights and scales into efficient preshuffled format.
138
+ wq , group_scale = torch .ops .fbgemm .preshuffle_i4 (wq , group_scale )
139
+ return wq , (group_scale , group_zero )
140
+ else :
141
+ raise NotImplementedError ("Only fp8 and bf16 activations supported." )
98
142
99
143
if w .ndim >= 3 :
100
144
orig_shape = w .shape
101
145
# Flatten to 3 dimensions then iterate over batches.
102
- w = w .view (- 1 , * w .shape [1 :])
103
- w .unbind (dim = 0 )
104
- wq = []
105
- row_scale = []
106
- group_scale = []
107
- for batch in w :
108
- wq_ , row_scale_ , group_scale_ = _quantize (batch )
109
- wq .append (wq_ )
110
- row_scale .append (row_scale_ )
111
- group_scale .append (group_scale_ )
146
+ wq , scales = zip (* [_quantize (i , dtype = dtype ) for i in w ])
112
147
wq = torch .stack (wq ).view (* orig_shape [:- 2 ], * wq [0 ].shape )
113
- row_scale = torch .stack (row_scale ).view (* orig_shape [:- 2 ], * row_scale [0 ].shape )
114
- group_scale = torch .stack (group_scale ).view (
115
- * orig_shape [:- 2 ], * group_scale [0 ].shape
148
+ # Decompose then stack scales back into a tuple.
149
+ a_scales , b_scales = zip (* scales )
150
+ scales = (
151
+ torch .stack (a_scales ).view (* orig_shape [:- 2 ], * a_scales [0 ].shape ),
152
+ torch .stack (b_scales ).view (* orig_shape [:- 2 ], * b_scales [0 ].shape ),
116
153
)
117
154
else :
118
- wq , row_scale , group_scale = _quantize (w )
119
- return wq , row_scale , group_scale
155
+ wq , scales = _quantize (w , dtype = dtype )
156
+
157
+ return wq , scales
0 commit comments