Skip to content

Commit 7827d97

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Provide helper functions for int4 quantization (pytorch#3775)
Summary: X-link: facebookresearch/FBGEMM#855 This diff introduces a set of quantization helper functions to fbgemm_gpu/experimental/gen_ai to make it easier to apply the new Int4 packing and preshuffling to weights. Reviewed By: summerdengfb Differential Revision: D70643388
1 parent 53bfbe8 commit 7827d97

File tree

5 files changed

+271
-82
lines changed

5 files changed

+271
-82
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
2525
grouped_gemm_fp8_rowwise,
2626
)
27+
from fbgemm_gpu.experimental.gen_ai.quantize import quantize_int4_preshuffle
2728
from tinygemm.utils import group_quantize_tensor
2829

2930
if torch.cuda.is_available() and torch.version.cuda:
@@ -1326,58 +1327,52 @@ def cuda(self) -> bool:
13261327

13271328

13281329
@register_quantize_op
1329-
class F8I4ShuffledGemm(F8I4RowwiseGemm):
1330-
def _int4_row_quantize(
1331-
self,
1332-
x: torch.Tensor,
1333-
group_size: int = 128,
1334-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1335-
n_bit = 4 # Number of target bits.
1336-
to_quant = x.reshape(-1, group_size).to(torch.float)
1337-
1338-
max_val = torch.abs(to_quant).amax(dim=1, keepdim=True)
1339-
max_int = 2 ** (n_bit - 1)
1340-
min_int = -(2 ** (n_bit - 1))
1341-
scales = max_val.clamp(min=1e-6) / max_int
1342-
1343-
out = to_quant.div(scales).round().clamp_(min_int, max_int - 1)
1344-
1345-
# Cast to int8 and restore shape.
1346-
out = out.to(dtype=torch.int8).reshape(x.shape)
1347-
1348-
# Scales should be in [num_groups, N] layout.
1349-
scales = scales.view(x.shape[0], -1).t().contiguous().to(torch.float8_e4m3fn)
1350-
1351-
return out, scales
1330+
class F8I4ShuffledGemm(QuantizeOpBase):
1331+
def preprocess(self, x, w):
1332+
# Prequantize and pack weights.
1333+
wq, row_scale, group_scale = quantize_int4_preshuffle(w)
1334+
return x, wq, row_scale, group_scale
13521335

1353-
def quantize(self, x, w):
1336+
def quantize(self, x, wq, row_scale, group_scale):
13541337
# Quantize both input tensors.
13551338
xq, x_scale = quantize_fp8_row(x)
1356-
# Weight quantization happens in two steps. First we quantize to fp8
1357-
# then to int4.
1358-
wq, w_scale = quantize_fp8_row(w)
1359-
# Now quantize to int4 with group scaling.
1360-
wq, w_scale_group = self._int4_row_quantize(wq)
1361-
# Pack int4 values together.
1362-
wq = self._pack_int4(wq)
1363-
# Shuffle weights and scales for faster compute.
1364-
wq, w_scale_group = torch.ops.fbgemm.preshuffle_i4(wq, w_scale_group)
1365-
return xq, wq, x_scale, w_scale, w_scale_group
1339+
return xq, wq, x_scale, row_scale, group_scale
13661340

1367-
def compute(self, xq, wq, x_scale, w_scale, w_scale_group):
1368-
out = torch.ops.fbgemm.f8i4bf16_shuffled(
1369-
xq, wq, x_scale, w_scale, w_scale_group
1341+
def compute(self, xq, wq, x_scale, row_scale, group_scale):
1342+
# Handle batched cases by looping over each batch.
1343+
if xq.dim() == 3:
1344+
B, M, _ = xq.shape
1345+
_, N, _ = wq.shape
1346+
y = torch.empty((B, M, N), device=xq.device, dtype=torch.bfloat16)
1347+
for i in range(B):
1348+
y[i] = torch.ops.fbgemm.f8i4bf16_shuffled(
1349+
xq[i], wq[i], x_scale[i], row_scale[i], group_scale[i]
1350+
)
1351+
return y
1352+
# Otherwise run gemm normally.
1353+
return torch.ops.fbgemm.f8i4bf16_shuffled(
1354+
xq, wq, x_scale, row_scale, group_scale
13701355
)
1371-
return out
13721356

1373-
def quantize_and_compute(self, x, w):
1374-
xq, wq, x_scale, w_scale, w_scale_group = self.quantize(x, w)
1375-
return self.compute(xq, wq, x_scale, w_scale, w_scale_group)
1357+
def quantize_and_compute(self, x, wq, row_scale, group_scale):
1358+
xq, wq, x_scale, row_scale, group_scale = self.quantize(
1359+
x, wq, row_scale, group_scale
1360+
)
1361+
return self.compute(xq, wq, x_scale, row_scale, group_scale)
13761362

13771363
@property
13781364
def name(self) -> str:
13791365
return "cutlass_f8i4_preshuffle"
13801366

1367+
@property
1368+
def hip(self) -> bool:
1369+
# Not yet supported on AMD.
1370+
return False
1371+
1372+
@property
1373+
def cuda(self) -> bool:
1374+
return True
1375+
13811376

13821377
@register_quantize_op
13831378
class BF16I4RowwiseGemm(F8I4RowwiseGemm):
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
# Helper functions for using FBGEMM quantized operators.
10+
11+
from typing import Tuple
12+
13+
import torch
14+
15+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import quantize_fp8_row
16+
17+
18+
def pack_int4(x: torch.Tensor) -> torch.Tensor:
19+
# Given int8 x, pack adjacent int4 values into a single int8.
20+
low_x = x[:, ::2]
21+
high_x = x[:, 1::2]
22+
23+
# High bits need to left shift, this also masks off extra bits.
24+
high_x = torch.bitwise_left_shift(high_x, 4)
25+
# Low bits need to have sign bits removed.
26+
low_x = torch.bitwise_and(low_x, 0xF)
27+
28+
# Recombine into a single value with bitwise or.
29+
return torch.bitwise_or(low_x, high_x).contiguous()
30+
31+
32+
def int4_row_quantize(
33+
x: torch.Tensor,
34+
group_size: int = 128,
35+
) -> Tuple[torch.Tensor, torch.Tensor]:
36+
"""
37+
Helper function to quantize a tensor to int4 with groupwise scales.
38+
39+
Args:
40+
x (Tensor): [N, K] Higher precision weight tensor to quantize.
41+
group_size (int): Number of elements to calculate group scale for.
42+
Returns:
43+
wq (Tensor): [N, K // 2] Quantized int4 tensor stored in int8 elements.
44+
group_scale (Tensor): [K / group_size, N] FP32 Scale per group.
45+
"""
46+
n_bit = 4 # Number of target bits.
47+
to_quant = x.reshape(-1, group_size).to(torch.float)
48+
49+
max_val = torch.abs(to_quant).amax(dim=1, keepdim=True)
50+
max_int = 2 ** (n_bit - 1)
51+
min_int = -(2 ** (n_bit - 1))
52+
scales = max_val.clamp(min=1e-6) / max_int
53+
54+
out = to_quant.div(scales).round().clamp_(min_int, max_int - 1)
55+
56+
# Cast to int8 and restore shape.
57+
out = out.to(dtype=torch.int8).reshape(x.shape)
58+
59+
# Scales should be in [num_groups, N] layout.
60+
scales = scales.view(x.shape[0], -1).t().contiguous()
61+
62+
return out, scales
63+
64+
65+
def quantize_int4_preshuffle(
66+
w: torch.Tensor, group_size: int = 128
67+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
68+
"""
69+
Quantizes an input weight tensor to int4 using preshuffling and scale packing.
70+
This function is intended to be used with fbgemms mixed dtype kernels and is expected
71+
to be applied to weights ahead of time. As such, it is not perfectly optimized.
72+
73+
Args:
74+
w (Tensor): [N, K] Higher precision weight tensor to quantize. May optionally have a batch dimension.
75+
group_size (int): Number of elements to calculate group scale for, must be at least 128.
76+
Returns:
77+
wq (Tensor): [N, K // 2] Quantized int4 weight tensor packed into int8 elements.
78+
row_scale (Tensor): [N] FP32 Scale per row of the weight tensor.
79+
group_scale (Tensor): [K / group_size, 8, N] FP8 Scale per group of the weight tensor.
80+
"""
81+
82+
def _quantize(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
83+
# Start by lowering weights to FP8 and producing row scales.
84+
wq, row_scale = quantize_fp8_row(w)
85+
86+
# Now reduce to INT4.
87+
wq, group_scale = int4_row_quantize(wq, group_size)
88+
# Reduce group scale to FP8.
89+
group_scale = group_scale.to(torch.float8_e4m3fn)
90+
91+
# Take quantized weights and pack them efficiently.
92+
wq = pack_int4(wq)
93+
94+
# Finally pack weights and scales into efficient preshuffled format.
95+
wq, group_scale = torch.ops.fbgemm.preshuffle_i4(wq, group_scale)
96+
97+
return wq, row_scale, group_scale
98+
99+
if w.ndim >= 3:
100+
orig_shape = w.shape
101+
# Flatten to 3 dimensions then iterate over batches.
102+
w = w.view(-1, *w.shape[1:])
103+
w.unbind(dim=0)
104+
wq = []
105+
row_scale = []
106+
group_scale = []
107+
for batch in w:
108+
wq_, row_scale_, group_scale_ = _quantize(batch)
109+
wq.append(wq_)
110+
row_scale.append(row_scale_)
111+
group_scale.append(group_scale_)
112+
wq = torch.stack(wq).view(*orig_shape[:-2], *wq[0].shape)
113+
row_scale = torch.stack(row_scale).view(*orig_shape[:-2], *row_scale[0].shape)
114+
group_scale = torch.stack(group_scale).view(
115+
*orig_shape[:-2], *group_scale[0].shape
116+
)
117+
else:
118+
wq, row_scale, group_scale = _quantize(w)
119+
return wq, row_scale, group_scale

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ at::Tensor f8i4bf16_rowwise_impl(
4747

4848
int group_size = K / num_groups;
4949

50+
// Return immediately if input is empty.
51+
if (M == 0 || N == 0 || K == 0) {
52+
return at::zeros({M, N}, XQ.options().dtype(at::kBFloat16));
53+
}
5054
auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));
5155

5256
using ElementInputA = INPUT_DTYPE;

0 commit comments

Comments
 (0)