Skip to content

Commit ee4c88b

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Enable rowwise scaling for DeepGemm (pytorch#3874)
Summary: X-link: facebookresearch/FBGEMM#964 Pull Request resolved: pytorch#3874 This diff adds [ngimel's support for DeepGemm rowwise scaling](https://github.com/ngimel/DeepGEMM/tree/rowwise) to our fbcode copy. It also includes a few deepgemm updates that allow operation on M<128, which is important for any real use case. Performance is increased considerably by the use of rowwise scaling, especially in memory bound cases. Notably, this makes DeepGemm the premier solution for slow accumulation as it now overall outperforms cublas + unfused rowwise scaling. {F1976375307} Reviewed By: jianyuh Differential Revision: D71748927 fbshipit-source-id: 87e287a2cec284bd8fd7c5e80603065a0d662f53
1 parent a5f8150 commit ee4c88b

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -875,6 +875,48 @@ def cuda(self) -> bool:
875875

876876

877877
@register_quantize_op
878+
class DeepGemmRowwise(QuantizeOpBase):
879+
"""
880+
FP8 matmul with rowwise scaling implemented with DeepGemm.
881+
"""
882+
883+
def preprocess(self, x, w):
884+
# Quantize weights.
885+
wq, w_scale = quantize_fp8_row(w)
886+
# allocate output.
887+
out = torch.empty(
888+
x.shape[0], wq.shape[0], device=x.device, dtype=torch.bfloat16
889+
)
890+
# Return processed tensors.
891+
return x, wq, w_scale, out
892+
893+
def quantize(self, x, wq, w_scale, out):
894+
xq, x_scale = quantize_fp8_row(x)
895+
# Pretranspose scales to deepgemm format.
896+
x_scale = get_col_major_tma_aligned_tensor(x_scale, rowwise_scaling=True)
897+
return xq, wq, x_scale, w_scale, out
898+
899+
def compute(self, xq, wq, x_scale, w_scale, out):
900+
gemm_fp8_fp8_bf16_nt((xq, x_scale), (wq, w_scale), out)
901+
return out
902+
903+
def quantize_and_compute(self, x, wq, w_scale, out):
904+
xq, wq, x_scale, w_scale, out = self.quantize(x, wq, w_scale, out)
905+
return self.compute(xq, wq, x_scale, w_scale, out)
906+
907+
@property
908+
def name(self) -> str:
909+
return "deepgemm_rowwise"
910+
911+
@property
912+
def hip(self) -> bool:
913+
return False
914+
915+
@property
916+
def cuda(self) -> bool:
917+
return True
918+
919+
878920
class FP8StackedGroupedGemm(QuantizeOpBase):
879921
"""
880922
FP8 grouped matmul with rowwise scaling and stacked inputs.

0 commit comments

Comments
 (0)