Skip to content

Commit 38fb7bb

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Preshuffled BF16I4 Gemm Kernel (#3913)
Summary: X-link: facebookresearch/FBGEMM#1003 This diff adds a preshuffled BF16I4 mixed dtype kernel using cutlass. Performance is quite compelling and shows substantial speedups for some shapes compared to bf16 x bf16 gemm backed by cublas. Notably, this preshuffle approach is 1.5-2X faster than the standard bf16i4 gemm for most shapes. Differential Revision: D72270467
1 parent 2b22ae2 commit 38fb7bb

File tree

7 files changed

+534
-46
lines changed

7 files changed

+534
-46
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1426,7 +1426,7 @@ def cuda(self) -> bool:
14261426
class F8I4ShuffledGemm(QuantizeOpBase):
14271427
def preprocess(self, x, w):
14281428
# Prequantize and pack weights.
1429-
wq, row_scale, group_scale = quantize_int4_preshuffle(w)
1429+
wq, (group_scale, row_scale) = quantize_int4_preshuffle(w)
14301430
return x, wq, row_scale, group_scale
14311431

14321432
def quantize(self, x, wq, row_scale, group_scale):
@@ -1470,6 +1470,49 @@ def cuda(self) -> bool:
14701470
return True
14711471

14721472

1473+
@register_quantize_op
1474+
class BF16I4ShuffledGemm(QuantizeOpBase):
1475+
def preprocess(self, x, w):
1476+
# Prequantize and pack weights.
1477+
wq, (group_scale, group_zero) = quantize_int4_preshuffle(w, dtype="bf16")
1478+
return x, wq, group_scale, group_zero
1479+
1480+
def quantize(self, x, wq, group_scale, group_zero):
1481+
# No extra action required.
1482+
return x, wq, group_scale, group_zero
1483+
1484+
def compute(self, x, wq, group_scale, group_zero):
1485+
# Handle batched cases by looping over each batch.
1486+
if x.dim() == 3:
1487+
B, M, _ = x.shape
1488+
_, N, _ = wq.shape
1489+
y = torch.empty((B, M, N), device=x.device, dtype=torch.bfloat16)
1490+
for i in range(B):
1491+
y[i] = torch.ops.fbgemm.bf16i4bf16_shuffled(
1492+
x[i], wq[i], group_scale[i], group_zero[i]
1493+
)
1494+
return y
1495+
# Otherwise run gemm normally.
1496+
return torch.ops.fbgemm.bf16i4bf16_shuffled(x, wq, group_scale, group_zero)
1497+
1498+
def quantize_and_compute(self, x, wq, group_scale, group_zero):
1499+
x, wq, group_scale, group_zero = self.quantize(x, wq, group_scale, group_zero)
1500+
return self.compute(x, wq, group_scale, group_zero)
1501+
1502+
@property
1503+
def name(self) -> str:
1504+
return "cutlass_bf16i4_preshuffle"
1505+
1506+
@property
1507+
def hip(self) -> bool:
1508+
# Not yet supported on AMD.
1509+
return False
1510+
1511+
@property
1512+
def cuda(self) -> bool:
1513+
return True
1514+
1515+
14731516
@register_quantize_op
14741517
class F8I4ShuffledGroupedGemm(QuantizeOpBase):
14751518
"""

fbgemm_gpu/experimental/gen_ai/gen_ai/quantize.py

Lines changed: 73 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,34 @@ def pack_int4(x: torch.Tensor) -> torch.Tensor:
2929
return torch.bitwise_or(low_x, high_x).contiguous()
3030

3131

32+
def int4_row_quantize_zp(
33+
x: torch.Tensor,
34+
group_size: int = 128,
35+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
36+
n_bit = 4 # Number of target bits.
37+
to_quant = x.reshape(-1, group_size).to(torch.float)
38+
39+
max_val = to_quant.amax(dim=1, keepdim=True)
40+
min_val = to_quant.amin(dim=1, keepdim=True)
41+
max_int = 2**n_bit - 1
42+
min_int = 0
43+
scales = (max_val - min_val).clamp(min=1e-6) / max_int
44+
45+
zeros = min_val + scales * (2 ** (n_bit - 1))
46+
47+
out = to_quant.sub(min_val).div(scales).round().clamp_(min_int, max_int)
48+
49+
# Recenter output and move to int8.
50+
out = (out - 2 ** (n_bit - 1)).to(dtype=torch.int8).reshape(x.shape)
51+
52+
# Cutlass expects column major layout for scale and zero point,
53+
# so we transpose here and make them contiguous.
54+
scales = scales.view(x.shape[0], -1).t().contiguous()
55+
zeros = zeros.view(x.shape[0], -1).t().contiguous()
56+
57+
return out, scales, zeros
58+
59+
3260
def int4_row_quantize(
3361
x: torch.Tensor,
3462
group_size: int = 128,
@@ -63,8 +91,8 @@ def int4_row_quantize(
6391

6492

6593
def quantize_int4_preshuffle(
66-
w: torch.Tensor, group_size: int = 128
67-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
94+
w: torch.Tensor, group_size: int = 128, dtype: str = "fp8"
95+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
6896
"""
6997
Quantizes an input weight tensor to int4 using preshuffling and scale packing.
7098
This function is intended to be used with fbgemms mixed dtype kernels and is expected
@@ -73,47 +101,57 @@ def quantize_int4_preshuffle(
73101
Args:
74102
w (Tensor): [N, K] Higher precision weight tensor to quantize. May optionally have a batch dimension.
75103
group_size (int): Number of elements to calculate group scale for, must be at least 128.
104+
dtype (torch.dtype): Type of corresponding activations. Must be fp8 or bf16.
76105
Returns:
77106
wq (Tensor): [N, K // 2] Quantized int4 weight tensor packed into int8 elements.
78-
row_scale (Tensor): [N] FP32 Scale per row of the weight tensor.
79-
group_scale (Tensor): [K / group_size, 8, N] FP8 Scale per group of the weight tensor.
107+
scales (Tuple[Tensor]): Scale tensors for the specified activation type. When FP8 is used,
108+
scales is a tuple of row_scale ([N]) and group_scale ([K / group_size, 8, N]). When BF16 is
109+
used, scales is a tuple of group_scale([K / group_size, N]) and group_zero ([K / group_size, N])
80110
"""
81111

82-
def _quantize(w: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
83-
# Start by lowering weights to FP8 and producing row scales.
84-
wq, row_scale = quantize_fp8_row(w)
85-
86-
# Now reduce to INT4.
87-
wq, group_scale = int4_row_quantize(wq, group_size)
88-
# Reduce group scale to FP8.
89-
group_scale = group_scale.to(torch.float8_e4m3fn)
90-
91-
# Take quantized weights and pack them efficiently.
92-
wq = pack_int4(wq)
93-
94-
# Finally pack weights and scales into efficient preshuffled format.
95-
wq, group_scale = torch.ops.fbgemm.preshuffle_i4(wq, group_scale)
96-
97-
return wq, row_scale, group_scale
112+
def _quantize(
113+
w: torch.Tensor, dtype: str = "fp8"
114+
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
115+
116+
if dtype == "fp8":
117+
# Start by lowering weights to FP8 and producing row scales.
118+
wq, row_scale = quantize_fp8_row(w)
119+
120+
# Now reduce to INT4.
121+
wq, group_scale = int4_row_quantize(wq, group_size)
122+
# Reduce group scale to FP8.
123+
group_scale = group_scale.to(torch.float8_e4m3fn)
124+
# Take quantized weights and pack them efficiently.
125+
wq = pack_int4(wq)
126+
# Finally pack weights and scales into efficient preshuffled format.
127+
wq, group_scale = torch.ops.fbgemm.preshuffle_i4(wq, group_scale)
128+
return wq, (group_scale, row_scale)
129+
130+
elif dtype == "bf16":
131+
wq, group_scale, group_zero = int4_row_quantize_zp(w, group_size)
132+
# Set scales to activation type.
133+
group_scale = group_scale.to(torch.bfloat16)
134+
group_zero = group_zero.to(torch.bfloat16)
135+
# Take quantized weights and pack them efficiently.
136+
wq = pack_int4(wq)
137+
# Finally pack weights and scales into efficient preshuffled format.
138+
wq, group_scale = torch.ops.fbgemm.preshuffle_i4(wq, group_scale)
139+
return wq, (group_scale, group_zero)
140+
else:
141+
raise NotImplementedError("Only fp8 and bf16 activations supported.")
98142

99143
if w.ndim >= 3:
100144
orig_shape = w.shape
101145
# Flatten to 3 dimensions then iterate over batches.
102-
w = w.view(-1, *w.shape[1:])
103-
w.unbind(dim=0)
104-
wq = []
105-
row_scale = []
106-
group_scale = []
107-
for batch in w:
108-
wq_, row_scale_, group_scale_ = _quantize(batch)
109-
wq.append(wq_)
110-
row_scale.append(row_scale_)
111-
group_scale.append(group_scale_)
146+
wq, scales = zip(*[_quantize(i, dtype=dtype) for i in w])
112147
wq = torch.stack(wq).view(*orig_shape[:-2], *wq[0].shape)
113-
row_scale = torch.stack(row_scale).view(*orig_shape[:-2], *row_scale[0].shape)
114-
group_scale = torch.stack(group_scale).view(
115-
*orig_shape[:-2], *group_scale[0].shape
148+
# Decompose then stack scales back into a tuple.
149+
a_scales, b_scales = zip(*scales)
150+
scales = (
151+
torch.stack(a_scales).view(*orig_shape[:-2], *a_scales[0].shape),
152+
torch.stack(b_scales).view(*orig_shape[:-2], *b_scales[0].shape),
116153
)
117154
else:
118-
wq, row_scale, group_scale = _quantize(w)
119-
return wq, row_scale, group_scale
155+
wq, scales = _quantize(w, dtype=dtype)
156+
157+
return wq, scales

0 commit comments

Comments
 (0)