Skip to content

Commit 2db3c2f

Browse files
jwfrommfacebook-github-bot
authored andcommitted
FP8 Grouped Gemm Optimization (#3655)
Summary: Pull Request resolved: #3655 X-link: facebookresearch/FBGEMM#731 While optimizing MOE, we found that small overheads were a major bottleneck for grouped gemm performance. This diff tackles a few of them, specifically overhead from torch.dynamo wrapping `quantize_fp8_row` and having to slice input tensors before calling `f8f8bf16_rowwise_grouped`. To fix the former, we enable `triton_quantize_fp8_row` to be directly called, skipping dynamo compatibility. In cases where AOTI isnt needed, this removes a bit of overhead. To fix the latter, we templatize f8f8fbf16_rowwise_grouped_dynamic to accept at::Tensor instead of lists. We introduce a new wrapper called f8f8bf16_rowwise_grouped_stacked to maintain the behavior where zero_start_index_M isnt provided but a user wants a single contiguous output tensor. In microbenchmarks, we've found these seemingly small changes can improve TFLOPs by 2X for small workloads. Reviewed By: jianyuh, jiawenliu64 Differential Revision: D69072529
1 parent dced756 commit 2db3c2f

File tree

79 files changed

+6720
-4714
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

79 files changed

+6720
-4714
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2447,6 +2447,13 @@ def triton_quantize_fp8_row(
24472447
torch.Tensor: fp8 scaled tensor.
24482448
torch.Tensor: reciprocal scale tensor per row.
24492449
"""
2450+
assert a.dim() <= 4, "Triton only supports up to 4 dimension input tensor."
2451+
a_shape = a.shape
2452+
while a.dim() < 4:
2453+
a = a.unsqueeze(0)
2454+
if zero_start_index_M is not None:
2455+
# There should be one value of zero_start_index_M per NxK matrix.
2456+
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
24502457
# Get constant values.
24512458
pt_dtype, tl_dtype, max_fp8, eps = get_fp8_constants()
24522459
num_rows = a.numel() // a.shape[-1]
@@ -2484,7 +2491,7 @@ def triton_quantize_fp8_row(
24842491
USE_INT64=use_int64,
24852492
)
24862493

2487-
return a_fp8, a_scale
2494+
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
24882495

24892496

24902497
@torch.library.custom_op("triton::quantize_fp8_row", mutates_args=())
@@ -2514,17 +2521,7 @@ def quantize_fp8_row(
25142521
logger.info("Triton does not support cpu, falling back to torch ops.")
25152522
use_triton = False
25162523
if use_triton:
2517-
assert (
2518-
a.dim() <= 4
2519-
), "Only up to 4 dimension input tensor is supported if use_triton is True"
2520-
a_shape = a.shape
2521-
while a.dim() < 4:
2522-
a = a.unsqueeze(0)
2523-
if zero_start_index_M is not None:
2524-
# There should be one value of zero_start_index_M per NxK matrix.
2525-
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
2526-
a_fp8, a_scale = triton_quantize_fp8_row(a, scale_ub, zero_start_index_M)
2527-
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
2524+
return triton_quantize_fp8_row(a, scale_ub, zero_start_index_M)
25282525
# else use pytorch implementation.
25292526
if not output_device:
25302527
output_device = a.device

fbgemm_gpu/experimental/gen_ai/bench/profile_grouped_gemm.py

Lines changed: 0 additions & 103 deletions
This file was deleted.

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
quantize_fp8_block,
2020
quantize_fp8_row,
2121
scale_fp8_row,
22+
triton_quantize_fp8_row,
2223
)
2324
from tinygemm.utils import group_quantize_tensor
2425

@@ -553,38 +554,31 @@ def preprocess(self, x, w):
553554
def quantize(self, x, wq, w_scale, m_values=None):
554555
# Handle case where inputs are explicitly grouped and non-sparse.
555556
if isinstance(x, (tuple, list)):
556-
xq, x_scale = zip(*[quantize_fp8_row(i) for i in x])
557+
xq, x_scale = zip(*[triton_quantize_fp8_row(i) for i in x])
557558
return xq, wq, x_scale, w_scale, m_values
558559
# Otherwise inputs are unified tensors and sparse.
559560
else:
560561
B = x.shape[0]
561-
xq, x_scale = quantize_fp8_row(x, zero_start_index_M=m_values)
562+
xq, x_scale = triton_quantize_fp8_row(x, zero_start_index_M=m_values)
562563
x_scale = x_scale.view(B, -1)
563564
return xq, wq, x_scale, w_scale, m_values
564565

565-
def compute(self, xq, wq, x_scale, w_scale, m_values, kernel_name=None):
566+
def compute(self, xq, wq, x_scale, w_scale, m_values):
566567
if m_values is None:
567568
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
568569
xq,
569570
wq,
570571
x_scale,
571572
w_scale,
572-
kernel_name=kernel_name,
573573
)
574574
else:
575575
# Break tensor into groups, simulates what is done e2e.
576-
B = xq.shape[0]
577-
xq_group = [xq[i, :, :] for i in range(B)]
578-
x_scale_group = [x_scale[i, :] for i in range(B)]
579-
wq_group = [wq[i, :, :] for i in range(B)]
580-
w_scale_group = [w_scale[i, :] for i in range(B)]
581576
return torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
582-
xq_group,
583-
wq_group,
584-
x_scale_group,
585-
w_scale_group,
577+
xq,
578+
wq,
579+
x_scale,
580+
w_scale,
586581
zero_start_index_M=m_values,
587-
kernel_name=kernel_name,
588582
)
589583

590584
def quantize_and_compute(self, x, wq, w_scale, m_values=None):

0 commit comments

Comments
 (0)