Skip to content

Commit 77c4b3a

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Provide helper functions for int4 quantization (#3775)
Summary: X-link: facebookresearch/FBGEMM#855 This diff introduces a set of quantization helper functions to fbgemm_gpu/experimental/gen_ai to make it easier to apply the new Int4 packing and preshuffling to weights. Reviewed By: summerdengfb Differential Revision: D70643388
1 parent 7b9bffa commit 77c4b3a

File tree

5 files changed

+271
-82
lines changed

5 files changed

+271
-82
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
2525
grouped_gemm_fp8_rowwise,
2626
)
27+
from fbgemm_gpu.experimental.gen_ai.quantize import quantize_int4_preshuffle
2728
from tinygemm.utils import group_quantize_tensor
2829

2930
if torch.cuda.is_available() and torch.version.cuda:
@@ -1277,58 +1278,52 @@ def cuda(self) -> bool:
12771278

12781279

12791280
@register_quantize_op
1280-
class F8I4ShuffledGemm(F8I4RowwiseGemm):
1281-
def _int4_row_quantize(
1282-
self,
1283-
x: torch.Tensor,
1284-
group_size: int = 128,
1285-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1286-
n_bit = 4 # Number of target bits.
1287-
to_quant = x.reshape(-1, group_size).to(torch.float)
1288-
1289-
max_val = torch.abs(to_quant).amax(dim=1, keepdim=True)
1290-
max_int = 2 ** (n_bit - 1)
1291-
min_int = -(2 ** (n_bit - 1))
1292-
scales = max_val.clamp(min=1e-6) / max_int
1293-
1294-
out = to_quant.div(scales).round().clamp_(min_int, max_int - 1)
1295-
1296-
# Cast to int8 and restore shape.
1297-
out = out.to(dtype=torch.int8).reshape(x.shape)
1298-
1299-
# Scales should be in [num_groups, N] layout.
1300-
scales = scales.view(x.shape[0], -1).t().contiguous().to(torch.float8_e4m3fn)
1301-
1302-
return out, scales
1281+
class F8I4ShuffledGemm(QuantizeOpBase):
1282+
def preprocess(self, x, w):
1283+
# Prequantize and pack weights.
1284+
wq, row_scale, group_scale = quantize_int4_preshuffle(w)
1285+
return x, wq, row_scale, group_scale
13031286

1304-
def quantize(self, x, w):
1287+
def quantize(self, x, wq, row_scale, group_scale):
13051288
# Quantize both input tensors.
13061289
xq, x_scale = quantize_fp8_row(x)
1307-
# Weight quantization happens in two steps. First we quantize to fp8
1308-
# then to int4.
1309-
wq, w_scale = quantize_fp8_row(w)
1310-
# Now quantize to int4 with group scaling.
1311-
wq, w_scale_group = self._int4_row_quantize(wq)
1312-
# Pack int4 values together.
1313-
wq = self._pack_int4(wq)
1314-
# Shuffle weights and scales for faster compute.
1315-
wq, w_scale_group = torch.ops.fbgemm.preshuffle_i4(wq, w_scale_group)
1316-
return xq, wq, x_scale, w_scale, w_scale_group
1290+
return xq, wq, x_scale, row_scale, group_scale
13171291

1318-
def compute(self, xq, wq, x_scale, w_scale, w_scale_group):
1319-
out = torch.ops.fbgemm.f8i4bf16_shuffled(
1320-
xq, wq, x_scale, w_scale, w_scale_group
1292+
def compute(self, xq, wq, x_scale, row_scale, group_scale):
1293+
# Handle batched cases by looping over each batch.
1294+
if xq.dim() == 3:
1295+
B, M, _ = xq.shape
1296+
_, N, _ = wq.shape
1297+
y = torch.empty((B, M, N), device=xq.device, dtype=torch.bfloat16)
1298+
for i in range(B):
1299+
y[i] = torch.ops.fbgemm.f8i4bf16_shuffled(
1300+
xq[i], wq[i], x_scale[i], row_scale[i], group_scale[i]
1301+
)
1302+
return y
1303+
# Otherwise run gemm normally.
1304+
return torch.ops.fbgemm.f8i4bf16_shuffled(
1305+
xq, wq, x_scale, row_scale, group_scale
13211306
)
1322-
return out
13231307

1324-
def quantize_and_compute(self, x, w):
1325-
xq, wq, x_scale, w_scale, w_scale_group = self.quantize(x, w)
1326-
return self.compute(xq, wq, x_scale, w_scale, w_scale_group)
1308+
def quantize_and_compute(self, x, wq, row_scale, group_scale):
1309+
xq, wq, x_scale, row_scale, group_scale = self.quantize(
1310+
x, wq, row_scale, group_scale
1311+
)
1312+
return self.compute(xq, wq, x_scale, row_scale, group_scale)
13271313

13281314
@property
13291315
def name(self) -> str:
13301316
return "cutlass_f8i4_preshuffle"
13311317

1318+
@property
1319+
def hip(self) -> bool:
1320+
# Not yet supported on AMD.
1321+
return False
1322+
1323+
@property
1324+
def cuda(self) -> bool:
1325+
return True
1326+
13321327

13331328
@register_quantize_op
13341329
class BF16I4RowwiseGemm(F8I4RowwiseGemm):
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
# Helper functions for using FBGEMM quantized operators.
10+
11+
from typing import Tuple
12+
13+
import torch
14+
15+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import quantize_fp8_row
16+
17+
18+
def pack_int4(x: torch.Tensor) -> torch.Tensor:
19+
# Given int8 x, pack adjacent int4 values into a single int8.
20+
low_x = x[:, ::2]
21+
high_x = x[:, 1::2]
22+
23+
# High bits need to left shift, this also masks off extra bits.
24+
high_x = torch.bitwise_left_shift(high_x, 4)
25+
# Low bits need to have sign bits removed.
26+
low_x = torch.bitwise_and(low_x, 0xF)
27+
28+
# Recombine into a single value with bitwise or.
29+
return torch.bitwise_or(low_x, high_x).contiguous()
30+
31+
32+
def int4_row_quantize(
33+
x: torch.Tensor,
34+
group_size: int = 128,
35+
) -> Tuple[torch.Tensor, torch.Tensor]:
36+
"""
37+
Helper function to quantize a tensor to int4 with groupwise scales.
38+
39+
Args:
40+
x (Tensor): [N, K] Higher precision weight tensor to quantize.
41+
group_size (int): Number of elements to calculate group scale for.
42+
Returns:
43+
wq (Tensor): [N, K // 2] Quantized int4 tensor stored in int8 elements.
44+
group_scale (Tensor): [K / group_size, N] FP32 Scale per group.
45+
"""
46+
n_bit = 4 # Number of target bits.
47+
to_quant = x.reshape(-1, group_size).to(torch.float)
48+
49+
max_val = torch.abs(to_quant).amax(dim=1, keepdim=True)
50+
max_int = 2 ** (n_bit - 1)
51+
min_int = -(2 ** (n_bit - 1))
52+
scales = max_val.clamp(min=1e-6) / max_int
53+
54+
out = to_quant.div(scales).round().clamp_(min_int, max_int - 1)
55+
56+
# Cast to int8 and restore shape.
57+
out = out.to(dtype=torch.int8).reshape(x.shape)
58+
59+
# Scales should be in [num_groups, N] layout.
60+
scales = scales.view(x.shape[0], -1).t().contiguous()
61+
62+
return out, scales
63+
64+
65+
def quantize_int4_preshuffle(
66+
w: torch.Tensor, group_size: int = 128
67+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
68+
"""
69+
Quantizes an input weight tensor to int4 using preshuffling and scale packing.
70+
This function is intended to be used with fbgemms mixed dtype kernels and is expected
71+
to be applied to weights ahead of time. As such, it is not perfectly optimized.
72+
73+
Args:
74+
w (Tensor): [N, K] Higher precision weight tensor to quantize. May optionally have a batch dimension.
75+
group_size (int): Number of elements to calculate group scale for, must be at least 128.
76+
Returns:
77+
wq (Tensor): [N, K // 2] Quantized int4 weight tensor packed into int8 elements.
78+
row_scale (Tensor): [N] FP32 Scale per row of the weight tensor.
79+
group_scale (Tensor): [K / group_size, 8, N] FP8 Scale per group of the weight tensor.
80+
"""
81+
82+
def _quantize(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
83+
# Start by lowering weights to FP8 and producing row scales.
84+
wq, row_scale = quantize_fp8_row(w)
85+
86+
# Now reduce to INT4.
87+
wq, group_scale = int4_row_quantize(wq, group_size)
88+
# Reduce group scale to FP8.
89+
group_scale = group_scale.to(torch.float8_e4m3fn)
90+
91+
# Take quantized weights and pack them efficiently.
92+
wq = pack_int4(wq)
93+
94+
# Finally pack weights and scales into efficient preshuffled format.
95+
wq, group_scale = torch.ops.fbgemm.preshuffle_i4(wq, group_scale)
96+
97+
return wq, row_scale, group_scale
98+
99+
if w.ndim >= 3:
100+
orig_shape = w.shape
101+
# Flatten to 3 dimensions then iterate over batches.
102+
w = w.view(-1, *w.shape[1:])
103+
w.unbind(dim=0)
104+
wq = []
105+
row_scale = []
106+
group_scale = []
107+
for batch in w:
108+
wq_, row_scale_, group_scale_ = _quantize(batch)
109+
wq.append(wq_)
110+
row_scale.append(row_scale_)
111+
group_scale.append(group_scale_)
112+
wq = torch.stack(wq).view(*orig_shape[:-2], *wq[0].shape)
113+
row_scale = torch.stack(row_scale).view(*orig_shape[:-2], *row_scale[0].shape)
114+
group_scale = torch.stack(group_scale).view(
115+
*orig_shape[:-2], *group_scale[0].shape
116+
)
117+
else:
118+
wq, row_scale, group_scale = _quantize(w)
119+
return wq, row_scale, group_scale

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ at::Tensor f8i4bf16_rowwise_impl(
4747

4848
int group_size = K / num_groups;
4949

50+
// Return immediately if input is empty.
51+
if (M == 0 || N == 0 || K == 0) {
52+
return at::zeros({M, N}, XQ.options().dtype(at::kBFloat16));
53+
}
5054
auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));
5155

5256
using ElementInputA = INPUT_DTYPE;

0 commit comments

Comments
 (0)