|
22 | 22 | triton_quantize_fp8_row,
|
23 | 23 | )
|
24 | 24 | from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
|
| 25 | + grouped_gemm, |
25 | 26 | grouped_gemm_fp8_rowwise,
|
26 | 27 | )
|
27 | 28 | from fbgemm_gpu.experimental.gen_ai.quantize import quantize_int4_preshuffle
|
@@ -729,6 +730,45 @@ def cuda(self) -> bool:
|
729 | 730 | return True
|
730 | 731 |
|
731 | 732 |
|
| 733 | +@register_quantize_op |
| 734 | +class BF16TritonStackedGroupedGemm(QuantizeOpBase): |
| 735 | + """ |
| 736 | + BF16 grouped matmul with stacked inputs implemented with triton. |
| 737 | + """ |
| 738 | + |
| 739 | + def preprocess(self, x, w): |
| 740 | + m_values = [i.shape[0] for i in x] |
| 741 | + # Convert m_values into offsets into grouped tensor. |
| 742 | + m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device) |
| 743 | + w = torch.concat(w, dim=0).contiguous() |
| 744 | + # Also view input as flattened. |
| 745 | + x = torch.concat(x, dim=0).contiguous() |
| 746 | + # Return processed tensors. |
| 747 | + return x, w, m_sizes |
| 748 | + |
| 749 | + def quantize(self, x, w, m_sizes): |
| 750 | + return x, w, m_sizes |
| 751 | + |
| 752 | + def compute(self, x, w, m_sizes): |
| 753 | + return grouped_gemm(x, w, m_sizes) |
| 754 | + |
| 755 | + def quantize_and_compute(self, x, w, m_sizes): |
| 756 | + x, w, m_sizes = self.quantize(x, w, m_sizes) |
| 757 | + return self.compute(x, w, m_sizes) |
| 758 | + |
| 759 | + @property |
| 760 | + def name(self) -> str: |
| 761 | + return "triton_bf16_grouped_stacked" |
| 762 | + |
| 763 | + @property |
| 764 | + def hip(self) -> bool: |
| 765 | + return True |
| 766 | + |
| 767 | + @property |
| 768 | + def cuda(self) -> bool: |
| 769 | + return True |
| 770 | + |
| 771 | + |
732 | 772 | @register_quantize_op
|
733 | 773 | class FP8TritonStackedGroupedGemm(QuantizeOpBase):
|
734 | 774 | """
|
@@ -1446,6 +1486,46 @@ def cuda(self) -> bool:
|
1446 | 1486 | return True
|
1447 | 1487 |
|
1448 | 1488 |
|
| 1489 | +@register_quantize_op |
| 1490 | +class BF16GroupedStacked(QuantizeOpBase): |
| 1491 | + """ |
| 1492 | + BF16 grouped matmul with stacked inputs backed by cutlass or ck. |
| 1493 | + """ |
| 1494 | + |
| 1495 | + def preprocess(self, x, w): |
| 1496 | + m_values = [i.shape[0] for i in x] |
| 1497 | + # Convert m_values into offsets into grouped tensor. |
| 1498 | + m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device) |
| 1499 | + # Group weights as single tensor. |
| 1500 | + w = torch.stack(w, dim=0).contiguous() |
| 1501 | + # Also view input as flattened. |
| 1502 | + x = torch.concat(x, dim=0).contiguous() |
| 1503 | + # Return processed tensors. |
| 1504 | + return x, w, m_sizes |
| 1505 | + |
| 1506 | + def quantize(self, x, w, m_sizes): |
| 1507 | + return x, w, m_sizes |
| 1508 | + |
| 1509 | + def compute(self, x, w, m_sizes): |
| 1510 | + return torch.ops.fbgemm.bf16bf16bf16_grouped_stacked(x, w, m_sizes) |
| 1511 | + |
| 1512 | + def quantize_and_compute(self, x, w, m_sizes): |
| 1513 | + x, w, m_sizes = self.quantize(x, w, m_sizes) |
| 1514 | + return self.compute(x, w, m_sizes) |
| 1515 | + |
| 1516 | + @property |
| 1517 | + def name(self) -> str: |
| 1518 | + return "bf16_grouped_stacked" |
| 1519 | + |
| 1520 | + @property |
| 1521 | + def hip(self) -> bool: |
| 1522 | + return True |
| 1523 | + |
| 1524 | + @property |
| 1525 | + def cuda(self) -> bool: |
| 1526 | + return True |
| 1527 | + |
| 1528 | + |
1449 | 1529 | @register_quantize_op
|
1450 | 1530 | class BF16I4RowwiseGemm(F8I4RowwiseGemm):
|
1451 | 1531 | """
|
|
0 commit comments