File tree Expand file tree Collapse file tree 1 file changed +4
-3
lines changed
fbgemm_gpu/experimental/gemm/triton_gemm Expand file tree Collapse file tree 1 file changed +4
-3
lines changed Original file line number Diff line number Diff line change @@ -136,7 +136,7 @@ def _fbgemm_grouped_gemm(
136
136
m_sizes ,
137
137
# problem sizes
138
138
G : tl .constexpr ,
139
- M_BUCKET : tl . constexpr ,
139
+ M_BUCKET ,
140
140
N : tl .constexpr ,
141
141
K : tl .constexpr ,
142
142
NUM_SMS : tl .constexpr ,
@@ -281,7 +281,7 @@ def _fbgemm_grouped_gemm_fp8_rowwise(
281
281
m_sizes ,
282
282
# problem sizes
283
283
G : tl .constexpr ,
284
- M_BUCKET : tl . constexpr ,
284
+ M_BUCKET ,
285
285
N : tl .constexpr ,
286
286
K : tl .constexpr ,
287
287
NUM_SMS : tl .constexpr ,
@@ -483,7 +483,8 @@ def grid(META):
483
483
484
484
return (NUM_SMS ,)
485
485
486
- M_BUCKET = triton .next_power_of_2 (M )
486
+ M_BUCKET_CAP = 16384
487
+ M_BUCKET = min (triton .next_power_of_2 (M ), M_BUCKET_CAP )
487
488
if x_scale is not None and w_scale is not None :
488
489
assert x_scale .is_contiguous ()
489
490
assert w_scale .is_contiguous ()
You can’t perform that action at this time.
0 commit comments