Skip to content

Commit d9e4e3f

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Provide helper functions for int4 quantization (pytorch#3775)
Summary: X-link: facebookresearch/FBGEMM#855 This diff introduces a set of quantization helper functions to fbgemm_gpu/experimental/gen_ai to make it easier to apply the new Int4 packing and preshuffling to weights. Reviewed By: summerdengfb Differential Revision: D70643388
1 parent aec40d1 commit d9e4e3f

File tree

5 files changed

+223
-70
lines changed

5 files changed

+223
-70
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 36 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
2525
grouped_gemm_fp8_rowwise,
2626
)
27+
from fbgemm_gpu.experimental.gen_ai.quantize import quantize_int4_preshuffle
2728
from tinygemm.utils import group_quantize_tensor
2829

2930
if torch.cuda.is_available() and torch.version.cuda:
@@ -1277,58 +1278,52 @@ def cuda(self) -> bool:
12771278

12781279

12791280
@register_quantize_op
1280-
class F8I4ShuffledGemm(F8I4RowwiseGemm):
1281-
def _int4_row_quantize(
1282-
self,
1283-
x: torch.Tensor,
1284-
group_size: int = 128,
1285-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
1286-
n_bit = 4 # Number of target bits.
1287-
to_quant = x.reshape(-1, group_size).to(torch.float)
1288-
1289-
max_val = torch.abs(to_quant).amax(dim=1, keepdim=True)
1290-
max_int = 2 ** (n_bit - 1)
1291-
min_int = -(2 ** (n_bit - 1))
1292-
scales = max_val.clamp(min=1e-6) / max_int
1293-
1294-
out = to_quant.div(scales).round().clamp_(min_int, max_int - 1)
1295-
1296-
# Cast to int8 and restore shape.
1297-
out = out.to(dtype=torch.int8).reshape(x.shape)
1298-
1299-
# Scales should be in [num_groups, N] layout.
1300-
scales = scales.view(x.shape[0], -1).t().contiguous().to(torch.float8_e4m3fn)
1301-
1302-
return out, scales
1281+
class F8I4ShuffledGemm(QuantizeOpBase):
1282+
def preprocess(self, x, w):
1283+
# Prequantize and pack weights.
1284+
wq, row_scale, group_scale = quantize_int4_preshuffle(w)
1285+
return x, wq, row_scale, group_scale
13031286

1304-
def quantize(self, x, w):
1287+
def quantize(self, x, wq, row_scale, group_scale):
13051288
# Quantize both input tensors.
13061289
xq, x_scale = quantize_fp8_row(x)
1307-
# Weight quantization happens in two steps. First we quantize to fp8
1308-
# then to int4.
1309-
wq, w_scale = quantize_fp8_row(w)
1310-
# Now quantize to int4 with group scaling.
1311-
wq, w_scale_group = self._int4_row_quantize(wq)
1312-
# Pack int4 values together.
1313-
wq = self._pack_int4(wq)
1314-
# Shuffle weights and scales for faster compute.
1315-
wq, w_scale_group = torch.ops.fbgemm.preshuffle_i4(wq, w_scale_group)
1316-
return xq, wq, x_scale, w_scale, w_scale_group
1290+
return xq, wq, x_scale, row_scale, group_scale
13171291

1318-
def compute(self, xq, wq, x_scale, w_scale, w_scale_group):
1319-
out = torch.ops.fbgemm.f8i4bf16_shuffled(
1320-
xq, wq, x_scale, w_scale, w_scale_group
1292+
def compute(self, xq, wq, x_scale, row_scale, group_scale):
1293+
# Handle batched cases by looping over each batch.
1294+
if xq.dim() == 3:
1295+
B, M, _ = xq.shape
1296+
_, N, _ = wq.shape
1297+
y = torch.empty((B, M, N), device=xq.device, dtype=torch.bfloat16)
1298+
for i in range(B):
1299+
y[i] = torch.ops.fbgemm.f8i4bf16_shuffled(
1300+
xq[i], wq[i], x_scale[i], row_scale[i], group_scale[i]
1301+
)
1302+
return y
1303+
# Otherwise run gemm normally.
1304+
return torch.ops.fbgemm.f8i4bf16_shuffled(
1305+
xq, wq, x_scale, row_scale, group_scale
13211306
)
1322-
return out
13231307

1324-
def quantize_and_compute(self, x, w):
1325-
xq, wq, x_scale, w_scale, w_scale_group = self.quantize(x, w)
1326-
return self.compute(xq, wq, x_scale, w_scale, w_scale_group)
1308+
def quantize_and_compute(self, x, wq, row_scale, group_scale):
1309+
xq, wq, x_scale, row_scale, group_scale = self.quantize(
1310+
x, wq, row_scale, group_scale
1311+
)
1312+
return self.compute(xq, wq, x_scale, row_scale, group_scale)
13271313

13281314
@property
13291315
def name(self) -> str:
13301316
return "cutlass_f8i4_preshuffle"
13311317

1318+
@property
1319+
def hip(self) -> bool:
1320+
# Not yet supported on AMD.
1321+
return False
1322+
1323+
@property
1324+
def cuda(self) -> bool:
1325+
return True
1326+
13321327

13331328
@register_quantize_op
13341329
class BF16I4RowwiseGemm(F8I4RowwiseGemm):
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
# Helper functions for using FBGEMM quantized operators.
10+
11+
from typing import Tuple
12+
13+
import torch
14+
15+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import quantize_fp8_row
16+
17+
18+
def pack_int4(x: torch.Tensor) -> torch.Tensor:
19+
# Given int8 x, pack adjacent int4 values into a single int8.
20+
low_x = x[:, ::2]
21+
high_x = x[:, 1::2]
22+
23+
# High bits need to left shift, this also masks off extra bits.
24+
high_x = torch.bitwise_left_shift(high_x, 4)
25+
# Low bits need to have sign bits removed.
26+
low_x = torch.bitwise_and(low_x, 0xF)
27+
28+
# Recombine into a single value with bitwise or.
29+
return torch.bitwise_or(low_x, high_x).contiguous()
30+
31+
32+
def int4_row_quantize(
33+
x: torch.Tensor,
34+
group_size: int = 128,
35+
) -> Tuple[torch.Tensor, torch.Tensor]:
36+
"""
37+
Helper function to quantize a tensor to int4 with groupwise scales.
38+
39+
Args:
40+
x (Tensor): [N, K] Higher precision weight tensor to quantize.
41+
group_size (int): Number of elements to calculate group scale for.
42+
Returns:
43+
wq (Tensor): [N, K // 2] Quantized int4 tensor stored in int8 elements.
44+
group_scale (Tensor): [K / group_size, N] FP32 Scale per group.
45+
"""
46+
n_bit = 4 # Number of target bits.
47+
to_quant = x.reshape(-1, group_size).to(torch.float)
48+
49+
max_val = torch.abs(to_quant).amax(dim=1, keepdim=True)
50+
max_int = 2 ** (n_bit - 1)
51+
min_int = -(2 ** (n_bit - 1))
52+
scales = max_val.clamp(min=1e-6) / max_int
53+
54+
out = to_quant.div(scales).round().clamp_(min_int, max_int - 1)
55+
56+
# Cast to int8 and restore shape.
57+
out = out.to(dtype=torch.int8).reshape(x.shape)
58+
59+
# Scales should be in [num_groups, N] layout.
60+
scales = scales.view(x.shape[0], -1).t().contiguous()
61+
62+
return out, scales
63+
64+
65+
def quantize_int4_preshuffle(
66+
w: torch.Tensor, group_size: int = 128
67+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
68+
"""
69+
Quantizes an input weight tensor to int4 using preshuffling and scale packing.
70+
This function is intended to be used with fbgemms mixed dtype kernels and is expected
71+
to be applied to weights ahead of time. As such, it is not perfectly optimized.
72+
73+
Args:
74+
w (Tensor): [N, K] Higher precision weight tensor to quantize. May optionally have a batch dimension.
75+
group_size (int): Number of elements to calculate group scale for, must be at least 128.
76+
Returns:
77+
wq (Tensor): [N, K // 2] Quantized int4 weight tensor packed into int8 elements.
78+
row_scale (Tensor): [N] FP32 Scale per row of the weight tensor.
79+
group_scale (Tensor): [K / group_size, 8, N] FP8 Scale per group of the weight tensor.
80+
"""
81+
82+
def _quantize(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
83+
# Start by lowering weights to FP8 and producing row scales.
84+
wq, row_scale = quantize_fp8_row(w)
85+
86+
# Now reduce to INT4.
87+
wq, group_scale = int4_row_quantize(wq, group_size)
88+
# Reduce group scale to FP8.
89+
group_scale = group_scale.to(torch.float8_e4m3fn)
90+
91+
# Take quantized weights and pack them efficiently.
92+
wq = pack_int4(wq)
93+
94+
# Finally pack weights and scales into efficient preshuffled format.
95+
wq, group_scale = torch.ops.fbgemm.preshuffle_i4(wq, group_scale)
96+
97+
return wq, row_scale, group_scale
98+
99+
if w.ndim >= 3:
100+
orig_shape = w.shape
101+
# Flatten to 3 dimensions then iterate over batches.
102+
w = w.view(-1, *w.shape[1:])
103+
w.unbind(dim=0)
104+
wq = []
105+
row_scale = []
106+
group_scale = []
107+
for batch in w:
108+
wq_, row_scale_, group_scale_ = _quantize(batch)
109+
wq.append(wq_)
110+
row_scale.append(row_scale_)
111+
group_scale.append(group_scale_)
112+
wq = torch.stack(wq).view(*orig_shape[:-2], *wq[0].shape)
113+
row_scale = torch.stack(row_scale).view(*orig_shape[:-2], *row_scale[0].shape)
114+
group_scale = torch.stack(group_scale).view(
115+
*orig_shape[:-2], *group_scale[0].shape
116+
)
117+
else:
118+
wq, row_scale, group_scale = _quantize(w)
119+
return wq, row_scale, group_scale

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_rowwise.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ at::Tensor f8i4bf16_rowwise_impl(
4747

4848
int group_size = K / num_groups;
4949

50+
// Return immediately if input is empty.
51+
if (M == 0 || N == 0 || K == 0) {
52+
return at::zeros({M, N}, XQ.options().dtype(at::kBFloat16));
53+
}
5054
auto Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));
5155

5256
using ElementInputA = INPUT_DTYPE;

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_shuffled.cu

Lines changed: 51 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
#include "cutlass/util/mixed_dtype_utils.hpp"
2323
#include "cutlass/util/packed_stride.hpp"
2424

25-
#include "cutlass_extensions/include/kernel_mode.h"
26-
2725
namespace fbgemm_gpu {
2826

2927
#if CUDART_VERSION >= 12000
@@ -34,19 +32,14 @@ at::Tensor _f8i4bf16_shuffled(
3432
at::Tensor WQ,
3533
at::Tensor x_scale,
3634
at::Tensor w_scale,
37-
at::Tensor w_scale_group) {
35+
at::Tensor w_scale_group,
36+
at::Tensor Y) {
3837
// Get shape information from input tensors.
39-
int M = XQ.size(0);
40-
int K = XQ.size(1);
41-
int N = WQ.size(0);
42-
// Make sure w_scale_group is in proper format.
43-
TORCH_CHECK(
44-
w_scale_group.size(1) == 8,
45-
"Weights and group scales must be prepacked with preshuffle_i4.");
38+
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
39+
int K = XQ.size(-1);
40+
int N = size_to_dim_(WQ.dim() - 1, WQ.sizes());
4641
int num_groups = w_scale_group.size(0);
4742
int group_size = K / num_groups;
48-
// Allocate output.
49-
at::Tensor Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));
5043

5144
// Define input types.
5245
using MmaType = cutlass::float_e4m3_t;
@@ -273,56 +266,86 @@ at::Tensor f8i4bf16_shuffled(
273266
at::Tensor x_scale,
274267
at::Tensor w_scale,
275268
at::Tensor w_scale_group) {
276-
int M = XQ.size(0);
277-
int K = XQ.size(1);
278-
int N = WQ.size(0);
269+
int M = size_to_dim_(XQ.dim() - 1, XQ.sizes());
270+
int K = XQ.size(-1);
271+
int N = size_to_dim_(WQ.dim() - 1, WQ.sizes());
272+
// Check input types and shapes.
273+
TORCH_CHECK(
274+
XQ.is_cuda() && XQ.is_contiguous() && XQ.dtype() == at::kFloat8_e4m3fn,
275+
"XQ must be FP8 and contiguous on GPU.");
276+
TORCH_CHECK(
277+
WQ.size(-1) == K / 2 && WQ.is_cuda() && WQ.is_contiguous() &&
278+
WQ.dtype() == at::kChar,
279+
"WQ should be int8 (which represent two int4 values), have shape [..., N, K/2], "
280+
"and be contiguous on GPU.");
281+
TORCH_CHECK(
282+
x_scale.numel() == M && x_scale.dtype() == at::kFloat &&
283+
x_scale.is_cuda(),
284+
"x_scale must be fp32 and have M total elements.");
285+
TORCH_CHECK(
286+
w_scale.numel() == N && w_scale.dtype() == at::kFloat &&
287+
w_scale.is_cuda(),
288+
"Weight row scale should have N elements and be on GPU.");
289+
// Make sure w_scale_group is in proper format.
290+
TORCH_CHECK(
291+
w_scale_group.dtype() == at::kFloat8_e4m3fn && w_scale_group.dim() == 3 &&
292+
w_scale_group.size(1) == 8 && w_scale_group.size(2) == N,
293+
"Weights and group scales must be prepacked with preshuffle_i4. "
294+
"Group scales are expected to be FP8 and have shape [num_groups, 8, N].");
295+
296+
// Allocate output or return an empty tensor if input is empty.
297+
if (M == 0 || N == 0 || K == 0) {
298+
return at::zeros({M, N}, XQ.options().dtype(at::kBFloat16));
299+
}
300+
at::Tensor Y = at::empty({M, N}, XQ.options().dtype(at::kBFloat16));
301+
279302
// Use shape heuristics to dispatch to optimized kernel configuration.
280303
if (M <= 16) {
281304
return _f8i4bf16_shuffled<64, 16, 2, 1, 1, false>(
282-
XQ, WQ, x_scale, w_scale, w_scale_group);
305+
XQ, WQ, x_scale, w_scale, w_scale_group, Y);
283306
} else if (M <= 32) {
284307
return _f8i4bf16_shuffled<64, 32, 2, 1, 1, false>(
285-
XQ, WQ, x_scale, w_scale, w_scale_group);
308+
XQ, WQ, x_scale, w_scale, w_scale_group, Y);
286309
} else if (M <= 64) {
287310
return _f8i4bf16_shuffled<64, 64, 2, 1, 1, false>(
288-
XQ, WQ, x_scale, w_scale, w_scale_group);
311+
XQ, WQ, x_scale, w_scale, w_scale_group, Y);
289312
} else if (M <= 128) {
290313
return _f8i4bf16_shuffled<64, 128, 2, 1, 1, false>(
291-
XQ, WQ, x_scale, w_scale, w_scale_group);
314+
XQ, WQ, x_scale, w_scale, w_scale_group, Y);
292315
} else if (M <= 256) {
293316
if (N <= 4096) {
294317
return _f8i4bf16_shuffled<64, 128, 2, 1, 1, false>(
295-
XQ, WQ, x_scale, w_scale, w_scale_group);
318+
XQ, WQ, x_scale, w_scale, w_scale_group, Y);
296319
} else {
297320
return _f8i4bf16_shuffled<64, 256, 1, 1, 1, false>(
298-
XQ, WQ, x_scale, w_scale, w_scale_group);
321+
XQ, WQ, x_scale, w_scale, w_scale_group, Y);
299322
}
300323
} else if (M <= 512) {
301324
if (N <= 4096) {
302325
return _f8i4bf16_shuffled<64, 256, 2, 1, 1, false>(
303-
XQ, WQ, x_scale, w_scale, w_scale_group);
326+
XQ, WQ, x_scale, w_scale, w_scale_group, Y);
304327
} else {
305328
return _f8i4bf16_shuffled<128, 256, 2, 1, 1, true>(
306-
XQ, WQ, x_scale, w_scale, w_scale_group);
329+
XQ, WQ, x_scale, w_scale, w_scale_group, Y);
307330
}
308331
} else if (M <= 1024) {
309332
if (N <= 1024) {
310333
return _f8i4bf16_shuffled<64, 128, 2, 1, 1, false>(
311-
XQ, WQ, x_scale, w_scale, w_scale_group);
334+
XQ, WQ, x_scale, w_scale, w_scale_group, Y);
312335
} else if (N <= 2048) {
313336
return _f8i4bf16_shuffled<64, 256, 2, 1, 1, false>(
314-
XQ, WQ, x_scale, w_scale, w_scale_group);
337+
XQ, WQ, x_scale, w_scale, w_scale_group, Y);
315338
} else {
316339
return _f8i4bf16_shuffled<128, 256, 2, 1, 1, true>(
317-
XQ, WQ, x_scale, w_scale, w_scale_group);
340+
XQ, WQ, x_scale, w_scale, w_scale_group, Y);
318341
}
319342
} else {
320343
if (N <= 1024) {
321344
return _f8i4bf16_shuffled<64, 256, 2, 1, 1, false>(
322-
XQ, WQ, x_scale, w_scale, w_scale_group);
345+
XQ, WQ, x_scale, w_scale, w_scale_group, Y);
323346
} else {
324347
return _f8i4bf16_shuffled<128, 256, 2, 1, 1, true>(
325-
XQ, WQ, x_scale, w_scale, w_scale_group);
348+
XQ, WQ, x_scale, w_scale, w_scale_group, Y);
326349
}
327350
}
328351
}

0 commit comments

Comments
 (0)