Skip to content

Commit dcb347f

Browse files
jasonjk-parkfacebook-github-bot
authored andcommitted
Handle 0 inputs for gmm (pytorch#3901)
Summary: Pull Request resolved: pytorch#3901 X-link: facebookresearch/FBGEMM#992 Add support for 0-sized input for triton gmm. Add unit test Reviewed By: levendlee Differential Revision: D72134331 fbshipit-source-id: 6b212993e0e270caefed27de892033ab75a57027
1 parent cfd715b commit dcb347f

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

fbgemm_gpu/experimental/gemm/test/grouped_gemm_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def _test_grouped_gemm_fp8_rowwise(
4545
torch.randint(
4646
low=0, high=M, size=[G - 1], device=device, dtype=torch.int32
4747
)
48+
if M > 0
49+
else torch.zeros([G - 1], device=device, dtype=torch.int32)
4850
)
4951
m_ends = m_ends.tolist()
5052
m_starts = [0] + m_ends
@@ -85,7 +87,7 @@ def _test_grouped_gemm_fp8_rowwise(
8587
torch.testing.assert_close(result, expected_result, atol=2e-2, rtol=1.6e-2)
8688

8789
for G in (1, 4, 16):
88-
for M in (64, 512):
90+
for M in (0, 64, 512):
8991
for fast_accu in (True, False):
9092
for ws in (True, False):
9193
logging.info(
@@ -111,6 +113,8 @@ def _test_grouped_gemm_bf16(
111113
torch.randint(
112114
low=0, high=M, size=[G - 1], device=device, dtype=torch.int32
113115
)
116+
if M > 0
117+
else torch.zeros([G - 1], device=device, dtype=torch.int32)
114118
)
115119
m_ends = m_ends.tolist()
116120
m_starts = [0] + m_ends
@@ -138,7 +142,7 @@ def _test_grouped_gemm_bf16(
138142
torch.testing.assert_close(result, expected_result, atol=1e-5, rtol=1.6e-2)
139143

140144
for G in (1, 4, 16):
141-
for M in (64, 512):
145+
for M in (0, 64, 512):
142146
for ws in (True, False):
143147
logging.info(f"Testing BF16 GMM with G={G}, M={M}")
144148
_test_grouped_gemm_bf16(

fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -780,6 +780,8 @@ def _grouped_gemm(
780780
assert K == w.shape[1]
781781

782782
y = torch.empty((M, N), device=x.device, dtype=torch.bfloat16)
783+
if M == 0 or N == 0:
784+
return y
783785

784786
NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count
785787

0 commit comments

Comments
 (0)