Skip to content

Commit 107e9f6

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Add Preshuffled FP8 x INT4 Grouped Gemm Kernel (pytorch#3800)
Summary: X-link: facebookresearch/FBGEMM#897 Efficient FP8xINT4 grouped gemm with preshuffling and scale packing. This implementation uses the "stacked" API where inputs and outputs are single contiguous tensors and the group boundaries are indicated with an `M_sizes` tensor that contains the number of rows in each group. Reviewed By: jiawenliu64 Differential Revision: D70870933
1 parent c7343e0 commit 107e9f6

File tree

4 files changed

+597
-1
lines changed

4 files changed

+597
-1
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,6 +1374,70 @@ def cuda(self) -> bool:
13741374
return True
13751375

13761376

1377+
@register_quantize_op
1378+
class F8I4ShuffledGroupedGemm(QuantizeOpBase):
1379+
"""
1380+
FP8 x Int4 mixed dtype grouped gemm with preshuffling.
1381+
"""
1382+
1383+
def preprocess(self, x, w):
1384+
assert isinstance(x, list) and isinstance(
1385+
w, list
1386+
), "Only supported for grouped inputs."
1387+
m_values = [i.shape[0] for i in x]
1388+
# Convert m_values into offsets into grouped tensor.
1389+
m_offsets = torch.tensor(np.cumsum(m_values)).to(
1390+
dtype=torch.int64, device=x[0].device
1391+
)
1392+
# Quantize weights.
1393+
# TODO Only rowwise scaling is currently supported. This needs to be fixed.
1394+
K = x[0].shape[-1]
1395+
wq, row_scale, group_scale = zip(
1396+
*[quantize_int4_preshuffle(i, group_size=K) for i in w]
1397+
)
1398+
# Group weights as single tensor.
1399+
wq = torch.stack(wq, dim=0).contiguous()
1400+
row_scale = torch.stack(row_scale, dim=0).contiguous()
1401+
group_scale = torch.stack(group_scale, dim=0).contiguous()
1402+
# Also view input as flattened.
1403+
x = torch.concat(x, dim=0).contiguous()
1404+
# Return processed tensors.
1405+
return x, wq, row_scale, group_scale, m_offsets
1406+
1407+
def quantize(self, x, wq, row_scale, group_scale, m_offsets):
1408+
B = x.shape[0]
1409+
xq, x_scale = triton_quantize_fp8_row(x)
1410+
x_scale = x_scale.view(B, -1)
1411+
return xq, wq, x_scale, row_scale, group_scale, m_offsets
1412+
1413+
def compute(self, xq, wq, x_scale, row_scale, group_scale, m_offsets):
1414+
out = torch.ops.fbgemm.f8i4bf16_shuffled_grouped(
1415+
xq, wq, x_scale, row_scale, group_scale, m_offsets
1416+
)
1417+
return out
1418+
1419+
def quantize_and_compute(self, x, wq, row_scale, group_scale, m_offsets):
1420+
xq, wq, x_scale, row_scale, group_scale, m_offsets = self.quantize(
1421+
x, wq, row_scale, group_scale, m_offsets
1422+
)
1423+
return self.compute(xq, wq, x_scale, row_scale, group_scale, m_offsets)
1424+
1425+
@property
1426+
def name(self) -> str:
1427+
if torch.version.cuda:
1428+
return "cutlass_f8i4_grouped_preshuffle"
1429+
else:
1430+
return "ck_f8i4_grouped_preshuffle"
1431+
1432+
@property
1433+
def hip(self) -> bool:
1434+
return False
1435+
1436+
@property
1437+
def cuda(self) -> bool:
1438+
return True
1439+
1440+
13771441
@register_quantize_op
13781442
class BF16I4RowwiseGemm(F8I4RowwiseGemm):
13791443
"""

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_shuffled.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ at::Tensor _f8i4bf16_shuffled(
186186
auto shape_B = cute::make_shape(N, K, 1);
187187
StrideA stride_A =
188188
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1));
189-
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
190189
StrideC stride_C =
191190
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(N, M, 1));
192191
LayoutB_Reordered layout_B_reordered =

0 commit comments

Comments
 (0)