Skip to content

Commit 3eadc7c

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Modernize bf16 cutlass grouped gemm (pytorch#3889)
Summary: Pull Request resolved: pytorch#3889 X-link: facebookresearch/FBGEMM#982 This diff unifies the API between FP8 and BF16 grouped gemm. Specifically we add the same dynamic, concatenated, and stacked APIs that are used for FP8 across both cutlass and CK. After this change, our tests can also be unified into a single grouped gemm test that covers all the various modes. Reviewed By: jiawenliu64, mxz297 Differential Revision: D71920813
1 parent a5f8150 commit 3eadc7c

File tree

52 files changed

+804
-854
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+804
-854
lines changed

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
triton_quantize_fp8_row,
2323
)
2424
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
25+
grouped_gemm,
2526
grouped_gemm_fp8_rowwise,
2627
)
2728
from fbgemm_gpu.experimental.gen_ai.quantize import quantize_int4_preshuffle
@@ -729,6 +730,45 @@ def cuda(self) -> bool:
729730
return True
730731

731732

733+
@register_quantize_op
734+
class BF16TritonStackedGroupedGemm(QuantizeOpBase):
735+
"""
736+
BF16 grouped matmul with stacked inputs implemented with triton.
737+
"""
738+
739+
def preprocess(self, x, w):
740+
m_values = [i.shape[0] for i in x]
741+
# Convert m_values into offsets into grouped tensor.
742+
m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
743+
w = torch.concat(w, dim=0).contiguous()
744+
# Also view input as flattened.
745+
x = torch.concat(x, dim=0).contiguous()
746+
# Return processed tensors.
747+
return x, w, m_sizes
748+
749+
def quantize(self, x, w, m_sizes):
750+
return x, w, m_sizes
751+
752+
def compute(self, x, w, m_sizes):
753+
return grouped_gemm(x, w, m_sizes)
754+
755+
def quantize_and_compute(self, x, w, m_sizes):
756+
x, w, m_sizes = self.quantize(x, w, m_sizes)
757+
return self.compute(x, w, m_sizes)
758+
759+
@property
760+
def name(self) -> str:
761+
return "triton_bf16_grouped_stacked"
762+
763+
@property
764+
def hip(self) -> bool:
765+
return True
766+
767+
@property
768+
def cuda(self) -> bool:
769+
return True
770+
771+
732772
@register_quantize_op
733773
class FP8TritonStackedGroupedGemm(QuantizeOpBase):
734774
"""
@@ -1446,6 +1486,46 @@ def cuda(self) -> bool:
14461486
return True
14471487

14481488

1489+
@register_quantize_op
1490+
class BF16GroupedStacked(QuantizeOpBase):
1491+
"""
1492+
BF16 grouped matmul with stacked inputs backed by cutlass or ck.
1493+
"""
1494+
1495+
def preprocess(self, x, w):
1496+
m_values = [i.shape[0] for i in x]
1497+
# Convert m_values into offsets into grouped tensor.
1498+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
1499+
# Group weights as single tensor.
1500+
w = torch.stack(w, dim=0).contiguous()
1501+
# Also view input as flattened.
1502+
x = torch.concat(x, dim=0).contiguous()
1503+
# Return processed tensors.
1504+
return x, w, m_sizes
1505+
1506+
def quantize(self, x, w, m_sizes):
1507+
return x, w, m_sizes
1508+
1509+
def compute(self, x, w, m_sizes):
1510+
return torch.ops.fbgemm.bf16bf16bf16_grouped_stacked(x, w, m_sizes)
1511+
1512+
def quantize_and_compute(self, x, w, m_sizes):
1513+
x, w, m_sizes = self.quantize(x, w, m_sizes)
1514+
return self.compute(x, w, m_sizes)
1515+
1516+
@property
1517+
def name(self) -> str:
1518+
return "bf16_grouped_stacked"
1519+
1520+
@property
1521+
def hip(self) -> bool:
1522+
return True
1523+
1524+
@property
1525+
def cuda(self) -> bool:
1526+
return True
1527+
1528+
14491529
@register_quantize_op
14501530
class BF16I4RowwiseGemm(F8I4RowwiseGemm):
14511531
"""

0 commit comments

Comments
 (0)