Skip to content

Commit 22b97a5

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Fix handling of dynamic FP8 grouped gemm on Nvidia (pytorch#3616)
Summary: X-link: facebookresearch/FBGEMM#695 This diff is the nvidia mirror of D68686266, which changes dynamic grouped gemm to return a tensor of shape [total_M, N] when zero_start_index_M isnt provided. We also add appropriate tests to make sure the behavior doesnt break going forward. Reviewed By: jasonjk-park, jianyuh, jiawenliu64 Differential Revision: D68689077
1 parent 98d54f7 commit 22b97a5

File tree

2 files changed

+39
-58
lines changed

2 files changed

+39
-58
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/cutlass_extensions/f8f8bf16_rowwise_grouped.cu

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -462,8 +462,6 @@ std::tuple<at::Tensor, std::vector<at::Tensor>> f8f8bf16_rowwise_grouped_impl(
462462
reinterpret_cast<GroupedGemmArgs::ElementOutput**>(output_ptr),
463463
stride_output_ptr}};
464464

465-
int M = XQ[0].size(0);
466-
int N = WQ[0].size(0);
467465
arguments.epilogue.thread = {
468466
{reinterpret_cast<const GroupedGemmArgs::ElementComputeEpilogue**>(
469467
x_scale_ptr)}, // x_scale
@@ -599,7 +597,13 @@ at::Tensor f8f8bf16_rowwise_grouped_dynamic(
599597
at::Tensor output = std::get<0>(dispatch_fp8_grouped_kernel(
600598
XQ, WQ, x_scale, w_scale, Y, zero_start_index_M));
601599
// View as proper shape.
602-
output = output.view({-1, XQ[0].size(0), WQ[0].size(0)});
600+
// When zero_start_index_M is provided, we can view as [G, M, N]
601+
if (zero_start_index_M.has_value()) {
602+
output = output.view({-1, XQ[0].size(0), WQ[0].size(0)});
603+
// Otherwise we view as {total_M, N}.
604+
} else {
605+
output = output.view({-1, WQ[0].size(0)});
606+
}
603607
return output;
604608
}
605609

fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py

Lines changed: 32 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -726,7 +726,8 @@ def fp8_loopover_bmm(
726726
torch.testing.assert_close(y_ref, y_fp8, atol=8.0e-2, rtol=8.0e-2)
727727

728728
@unittest.skipIf(
729-
not torch.version.cuda, "Skip on AMD: GMM ops are not yet suported."
729+
not torch.version.cuda and torch.version.hip < "6.2",
730+
"Skip on AMD with < RoCM 6.2",
730731
)
731732
@settings(deadline=None)
732733
@given(
@@ -805,63 +806,39 @@ def test_fp8_grouped_gemm(
805806
w_scale_group = torch.unbind(torch.stack(w_scale_group, dim=0).contiguous())
806807

807808
# FP8 grouped gemm kernel
809+
fp8_args = (
810+
[
811+
xq_group,
812+
wq_group,
813+
x_scale_group,
814+
w_scale_group,
815+
zero_start_index_M if use_padding_zeros else None,
816+
]
817+
if use_dynamic
818+
else [xq_group, wq_group, x_scale_group, w_scale_group]
819+
)
820+
fp8_op = (
821+
torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic
822+
if use_dynamic
823+
else torch.ops.fbgemm.f8f8bf16_rowwise_grouped
824+
)
808825
if use_cudagraph:
809-
if use_padding_zeros:
810-
# warmup
811-
torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
812-
xq_group,
813-
wq_group,
814-
x_scale_group,
815-
w_scale_group,
816-
zero_start_index_M,
817-
)
818-
# With cudagraph
819-
g = torch.cuda.CUDAGraph()
820-
with torch.cuda.graph(g):
821-
y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
822-
xq_group,
823-
wq_group,
824-
x_scale_group,
825-
w_scale_group,
826-
zero_start_index_M,
827-
)
828-
g.replay()
829-
y_fp8_group = y_fp8_group.unbind(dim=0)
830-
else:
831-
# warmup
832-
torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
833-
xq_group,
834-
wq_group,
835-
x_scale_group,
836-
w_scale_group,
837-
)
838-
# With cudagraph
839-
g = torch.cuda.CUDAGraph()
840-
with torch.cuda.graph(g):
841-
y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
842-
xq_group,
843-
wq_group,
844-
x_scale_group,
845-
w_scale_group,
846-
)
847-
g.replay()
826+
# warmup
827+
fp8_op(*fp8_args)
828+
# With cudagraph
829+
g = torch.cuda.CUDAGraph()
830+
with torch.cuda.graph(g):
831+
y_fp8_group = fp8_op(*fp8_args)
832+
g.replay()
848833
else:
849-
if use_padding_zeros:
850-
y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped_dynamic(
851-
xq_group,
852-
wq_group,
853-
x_scale_group,
854-
w_scale_group,
855-
zero_start_index_M,
856-
)
857-
y_fp8_group = y_fp8_group.unbind(dim=0)
834+
y_fp8_group = fp8_op(*fp8_args)
835+
836+
# Massage output into proper format.
837+
if not isinstance(y_fp8_group, (tuple, list)):
838+
if y_fp8_group.ndim == 2:
839+
y_fp8_group = torch.split(y_fp8_group, tuple(ms.tolist()), dim=0)
858840
else:
859-
y_fp8_group = torch.ops.fbgemm.f8f8bf16_rowwise_grouped(
860-
xq_group,
861-
wq_group,
862-
x_scale_group,
863-
w_scale_group,
864-
)
841+
y_fp8_group = torch.unbind(y_fp8_group)
865842

866843
# BF16 grouped gemm kernel
867844
bf16_args = (

0 commit comments

Comments
 (0)