Skip to content

Commit 851815d

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Enable groupwise scales for F8I4 Grouped Gemm (pytorch#3884)
Summary: X-link: facebookresearch/FBGEMM#975 Pull Request resolved: pytorch#3884 Due to cutlass support limitations, we previously required that F8I4 grouped gemm use rowwise scales for its weights. This leaves a lot of accuracy on the table compared to groupwise scales (which we use for standard f8i4 gemm). This diff adds support for groupwise scaling and lifts the restriction from our implementation. Reviewed By: jiawenliu64 Differential Revision: D71905839 fbshipit-source-id: f49d54aad558730992c79fe1c714ba3a66fa523e
1 parent 6a6db7c commit 851815d

File tree

2 files changed

+1
-8
lines changed

2 files changed

+1
-8
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1402,10 +1402,7 @@ def preprocess(self, x, w):
14021402
m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
14031403
# Quantize weights.
14041404
# TODO Only rowwise scaling is currently supported. This needs to be fixed.
1405-
K = x[0].shape[-1]
1406-
wq, row_scale, group_scale = zip(
1407-
*[quantize_int4_preshuffle(i, group_size=K) for i in w]
1408-
)
1405+
wq, row_scale, group_scale = zip(*[quantize_int4_preshuffle(i) for i in w])
14091406
# Group weights as single tensor.
14101407
wq = torch.stack(wq, dim=0).contiguous()
14111408
row_scale = torch.stack(row_scale, dim=0).contiguous()

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8i4bf16_shuffled_grouped.cu

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,6 @@ void _f8i4bf16_shuffled_grouped(
145145
// Group scales should have shape [G, num_scale_groups, 8, N]
146146
int num_scale_groups = w_scale_group.size(1);
147147
int group_size = K / num_scale_groups;
148-
TORCH_CHECK(
149-
num_scale_groups == 1,
150-
"Mixed dtype grouped gemm only supports rowwise scaling currently (group_size=K).");
151-
152148
// Define cutlass types.
153149
using ProblemShape = cutlass::gemm::GroupProblemShape<
154150
cute::Shape<int, int, int>>; // <M,N,K> per group.

0 commit comments

Comments
 (0)