Skip to content

Commit 9391323

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Add Preshuffled FP8 x INT4 Grouped Gemm Kernel (pytorch#3800)
Summary: Working on adding support for stacked mixed dtype grouped gemm with preshuffling. Differential Revision: D70870933
1 parent e80288d commit 9391323

File tree

4 files changed

+597
-1
lines changed

4 files changed

+597
-1
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1325,6 +1325,70 @@ def cuda(self) -> bool:
13251325
return True
13261326

13271327

1328+
@register_quantize_op
1329+
class F8I4ShuffledGroupedGemm(QuantizeOpBase):
1330+
"""
1331+
FP8 x Int4 mixed dtype grouped gemm with preshuffling.
1332+
"""
1333+
1334+
def preprocess(self, x, w):
1335+
assert isinstance(x, list) and isinstance(
1336+
w, list
1337+
), "Only supported for grouped inputs."
1338+
m_values = [i.shape[0] for i in x]
1339+
# Convert m_values into offsets into grouped tensor.
1340+
m_offsets = torch.tensor(np.cumsum(m_values)).to(
1341+
dtype=torch.int64, device=x[0].device
1342+
)
1343+
# Quantize weights.
1344+
# TODO Only rowwise scaling is currently supported. This needs to be fixed.
1345+
K = x[0].shape[-1]
1346+
wq, row_scale, group_scale = zip(
1347+
*[quantize_int4_preshuffle(i, group_size=K) for i in w]
1348+
)
1349+
# Group weights as single tensor.
1350+
wq = torch.stack(wq, dim=0).contiguous()
1351+
row_scale = torch.stack(row_scale, dim=0).contiguous()
1352+
group_scale = torch.stack(group_scale, dim=0).contiguous()
1353+
# Also view input as flattened.
1354+
x = torch.concat(x, dim=0).contiguous()
1355+
# Return processed tensors.
1356+
return x, wq, row_scale, group_scale, m_offsets
1357+
1358+
def quantize(self, x, wq, row_scale, group_scale, m_offsets):
1359+
B = x.shape[0]
1360+
xq, x_scale = triton_quantize_fp8_row(x)
1361+
x_scale = x_scale.view(B, -1)
1362+
return xq, wq, x_scale, row_scale, group_scale, m_offsets
1363+
1364+
def compute(self, xq, wq, x_scale, row_scale, group_scale, m_offsets):
1365+
out = torch.ops.fbgemm.f8i4bf16_shuffled_grouped(
1366+
xq, wq, x_scale, row_scale, group_scale, m_offsets
1367+
)
1368+
return out
1369+
1370+
def quantize_and_compute(self, x, wq, row_scale, group_scale, m_offsets):
1371+
xq, wq, x_scale, row_scale, group_scale, m_offsets = self.quantize(
1372+
x, wq, row_scale, group_scale, m_offsets
1373+
)
1374+
return self.compute(xq, wq, x_scale, row_scale, group_scale, m_offsets)
1375+
1376+
@property
1377+
def name(self) -> str:
1378+
if torch.version.cuda:
1379+
return "cutlass_f8i4_grouped_preshuffle"
1380+
else:
1381+
return "ck_f8i4_grouped_preshuffle"
1382+
1383+
@property
1384+
def hip(self) -> bool:
1385+
return False
1386+
1387+
@property
1388+
def cuda(self) -> bool:
1389+
return True
1390+
1391+
13281392
@register_quantize_op
13291393
class BF16I4RowwiseGemm(F8I4RowwiseGemm):
13301394
"""

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_shuffled.cu

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ at::Tensor _f8i4bf16_shuffled(
186186
auto shape_B = cute::make_shape(N, K, 1);
187187
StrideA stride_A =
188188
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, 1));
189-
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B);
190189
StrideC stride_C =
191190
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(N, M, 1));
192191
LayoutB_Reordered layout_B_reordered =

0 commit comments

Comments
 (0)