Skip to content

Commit b5655fb

Browse files
choutimfacebook-github-bot
authored andcommitted
FP8 Rowwise Dequant Kernel (pytorch#962)
Summary: X-link: pytorch#3873 Pull Request resolved: facebookresearch/FBGEMM#962 Add Triton Rowwise dequant kernel Reviewed By: jiawenliu64 Differential Revision: D71655200 fbshipit-source-id: 10b329b3da5fa3398145ec78ab98bdbd1402eafd
1 parent 9583884 commit b5655fb

File tree

2 files changed

+135
-0
lines changed

2 files changed

+135
-0
lines changed

fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
from typing import Optional, Tuple
1212

1313
import torch
14+
import triton
1415

1516
if torch.cuda.is_available():
1617
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
1718
dequantize_fp8_block,
19+
dequantize_fp8_row,
1820
matmul_fp8_block,
1921
matmul_fp8_row,
2022
quantize_fp8_block,
@@ -136,6 +138,41 @@ def _test_quantize_fp8_row(
136138
use_jagged=True,
137139
)
138140

141+
def test_dequantize_fp8_row(self) -> None:
142+
def _test_dequantize_fp8_row(
143+
shape: Tuple[int, ...],
144+
) -> None:
145+
a = torch.randn(shape, dtype=torch.bfloat16, device="cuda")
146+
a_fp8, a_scale = quantize_fp8_row(
147+
a,
148+
use_triton=True,
149+
)
150+
151+
# Undo scaling.
152+
a_bf16 = dequantize_fp8_row(a_fp8, a_scale)
153+
154+
ms = triton.testing.do_bench(
155+
lambda: dequantize_fp8_row(a_fp8, a_scale),
156+
)
157+
print(f"Shape: {a.shape} MS: {ms}")
158+
torch.testing.assert_close(a_bf16, a, atol=2e-1, rtol=1e-1)
159+
self.assertTrue(
160+
torch.allclose(
161+
a,
162+
a_bf16,
163+
atol=2e-1,
164+
rtol=1e-1,
165+
)
166+
)
167+
168+
for n_col in [1, 100, 1000]:
169+
_test_dequantize_fp8_row((2, n_col))
170+
# Test with batched input.
171+
_test_dequantize_fp8_row((4, 2, 3))
172+
shapes = [(4, 2, 3), (6, 4, 2, 3), (2, 3), (20, 30)]
173+
for shape in shapes:
174+
_test_dequantize_fp8_row(shape)
175+
139176
def test_scale_fp8_row(self) -> None:
140177
def _test_scale_fp8_row(
141178
shape: Tuple[int, int],

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3099,6 +3099,104 @@ def _kernel_matmul_fp8_row_non_persistent(
30993099
tl.atomic_add(C, acc, mask=mask)
31003100

31013101

3102+
@triton.autotune(
3103+
configs=[Config({"BLOCK_M": 16, "BLOCK_K": 512, "NUM_STAGES": 2})],
3104+
key=["M", "K"],
3105+
)
3106+
@triton.jit
3107+
def _kernel_dequantize_fp8_row(
3108+
xq_ptr,
3109+
x_scale_ptr,
3110+
x_dequant_ptr,
3111+
M,
3112+
K,
3113+
stride_xm,
3114+
stride_xk,
3115+
stride_xdqm,
3116+
stride_xdqk,
3117+
BLOCK_M: tl.constexpr,
3118+
BLOCK_K: tl.constexpr,
3119+
NUM_STAGES: tl.constexpr,
3120+
USE_INT64: tl.constexpr,
3121+
):
3122+
"""
3123+
Kernel to dequantize FP8 tensor to BF16 tensor.
3124+
Args:
3125+
xq_ptr (tl.constexpr): Pointer to FP8 tensor.
3126+
x_scale_ptr (tl.constexpr): Pointer to FP8 scale tensor.
3127+
x_dequant_ptr (tl.constexpr): Pointer to BF16 tensor.
3128+
M (tl.constexpr): M dimension of input tensor.
3129+
K (tl.constexpr): K dimension of input tensor (along which scales are applied)
3130+
BLOCK_SIZE (tl.constexpr): Block size for the K dimension.
3131+
"""
3132+
pid = tl.program_id(axis=0)
3133+
if USE_INT64:
3134+
pid = pid.to(tl.int64)
3135+
offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M)
3136+
offs_k = tl.arange(0, BLOCK_K)
3137+
scales = tl.load(x_scale_ptr + offs_m)
3138+
3139+
for _k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
3140+
mask = (offs_m[:, None] < M) & (offs_k[None, :] < K)
3141+
xq = tl.load(
3142+
xq_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk,
3143+
mask=mask,
3144+
)
3145+
x_dq = xq * scales[:, None]
3146+
tl.store(
3147+
x_dequant_ptr
3148+
+ offs_m[:, None] * stride_xdqm
3149+
+ offs_k[None, :] * stride_xdqk,
3150+
x_dq,
3151+
mask=mask,
3152+
)
3153+
offs_k += BLOCK_K
3154+
3155+
3156+
def dequantize_fp8_row(
3157+
xq: torch.Tensor,
3158+
x_scale: torch.Tensor,
3159+
) -> torch.Tensor:
3160+
"""
3161+
Rowwise Dequantize FP8 tensor to BF16 tensor along last axis.
3162+
3163+
Args:
3164+
xq (torch.Tensor): FP8 tensor to be dequantized.
3165+
x_scale (torch.Tensor): FP8 scale tensor.
3166+
3167+
Returns:
3168+
torch.Tensor: Dequantized BF16 tensor.
3169+
"""
3170+
3171+
assert (
3172+
xq.is_contiguous() and x_scale.is_contiguous()
3173+
), "Input tensors must be contiguous"
3174+
x_dequant = torch.empty_like(xq, dtype=torch.bfloat16)
3175+
3176+
# Reshape to 2-d array keeping last dim only.
3177+
K = xq.shape[-1]
3178+
xq = xq.reshape(-1, K)
3179+
M = xq.shape[0]
3180+
use_int64 = xq.numel() > 2**31
3181+
3182+
def grid(meta):
3183+
return (triton.cdiv(M, meta["BLOCK_M"]),)
3184+
3185+
_kernel_dequantize_fp8_row[grid](
3186+
xq,
3187+
x_scale,
3188+
x_dequant,
3189+
M,
3190+
K,
3191+
xq.stride(0),
3192+
xq.stride(1),
3193+
xq.stride(0), # Use squashed stride.
3194+
xq.stride(1),
3195+
USE_INT64=use_int64,
3196+
)
3197+
return x_dequant
3198+
3199+
31023200
@triton.jit
31033201
def _kernel_dequantize_fp8_block(
31043202
xq_ptr,

0 commit comments

Comments
 (0)