Skip to content

Commit a16425b

Browse files
levendleefacebook-github-bot
authored andcommitted
No recompilation caused by varying sequence lengths. (#3903)
Summary: Pull Request resolved: #3903 X-link: facebookresearch/FBGEMM#995 No need to recompile with varying sequence length. Just tuning. Reviewed By: jasonjk-park, jianyuh Differential Revision: D71527635 fbshipit-source-id: 5616f57175d2497d491610ca663123fbe77a12ef
1 parent 4529860 commit a16425b

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def _fbgemm_grouped_gemm(
136136
m_sizes,
137137
# problem sizes
138138
G: tl.constexpr,
139-
M_BUCKET: tl.constexpr,
139+
M_BUCKET,
140140
N: tl.constexpr,
141141
K: tl.constexpr,
142142
NUM_SMS: tl.constexpr,
@@ -281,7 +281,7 @@ def _fbgemm_grouped_gemm_fp8_rowwise(
281281
m_sizes,
282282
# problem sizes
283283
G: tl.constexpr,
284-
M_BUCKET: tl.constexpr,
284+
M_BUCKET,
285285
N: tl.constexpr,
286286
K: tl.constexpr,
287287
NUM_SMS: tl.constexpr,
@@ -483,7 +483,8 @@ def grid(META):
483483

484484
return (NUM_SMS,)
485485

486-
M_BUCKET = triton.next_power_of_2(M)
486+
M_BUCKET_CAP = 16384
487+
M_BUCKET = min(triton.next_power_of_2(M), M_BUCKET_CAP)
487488
if x_scale is not None and w_scale is not None:
488489
assert x_scale.is_contiguous()
489490
assert w_scale.is_contiguous()

0 commit comments

Comments
 (0)