Skip to content

Commit 8de5079

Browse files
jwfrommfacebook-github-bot
authored andcommitted
Optimize MX4 padding to minimize need for tuning (pytorch#3040)
Summary: X-link: facebookresearch/FBGEMM#137 Pull Request resolved: pytorch#3040 D61447274 introduced a very cool way of doing 2D indexing over input tensors during MX4 quantization, however, it is fairly reliant on tuning configurations to get good performance. It turns out the use case for MX4 has highly dynamic shapes, so we spend a huge amount of time tuning those shapes. After deep meditation I realized there's a much simpler indexing scheme we can use, which is similar to the 1D accesses we used previously but adds shifts for padding. With this approach we should get the best of both worlds; support for padding rows not divisible by group size and minimizing tuning while maintaining good performance. After further experimentation, we can actually remove tuning entirely and just use a reasonably large `GROUP_LOAD`. This gives good performance across all shapes and removes any chance of overhead. Empirically, `GROUP_LOAD=64` seems to be the sweet spot. Differential Revision: D61816830
1 parent e31151e commit 8de5079

File tree

4 files changed

+155
-210
lines changed

4 files changed

+155
-210
lines changed

fbgemm_gpu/fbgemm_gpu/triton/common.py

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,26 @@
88
# pyre-unsafe
99
from enum import IntEnum
1010

11+
import torch
12+
13+
14+
# LUTS need to be allocated ahead of time and copied to GPU to avoid expensive copies later.
15+
if torch.version.cuda:
16+
lut_device = "cuda"
17+
else:
18+
lut_device = "cpu"
19+
20+
E2M1_LUT = torch.tensor(
21+
[0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6],
22+
dtype=torch.float32,
23+
device=lut_device,
24+
)
25+
E3M0_LUT = torch.tensor(
26+
[0, 0.25, 0.5, 1, 2, 4, 8, 16, -0, -0.25, -0.5, -1, -2, -4, -8, -16],
27+
dtype=torch.float32,
28+
device=lut_device,
29+
)
30+
1131

1232
class RoundingMode(IntEnum):
1333
"""Rounding options for quantization."""
@@ -47,26 +67,9 @@ def get_mx4_lookup_table(ebits, mbits):
4767
The lookup table for the specified mx4 format.
4868
"""
4969
if ebits == 2 and mbits == 1:
50-
return [0, 0.5, 1, 1.5, 2, 3, 4, 6, -0, -0.5, -1, -1.5, -2, -3, -4, -6]
70+
return E2M1_LUT
5171
elif ebits == 3 and mbits == 0:
52-
return [
53-
0,
54-
0.25,
55-
0.5,
56-
1,
57-
2,
58-
4,
59-
8,
60-
16,
61-
-0,
62-
-0.25,
63-
-0.5,
64-
-1,
65-
-2,
66-
-4,
67-
-8,
68-
-16,
69-
]
72+
return E3M0_LUT
7073
else:
7174
raise NotImplementedError(
7275
f"MX4 with ebits={ebits} and mbits={mbits} not supported."

0 commit comments

Comments
 (0)