Skip to content

Commit a6f9891

Browse files
levendleeliligwu
authored andcommitted
Makes use_fast_accum configurable. (pytorch#3829)
Summary: Pull Request resolved: pytorch#3829 X-link: https://github.com/facebookresearch/FBGEMM/pull/913 [Public to OSS] Thanks htyu for pointing out the issue. Looking forward to warp specialization support on Nvidia! - Exposes fast accumulation as a configurable. - Not enable it by default. No change in default behavior. - No additional tuning regarding to `use_fast_accum=True`. W/ HIP backend, the semantics of `c += tl.dot(a, b)` and `c = tl.dot(a,b,c)` seems to be the same. Reviewed By: htyu Differential Revision: D71290596 fbshipit-source-id: 8e2a20899f301f861d8d72f6290e573e23288e63
1 parent 12df261 commit a6f9891

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

fbgemm_gpu/experimental/gemm/test/grouped_gemm_test.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def test_grouped_gemm_fp8_rowwise(self) -> None:
3535
def _test_grouped_gemm_fp8_rowwise(
3636
shape: Tuple[int, int, int, int],
3737
device: torch.device,
38+
fast_accu: bool,
3839
) -> None:
3940
G, M, N, K = shape
4041
a = torch.randn(M, K, dtype=torch.bfloat16, device=device)
@@ -60,6 +61,7 @@ def _test_grouped_gemm_fp8_rowwise(
6061
m_sizes,
6162
a_scale,
6263
b_scale,
64+
use_fast_accum=fast_accu,
6365
)
6466
self.assertTrue(result.shape == (M, N))
6567

@@ -82,8 +84,13 @@ def _test_grouped_gemm_fp8_rowwise(
8284

8385
for G in (1, 4, 16):
8486
for M in (64, 512):
85-
logging.info(f"Testing FP8 GMM with G={G}, M={M}")
86-
_test_grouped_gemm_fp8_rowwise((G, M, 256, 256), torch.device("cuda"))
87+
for fast_accu in (True, False):
88+
logging.info(
89+
f"Testing FP8 GMM with G={G}, M={M}, FastAccu={fast_accu}"
90+
)
91+
_test_grouped_gemm_fp8_rowwise(
92+
(G, M, 256, 256), torch.device("cuda"), fast_accu=fast_accu
93+
)
8794

8895
def test_grouped_gemm_bf16(self) -> None:
8996
def _test_grouped_gemm_bf16(

fbgemm_gpu/experimental/gemm/triton_gemm/grouped_gemm.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def _kernel_grouped_gemm(
142142
NUM_SMS: tl.constexpr,
143143
USE_TMA_LOAD: tl.constexpr,
144144
USE_TMA_STORE: tl.constexpr,
145+
USE_FAST_ACCUM: tl.constexpr,
145146
# tile sizes
146147
BLOCK_SIZE_M: tl.constexpr,
147148
BLOCK_SIZE_N: tl.constexpr,
@@ -208,7 +209,10 @@ def _kernel_grouped_gemm(
208209
[BLOCK_SIZE_N, BLOCK_SIZE_K],
209210
dtype,
210211
)
211-
accumulator += tl.dot(a, b.T)
212+
if USE_FAST_ACCUM:
213+
accumulator = tl.dot(a, b.T, accumulator)
214+
else:
215+
accumulator += tl.dot(a, b.T)
212216
else:
213217
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
214218
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@@ -283,6 +287,7 @@ def _kernel_grouped_gemm_fp8_rowwise(
283287
NUM_SMS: tl.constexpr,
284288
USE_TMA_LOAD: tl.constexpr,
285289
USE_TMA_STORE: tl.constexpr,
290+
USE_FAST_ACCUM: tl.constexpr,
286291
# tile sizes
287292
BLOCK_SIZE_M: tl.constexpr,
288293
BLOCK_SIZE_N: tl.constexpr,
@@ -349,7 +354,10 @@ def _kernel_grouped_gemm_fp8_rowwise(
349354
[BLOCK_SIZE_N, BLOCK_SIZE_K],
350355
dtype,
351356
)
352-
accumulator += tl.dot(a, b.T)
357+
if USE_FAST_ACCUM:
358+
accumulator = tl.dot(a, b.T, accumulator)
359+
else:
360+
accumulator += tl.dot(a, b.T)
353361
else:
354362
offs_am = tile_m_idx * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
355363
offs_bn = tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@@ -410,6 +418,7 @@ def _grouped_gemm(
410418
m_sizes: torch.Tensor,
411419
x_scale: Optional[torch.Tensor] = None,
412420
w_scale: Optional[torch.Tensor] = None,
421+
use_fast_accum: bool = False,
413422
) -> torch.Tensor:
414423
if not utils.HAS_TMA_DESC:
415424
raise NotImplementedError("Grouped GEMM without TMA is not supported yet")
@@ -493,6 +502,7 @@ def grid(META):
493502
NUM_SMS,
494503
USE_TMA_LOAD,
495504
USE_TMA_STORE,
505+
use_fast_accum,
496506
)
497507
else:
498508
assert x_scale is None
@@ -510,15 +520,19 @@ def grid(META):
510520
NUM_SMS,
511521
USE_TMA_LOAD,
512522
USE_TMA_STORE,
523+
use_fast_accum,
513524
)
514525

515526
return y
516527

517528

518529
def grouped_gemm(
519-
x: torch.Tensor, w: torch.Tensor, m_sizes: torch.Tensor
530+
x: torch.Tensor,
531+
w: torch.Tensor,
532+
m_sizes: torch.Tensor,
533+
use_fast_accum: bool = False,
520534
) -> torch.Tensor:
521-
return _grouped_gemm(x, w, m_sizes)
535+
return _grouped_gemm(x, w, m_sizes, use_fast_accum=use_fast_accum)
522536

523537

524538
def grouped_gemm_fp8_rowwise(
@@ -527,5 +541,6 @@ def grouped_gemm_fp8_rowwise(
527541
m_sizes: torch.Tensor,
528542
x_scale: torch.Tensor,
529543
w_scale: torch.Tensor,
544+
use_fast_accum: bool = False,
530545
) -> torch.Tensor:
531-
return _grouped_gemm(x, w, m_sizes, x_scale, w_scale)
546+
return _grouped_gemm(x, w, m_sizes, x_scale, w_scale, use_fast_accum=use_fast_accum)

0 commit comments

Comments
 (0)