Skip to content

Commit 7b4324e

Browse files
mxz297facebook-github-bot
authored andcommitted
Modernize bf16 cutlass grouped gemm (pytorch#3889)
Summary: Pull Request resolved: pytorch#3889 X-link: facebookresearch/FBGEMM#982 This diff unifies the API between FP8 and BF16 grouped gemm. Specifically we add the same dynamic, concatenated, and stacked APIs that are used for FP8 across both cutlass and CK. After this change, our tests can also be unified into a single grouped gemm test that covers all the various modes. Reviewed By: jiawenliu64 Differential Revision: D71920813
1 parent 4587ad0 commit 7b4324e

File tree

52 files changed

+804
-854
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+804
-854
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
triton_quantize_fp8_row,
2323
)
2424
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
25+
grouped_gemm,
2526
grouped_gemm_fp8_rowwise,
2627
)
2728
from fbgemm_gpu.experimental.gen_ai.quantize import quantize_int4_preshuffle
@@ -729,6 +730,45 @@ def cuda(self) -> bool:
729730
return True
730731

731732

733+
@register_quantize_op
734+
class BF16TritonStackedGroupedGemm(QuantizeOpBase):
735+
"""
736+
BF16 grouped matmul with stacked inputs implemented with triton.
737+
"""
738+
739+
def preprocess(self, x, w):
740+
m_values = [i.shape[0] for i in x]
741+
# Convert m_values into offsets into grouped tensor.
742+
m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
743+
w = torch.concat(w, dim=0).contiguous()
744+
# Also view input as flattened.
745+
x = torch.concat(x, dim=0).contiguous()
746+
# Return processed tensors.
747+
return x, w, m_sizes
748+
749+
def quantize(self, x, w, m_sizes):
750+
return x, w, m_sizes
751+
752+
def compute(self, x, w, m_sizes):
753+
return grouped_gemm(x, w, m_sizes)
754+
755+
def quantize_and_compute(self, x, w, m_sizes):
756+
x, w, m_sizes = self.quantize(x, w, m_sizes)
757+
return self.compute(x, w, m_sizes)
758+
759+
@property
760+
def name(self) -> str:
761+
return "triton_bf16_grouped_stacked"
762+
763+
@property
764+
def hip(self) -> bool:
765+
return True
766+
767+
@property
768+
def cuda(self) -> bool:
769+
return True
770+
771+
732772
@register_quantize_op
733773
class FP8TritonStackedGroupedGemm(QuantizeOpBase):
734774
"""
@@ -1488,6 +1528,46 @@ def cuda(self) -> bool:
14881528
return True
14891529

14901530

1531+
@register_quantize_op
1532+
class BF16GroupedStacked(QuantizeOpBase):
1533+
"""
1534+
BF16 grouped matmul with stacked inputs backed by cutlass or ck.
1535+
"""
1536+
1537+
def preprocess(self, x, w):
1538+
m_values = [i.shape[0] for i in x]
1539+
# Convert m_values into offsets into grouped tensor.
1540+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
1541+
# Group weights as single tensor.
1542+
w = torch.stack(w, dim=0).contiguous()
1543+
# Also view input as flattened.
1544+
x = torch.concat(x, dim=0).contiguous()
1545+
# Return processed tensors.
1546+
return x, w, m_sizes
1547+
1548+
def quantize(self, x, w, m_sizes):
1549+
return x, w, m_sizes
1550+
1551+
def compute(self, x, w, m_sizes):
1552+
return torch.ops.fbgemm.bf16bf16bf16_grouped_stacked(x, w, m_sizes)
1553+
1554+
def quantize_and_compute(self, x, w, m_sizes):
1555+
x, w, m_sizes = self.quantize(x, w, m_sizes)
1556+
return self.compute(x, w, m_sizes)
1557+
1558+
@property
1559+
def name(self) -> str:
1560+
return "bf16_grouped_stacked"
1561+
1562+
@property
1563+
def hip(self) -> bool:
1564+
return True
1565+
1566+
@property
1567+
def cuda(self) -> bool:
1568+
return True
1569+
1570+
14911571
@register_quantize_op
14921572
class BF16I4RowwiseGemm(F8I4RowwiseGemm):
14931573
"""

0 commit comments

Comments
 (0)