|
22 | 22 | triton_quantize_fp8_row,
|
23 | 23 | )
|
24 | 24 | from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
|
| 25 | + grouped_gemm, |
25 | 26 | grouped_gemm_fp8_rowwise,
|
26 | 27 | )
|
27 | 28 | from fbgemm_gpu.experimental.gen_ai.quantize import quantize_int4_preshuffle
|
@@ -729,6 +730,45 @@ def cuda(self) -> bool:
|
729 | 730 | return True
|
730 | 731 |
|
731 | 732 |
|
| 733 | +@register_quantize_op |
| 734 | +class BF16TritonStackedGroupedGemm(QuantizeOpBase): |
| 735 | + """ |
| 736 | + BF16 grouped matmul with stacked inputs implemented with triton. |
| 737 | + """ |
| 738 | + |
| 739 | + def preprocess(self, x, w): |
| 740 | + m_values = [i.shape[0] for i in x] |
| 741 | + # Convert m_values into offsets into grouped tensor. |
| 742 | + m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device) |
| 743 | + w = torch.concat(w, dim=0).contiguous() |
| 744 | + # Also view input as flattened. |
| 745 | + x = torch.concat(x, dim=0).contiguous() |
| 746 | + # Return processed tensors. |
| 747 | + return x, w, m_sizes |
| 748 | + |
| 749 | + def quantize(self, x, w, m_sizes): |
| 750 | + return x, w, m_sizes |
| 751 | + |
| 752 | + def compute(self, x, w, m_sizes): |
| 753 | + return grouped_gemm(x, w, m_sizes) |
| 754 | + |
| 755 | + def quantize_and_compute(self, x, w, m_sizes): |
| 756 | + x, w, m_sizes = self.quantize(x, w, m_sizes) |
| 757 | + return self.compute(x, w, m_sizes) |
| 758 | + |
| 759 | + @property |
| 760 | + def name(self) -> str: |
| 761 | + return "triton_bf16_grouped_stacked" |
| 762 | + |
| 763 | + @property |
| 764 | + def hip(self) -> bool: |
| 765 | + return True |
| 766 | + |
| 767 | + @property |
| 768 | + def cuda(self) -> bool: |
| 769 | + return True |
| 770 | + |
| 771 | + |
732 | 772 | @register_quantize_op
|
733 | 773 | class FP8TritonStackedGroupedGemm(QuantizeOpBase):
|
734 | 774 | """
|
@@ -1488,6 +1528,46 @@ def cuda(self) -> bool:
|
1488 | 1528 | return True
|
1489 | 1529 |
|
1490 | 1530 |
|
| 1531 | +@register_quantize_op |
| 1532 | +class BF16GroupedStacked(QuantizeOpBase): |
| 1533 | + """ |
| 1534 | + BF16 grouped matmul with stacked inputs backed by cutlass or ck. |
| 1535 | + """ |
| 1536 | + |
| 1537 | + def preprocess(self, x, w): |
| 1538 | + m_values = [i.shape[0] for i in x] |
| 1539 | + # Convert m_values into offsets into grouped tensor. |
| 1540 | + m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device) |
| 1541 | + # Group weights as single tensor. |
| 1542 | + w = torch.stack(w, dim=0).contiguous() |
| 1543 | + # Also view input as flattened. |
| 1544 | + x = torch.concat(x, dim=0).contiguous() |
| 1545 | + # Return processed tensors. |
| 1546 | + return x, w, m_sizes |
| 1547 | + |
| 1548 | + def quantize(self, x, w, m_sizes): |
| 1549 | + return x, w, m_sizes |
| 1550 | + |
| 1551 | + def compute(self, x, w, m_sizes): |
| 1552 | + return torch.ops.fbgemm.bf16bf16bf16_grouped_stacked(x, w, m_sizes) |
| 1553 | + |
| 1554 | + def quantize_and_compute(self, x, w, m_sizes): |
| 1555 | + x, w, m_sizes = self.quantize(x, w, m_sizes) |
| 1556 | + return self.compute(x, w, m_sizes) |
| 1557 | + |
| 1558 | + @property |
| 1559 | + def name(self) -> str: |
| 1560 | + return "bf16_grouped_stacked" |
| 1561 | + |
| 1562 | + @property |
| 1563 | + def hip(self) -> bool: |
| 1564 | + return True |
| 1565 | + |
| 1566 | + @property |
| 1567 | + def cuda(self) -> bool: |
| 1568 | + return True |
| 1569 | + |
| 1570 | + |
1491 | 1571 | @register_quantize_op
|
1492 | 1572 | class BF16I4RowwiseGemm(F8I4RowwiseGemm):
|
1493 | 1573 | """
|
|
0 commit comments