Skip to content

Commit a4146f3

Browse files
SamGinzburgfacebook-github-bot
authored andcommitted
implement packed quantize row / dequantize row API (pytorch#3915)
Summary: X-link: facebookresearch/FBGEMM#1004 API for a packed version of quantize/dequantize row. This version will return a single, contiguous tensor in memory instead of returning two tensors, and operates on the contiguous tensor. Example usage: ``` a = torch.randn(shape, dtype=torch.bfloat16, device="cuda") packed_values = quantize_fp8_packed_row_raw( a, use_triton=True, ) # Undo scaling. a_bf16 = dequantize_fp8_packed_row(packed_values) torch.testing.assert_close(a_bf16, a, atol=2e-1, rtol=1e-1) ``` A third API: "quantize_fp8_packed_row" mimics the API of quantize_fp8_row (mainly for testing). Reviewed By: jiawenliu64 Differential Revision: D72121939
1 parent dcb347f commit a4146f3

File tree

2 files changed

+625
-0
lines changed

2 files changed

+625
-0
lines changed

fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py

Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@
1616
if torch.cuda.is_available():
1717
from fbgemm_gpu.experimental.gemm.triton_gemm.fp8_gemm import (
1818
dequantize_fp8_block,
19+
dequantize_fp8_packed_row,
1920
dequantize_fp8_row,
2021
matmul_fp8_block,
2122
matmul_fp8_row,
2223
quantize_fp8_block,
24+
# packed_row unpacks the values, packed_row_raw returns just the packed tensor
25+
quantize_fp8_packed_row,
26+
quantize_fp8_packed_row_raw,
2327
quantize_fp8_row,
2428
scale_fp8_row,
2529
)
@@ -138,6 +142,116 @@ def _test_quantize_fp8_row(
138142
use_jagged=True,
139143
)
140144

145+
def test_quantize_fp8_packed_row(self) -> None:
146+
def _test_quantize_fp8_packed_row(
147+
shape: Tuple[int, ...],
148+
use_triton: bool,
149+
device: torch.device,
150+
output_device: Optional[torch.device] = None,
151+
use_jagged: bool = False,
152+
use_scale_ub: bool = False,
153+
transpose_inputs: bool = False,
154+
) -> None:
155+
a = torch.randn(shape, dtype=torch.bfloat16, device=device)
156+
inputs = [a]
157+
# if transpose_inputs is true, get all possible dimension combinations
158+
# of the input tensor and transposes each pair
159+
if transpose_inputs:
160+
dims = range(a.ndim)
161+
for dim1, dim2 in itertools.combinations(dims, 2):
162+
dims_list = list(dims)
163+
dims_list[dim1], dims_list[dim2] = dims_list[dim2], dims_list[dim1]
164+
inputs.append(a.clone().permute(dims_list))
165+
scale_ub = (
166+
torch.tensor([1200], dtype=torch.float, device=device)
167+
if use_scale_ub
168+
else None
169+
)
170+
for input_a in inputs:
171+
# Apply sparsification if specified.
172+
zero_start_index_M = None
173+
if use_jagged:
174+
# View input as [G, M, K] where G is the number of groups.
175+
grouped_input = input_a.view(
176+
-1, input_a.shape[-2], input_a.shape[-1]
177+
)
178+
m_vals = torch.randint(
179+
0, grouped_input.shape[1] + 1, (grouped_input.shape[0],)
180+
)
181+
mask = torch.arange(grouped_input.shape[-2]).expand(
182+
(grouped_input.shape[0], grouped_input.shape[1])
183+
) >= m_vals.unsqueeze(-1)
184+
# Set corresponding values to 0.
185+
grouped_input[mask] = 0.0
186+
# Generate nonzero tensor in same layout as input.
187+
zero_start_index_M = torch.count_nonzero(
188+
torch.sum(grouped_input, dim=-1), dim=-1
189+
)
190+
191+
a_fp8, a_scale = quantize_fp8_packed_row(
192+
input_a,
193+
scale_ub=scale_ub,
194+
zero_start_index_M=zero_start_index_M,
195+
use_triton=use_triton,
196+
output_device=output_device,
197+
)
198+
199+
# Undo scaling.
200+
a_torch = a_fp8.to(torch.bfloat16)
201+
broadcast_shape = list(a_torch.shape[:-1]) + [-1]
202+
203+
assert a_scale.shape == a_torch.shape[:-1]
204+
205+
a_torch *= a_scale.view(broadcast_shape)
206+
207+
self.assertTrue(
208+
torch.allclose(
209+
input_a.to(device=output_device),
210+
a_torch,
211+
atol=2e-1,
212+
rtol=1e-1,
213+
)
214+
)
215+
216+
for n_col in range(1, 9000, 100):
217+
_test_quantize_fp8_packed_row((2, n_col), True, torch.device("cuda"))
218+
# Test with batched input.
219+
_test_quantize_fp8_packed_row((4, 2, 3), True, torch.device("cuda"))
220+
_test_quantize_fp8_packed_row((6, 4, 2, 3), True, torch.device("cuda"))
221+
# Test with non-contiguous input
222+
_test_quantize_fp8_packed_row(
223+
(4, 2, 3), True, torch.device("cuda"), transpose_inputs=True
224+
)
225+
_test_quantize_fp8_packed_row(
226+
(6, 4, 2, 3), True, torch.device("cuda"), transpose_inputs=True
227+
)
228+
_test_quantize_fp8_packed_row(
229+
(2, 3), True, torch.device("cuda"), use_scale_ub=True
230+
)
231+
# Test with cpu
232+
_test_quantize_fp8_packed_row(
233+
(2, 3), False, torch.device("cpu"), torch.device("cuda")
234+
)
235+
_test_quantize_fp8_packed_row(
236+
(2, 3), False, torch.device("cpu"), torch.device("cuda"), use_scale_ub=True
237+
)
238+
_test_quantize_fp8_packed_row((4, 2, 3), True, torch.device("cpu"))
239+
_test_quantize_fp8_packed_row((6, 4, 2, 3), True, torch.device("cpu"))
240+
# Test with zero_start_index_M
241+
_test_quantize_fp8_packed_row(
242+
(20, 30), True, torch.device("cuda"), use_jagged=True
243+
)
244+
_test_quantize_fp8_packed_row(
245+
(6, 4, 2, 3), True, torch.device("cuda"), use_jagged=True
246+
)
247+
_test_quantize_fp8_packed_row(
248+
(4, 2, 3),
249+
True,
250+
torch.device("cuda"),
251+
transpose_inputs=True,
252+
use_jagged=True,
253+
)
254+
141255
def test_dequantize_fp8_row(self) -> None:
142256
def _test_dequantize_fp8_row(
143257
shape: Tuple[int, ...],
@@ -173,6 +287,44 @@ def _test_dequantize_fp8_row(
173287
for shape in shapes:
174288
_test_dequantize_fp8_row(shape)
175289

290+
def test_dequantize_fp8_packed_row(self) -> None:
291+
def _test_dequantize_fp8_packed_row(
292+
shape: Tuple[int, ...],
293+
) -> None:
294+
a = torch.randn(shape, dtype=torch.bfloat16, device="cuda")
295+
296+
packed_values = quantize_fp8_packed_row_raw(
297+
a,
298+
use_triton=True,
299+
)
300+
301+
# Undo scaling.
302+
a_bf16 = dequantize_fp8_packed_row(packed_values)
303+
304+
ms = triton.testing.do_bench(
305+
lambda: dequantize_fp8_packed_row(packed_values),
306+
)
307+
print(f"Shape: {a.shape} MS: {ms}")
308+
309+
torch.testing.assert_close(a_bf16, a, atol=2e-1, rtol=1e-1)
310+
311+
self.assertTrue(
312+
torch.allclose(
313+
a,
314+
a_bf16,
315+
atol=2e-1,
316+
rtol=1e-1,
317+
)
318+
)
319+
320+
for n_col in [1, 100, 1000]:
321+
_test_dequantize_fp8_packed_row((2, n_col))
322+
# Test with batched input.
323+
_test_dequantize_fp8_packed_row((4, 2, 3))
324+
shapes = [(4, 2, 3), (6, 4, 2, 3), (2, 3), (20, 30)]
325+
for shape in shapes:
326+
_test_dequantize_fp8_packed_row(shape)
327+
176328
def test_scale_fp8_row(self) -> None:
177329
def _test_scale_fp8_row(
178330
shape: Tuple[int, int],

0 commit comments

Comments
 (0)