Skip to content

Commit f409a85

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Refactor stacked version of FP8 Grouped Gemm for reduced overhead (pytorch#780)
Summary: X-link: pytorch#3699 Pull Request resolved: facebookresearch/FBGEMM#780 Currently, the stacked version of FP8 grouped gemm accepts lists of tensor inputs and produces a single tensor output. This reduces quite a bit of overhead when cuda graphs are used, but still requires splitting input tensors in prefill which can be costly. This diff updates the input types of stacked grouped gemm to support single tensors. Notably, since M varies across group and we do no padding, this change requires that we provide a new input tensor called `M_sizes` that indicates the number of rows in each group. This diff also includes a long overdue refactor of grouped gemm setup for nvidia such that we only launch a single kernel rather than one per group. This should reduce overhead by quite a bit in some cases. Reviewed By: jiawenliu64, mxz297 Differential Revision: D69544396 fbshipit-source-id: bdbcf0df3e4c479df84996cc22b167a68ea75b84
1 parent bc52f44 commit f409a85

File tree

74 files changed

+516
-344
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

74 files changed

+516
-344
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,6 +867,55 @@ def cuda(self) -> bool:
867867
return True
868868

869869

870+
class FP8StackedGroupedGemm(QuantizeOpBase):
871+
"""
872+
FP8 grouped matmul with rowwise scaling and stacked inputs.
873+
"""
874+
875+
def preprocess(self, x, w):
876+
m_values = [i.shape[0] for i in x]
877+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
878+
# Quantize weights.
879+
wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
880+
# Group weights as single tensor.
881+
wq = torch.stack(wq, dim=0).contiguous()
882+
w_scale = torch.stack(w_scale, dim=0).contiguous()
883+
# Also view input as flattened.
884+
x = torch.concat(x, dim=0).contiguous()
885+
# Return processed tensors.
886+
return x, wq, w_scale, m_sizes
887+
888+
def quantize(self, x, wq, w_scale, m_sizes):
889+
B = x.shape[0]
890+
xq, x_scale = triton_quantize_fp8_row(x)
891+
x_scale = x_scale.view(B, -1)
892+
return xq, wq, x_scale, w_scale, m_sizes
893+
894+
def compute(self, xq, wq, x_scale, w_scale, m_sizes):
895+
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_stacked(
896+
xq, wq, x_scale, w_scale, m_sizes
897+
)
898+
899+
def quantize_and_compute(self, x, wq, w_scale, m_sizes):
900+
xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, wq, w_scale, m_sizes)
901+
return self.compute(xq, wq, x_scale, w_scale, m_sizes)
902+
903+
@property
904+
def name(self) -> str:
905+
if torch.version.cuda:
906+
return "cutlass_grouped_stacked"
907+
else:
908+
return "ck_grouped_stacked"
909+
910+
@property
911+
def hip(self) -> bool:
912+
return True
913+
914+
@property
915+
def cuda(self) -> bool:
916+
return True
917+
918+
870919
@register_quantize_op
871920
class BF16GroupedGemm(QuantizeOpBase):
872921
"""

fbgemm_gpu/experimental/gen_ai/src/quantize/ck_extensions/fp8_rowwise_batched/kernels/fp8_rowwise_batched_common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ at::Tensor f8f8bf16_rowwise_batched_impl(
132132
int B = XQ.size(0);
133133
int M = XQ.size(1);
134134
int N = WQ.size(1);
135-
int K = XQ.size(2);
135+
int K = WQ.size(2);
136136

137137
int StrideA = K;
138138
int StrideB = K;

0 commit comments

Comments
 (0)