Skip to content

Commit 7e88324

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Add Preshuffled FP8 x INT4 Grouped Gemm Kernel (pytorch#3800)
Summary: X-link: facebookresearch/FBGEMM#897 Pull Request resolved: pytorch#3800 Efficient FP8xINT4 grouped gemm with preshuffling and scale packing. This implementation uses the "stacked" API where inputs and outputs are single contiguous tensors and the group boundaries are indicated with an `M_sizes` tensor that contains the number of rows in each group. Differential Revision: D70870933
1 parent e371eb7 commit 7e88324

File tree

4 files changed

+597
-1
lines changed

4 files changed

+597
-1
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,70 @@ def cuda(self) -> bool:
13251325
return True
13261326

13271327

1328+
@register_quantize_op
1329+
class F8I4ShuffledGroupedGemm(QuantizeOpBase):
1330+
"""
1331+
FP8 x Int4 mixed dtype grouped gemm with preshuffling.
1332+
"""
1333+
1334+
def preprocess(self, x, w):
1335+
assert isinstance(x, list) and isinstance(
1336+
w, list
1337+
), "Only supported for grouped inputs."
1338+
m_values = [i.shape[0] for i in x]
1339+
# Convert m_values into offsets into grouped tensor.
1340+
m_offsets = torch.tensor(np.cumsum(m_values)).to(
1341+
dtype=torch.int64, device=x[0].device
1342+
)
1343+
# Quantize weights.
1344+
# TODO Only rowwise scaling is currently supported. This needs to be fixed.
1345+
K = x[0].shape[-1]
1346+
wq, row_scale, group_scale = zip(
1347+
*[quantize_int4_preshuffle(i, group_size=K) for i in w]
1348+
)
1349+
# Group weights as single tensor.
1350+
wq = torch.stack(wq, dim=0).contiguous()
1351+
row_scale = torch.stack(row_scale, dim=0).contiguous()
1352+
group_scale = torch.stack(group_scale, dim=0).contiguous()
1353+
# Also view input as flattened.
1354+
x = torch.concat(x, dim=0).contiguous()
1355+
# Return processed tensors.
1356+
return x, wq, row_scale, group_scale, m_offsets
1357+
1358+
def quantize(self, x, wq, row_scale, group_scale, m_offsets):
1359+
B = x.shape[0]
1360+
xq, x_scale = triton_quantize_fp8_row(x)
1361+
x_scale = x_scale.view(B, -1)
1362+
return xq, wq, x_scale, row_scale, group_scale, m_offsets
1363+
1364+
def compute(self, xq, wq, x_scale, row_scale, group_scale, m_offsets):
1365+
out = torch.ops.fbgemm.f8i4bf16_shuffled_grouped(
1366+
xq, wq, x_scale, row_scale, group_scale, m_offsets
1367+
)
1368+
return out
1369+
1370+
def quantize_and_compute(self, x, wq, row_scale, group_scale, m_offsets):
1371+
xq, wq, x_scale, row_scale, group_scale, m_offsets = self.quantize(
1372+
x, wq, row_scale, group_scale, m_offsets
1373+
)
1374+
return self.compute(xq, wq, x_scale, row_scale, group_scale, m_offsets)
1375+
1376+
@property
1377+
def name(self) -> str:
1378+
if torch.version.cuda:
1379+
return "cutlass_f8i4_grouped_preshuffle"
1380+
else:
1381+
return "ck_f8i4_grouped_preshuffle"
1382+
1383+
@property
1384+
def hip(self) -> bool:
1385+
return False
1386+
1387+
@property
1388+
def cuda(self) -> bool:
1389+
return True
1390+
1391+
13281392
@register_quantize_op
13291393
class BF16I4RowwiseGemm(F8I4RowwiseGemm):
13301394
"""

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_shuffled.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ at::Tensor _f8i4bf16_shuffled(
186186
auto shape_B = cute::make_shape(N, K, 1);
187187
StrideA stride_A =
188188
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1));
189-
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
190189
StrideC stride_C =
191190
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(N, M, 1));
192191
LayoutB_Reordered layout_B_reordered =

0 commit comments

Comments
 (0)