Skip to content

Commit 54a96f2

Browse files
levendleefacebook-github-bot
authored andcommitted
Triton GroupedGEMM. WS. (pytorch#1002)
Summary: X-link: pytorch#3912 Pull Request resolved: facebookresearch/FBGEMM#1002 Enable warp-specialization. Reviewed By: jianyuh Differential Revision: D70800084 fbshipit-source-id: 481486fcbf95acf136e83072203fc3466bb2e904
1 parent 663c39a commit 54a96f2

File tree

2 files changed

+403
-20
lines changed

2 files changed

+403
-20
lines changed

fbgemm_gpu/experimental/gemm/test/grouped_gemm_test.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def _test_grouped_gemm_fp8_rowwise(
3636
shape: Tuple[int, int, int, int],
3737
device: torch.device,
3838
fast_accu: bool,
39+
use_warp_specialization: bool,
3940
) -> None:
4041
G, M, N, K = shape
4142
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
@@ -62,6 +63,7 @@ def _test_grouped_gemm_fp8_rowwise(
6263
a_scale,
6364
b_scale,
6465
use_fast_accum=fast_accu,
66+
_use_warp_specialization=use_warp_specialization,
6567
)
6668
self.assertTrue(result.shape == (M, N))
6769

@@ -85,17 +87,22 @@ def _test_grouped_gemm_fp8_rowwise(
8587
for G in (1, 4, 16):
8688
for M in (64, 512):
8789
for fast_accu in (True, False):
88-
logging.info(
89-
f"Testing FP8 GMM with G={G}, M={M}, FastAccu={fast_accu}"
90-
)
91-
_test_grouped_gemm_fp8_rowwise(
92-
(G, M, 256, 256), torch.device("cuda"), fast_accu=fast_accu
93-
)
90+
for ws in (True, False):
91+
logging.info(
92+
f"Testing FP8 GMM with G={G}, M={M}, FastAccu={fast_accu}"
93+
)
94+
_test_grouped_gemm_fp8_rowwise(
95+
(G, M, 256, 256),
96+
torch.device("cuda"),
97+
fast_accu=fast_accu,
98+
use_warp_specialization=ws,
99+
)
94100

95101
def test_grouped_gemm_bf16(self) -> None:
96102
def _test_grouped_gemm_bf16(
97103
shape: Tuple[int, int, int, int],
98104
device: torch.device,
105+
use_warp_specialization: bool,
99106
) -> None:
100107
G, M, N, K = shape
101108
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
@@ -116,6 +123,7 @@ def _test_grouped_gemm_bf16(
116123
a,
117124
b,
118125
m_sizes,
126+
_use_warp_specialization=use_warp_specialization,
119127
)
120128
self.assertTrue(result.shape == (M, N))
121129

@@ -131,5 +139,10 @@ def _test_grouped_gemm_bf16(
131139

132140
for G in (1, 4, 16):
133141
for M in (64, 512):
134-
logging.info(f"Testing BF16 GMM with G={G}, M={M}")
135-
_test_grouped_gemm_bf16((G, M, 256, 256), torch.device("cuda"))
142+
for ws in (True, False):
143+
logging.info(f"Testing BF16 GMM with G={G}, M={M}")
144+
_test_grouped_gemm_bf16(
145+
(G, M, 256, 256),
146+
torch.device("cuda"),
147+
use_warp_specialization=ws,
148+
)

0 commit comments

Comments
 (0)