Skip to content

Commit 92f7785

Browse files
cgufbfacebook-github-bot
authored andcommitted
Add abstract impl for Fused8BitRowwiseQuantizedToFloatOrHalf et al. (pytorch#715)
Summary: X-link: pytorch#3640 Pull Request resolved: facebookresearch/FBGEMM#715 Reviewed By: q10, jianyuh Differential Revision: D68817290 fbshipit-source-id: b57b475aa9ee746d8726945b28b71fafd93cfbbe
1 parent 69fd016 commit 92f7785

File tree

3 files changed

+188
-18
lines changed

3 files changed

+188
-18
lines changed

fbgemm_gpu/fbgemm_gpu/sparse_ops.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,90 @@ def fused_nbit_rowwise_quantized_sb_half_to_float_or_half(
10161016
)
10171017

10181018

1019+
def fused_8_bit_rowwise_quantized_to_float_or_half(
1020+
input_t: Tensor,
1021+
output_dtype: int = 0,
1022+
scale_bias_last: bool = True,
1023+
quant_padding_float_type: bool = True,
1024+
) -> Tensor:
1025+
torch._check(
1026+
output_dtype
1027+
in [
1028+
SparseType.FP32.as_int(),
1029+
SparseType.FP16.as_int(),
1030+
SparseType.BF16.as_int(),
1031+
]
1032+
)
1033+
torch._check(quant_padding_float_type or not scale_bias_last)
1034+
torch._check(input_t.dim() >= 2)
1035+
last_dim = input_t.dim() - 1
1036+
output_shape = list(input_t.shape)
1037+
ncols = input_t.size(last_dim)
1038+
quant_padding_size = 4 if quant_padding_float_type else 2
1039+
ncols_aligned = (
1040+
(ncols + quant_padding_size - 1) // quant_padding_size * quant_padding_size
1041+
)
1042+
output_columns = ncols_aligned - 2 * quant_padding_size
1043+
output_shape[last_dim] = output_columns
1044+
if output_dtype == SparseType.FP32.as_int():
1045+
return torch.empty(output_shape, dtype=torch.float32, device=input_t.device)
1046+
elif output_dtype == SparseType.FP16.as_int():
1047+
return torch.empty(output_shape, dtype=torch.float16, device=input_t.device)
1048+
else: # output_dtype is SparseType.BF16
1049+
return torch.empty(output_shape, dtype=torch.bfloat16, device=input_t.device)
1050+
1051+
1052+
def float_or_half_to_fused_8_bit_rowwise(
1053+
input_t: Tensor,
1054+
) -> Tensor:
1055+
torch._check(input_t.dim() >= 2)
1056+
last_dim = input_t.dim() - 1
1057+
output_shape = list(input_t.shape)
1058+
ncols = input_t.size(last_dim)
1059+
ncols_aligned = (ncols + 4 - 1) // 4 * 4
1060+
output_columns = ncols_aligned + 2 * 4
1061+
output_shape[last_dim] = output_columns
1062+
return torch.empty(output_shape, dtype=torch.uint8, device=input_t.device)
1063+
1064+
1065+
def fused_8_bit_rowwise_quantized_to_float(
1066+
input_t: Tensor,
1067+
scale_bias_last: bool = True,
1068+
quant_padding_float_type: bool = True,
1069+
) -> Tensor:
1070+
torch._check(quant_padding_float_type or not scale_bias_last)
1071+
torch._check(input_t.dim() >= 2)
1072+
last_dim = input_t.dim() - 1
1073+
output_shape = list(input_t.shape)
1074+
ncols = input_t.size(last_dim)
1075+
quant_padding_size = 4 if quant_padding_float_type else 2
1076+
ncols_aligned = (
1077+
(ncols + quant_padding_size - 1) // quant_padding_size * quant_padding_size
1078+
)
1079+
output_columns = ncols_aligned - 2 * quant_padding_size
1080+
output_shape[last_dim] = output_columns
1081+
return torch.empty(output_shape, dtype=torch.float32, device=input_t.device)
1082+
1083+
1084+
def fused_8_bit_rowwise_quantized_to_half(
1085+
input_t: Tensor,
1086+
scale_bias_last: bool = True,
1087+
quant_padding_float_type: bool = True,
1088+
) -> Tensor:
1089+
torch._check(quant_padding_float_type or not scale_bias_last)
1090+
torch._check(input_t.dim() >= 2)
1091+
last_dim = input_t.dim() - 1
1092+
output_shape = list(input_t.shape)
1093+
ncols = input_t.size(last_dim)
1094+
quant_padding_size = 4 if quant_padding_float_type else 2
1095+
ncols_aligned = (
1096+
(ncols + quant_padding_size - 1) // quant_padding_size * quant_padding_size
1097+
)
1098+
output_columns = ncols_aligned - 2 * quant_padding_size
1099+
output_shape[last_dim] = output_columns
1100+
return torch.empty(output_shape, dtype=torch.float16, device=input_t.device)
1101+
1102+
10191103
def _setup() -> None:
10201104
# pyre-ignore[16]
10211105
_setup.done = getattr(_setup, "done", False)
@@ -1165,7 +1249,30 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None
11651249
"fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf",
11661250
fused_nbit_rowwise_quantized_sb_half_to_float_or_half,
11671251
)
1168-
1252+
impl_abstract(
1253+
"fbgemm::Fused8BitRowwiseQuantizedToFloatOrHalf",
1254+
fused_8_bit_rowwise_quantized_to_float_or_half,
1255+
)
1256+
impl_abstract(
1257+
"fbgemm::FloatToFused8BitRowwiseQuantized",
1258+
float_or_half_to_fused_8_bit_rowwise,
1259+
)
1260+
impl_abstract(
1261+
"fbgemm::FloatOrHalfToFused8BitRowwiseQuantized",
1262+
float_or_half_to_fused_8_bit_rowwise,
1263+
)
1264+
impl_abstract(
1265+
"fbgemm::HalfToFused8BitRowwiseQuantized",
1266+
float_or_half_to_fused_8_bit_rowwise,
1267+
)
1268+
impl_abstract(
1269+
"fbgemm::Fused8BitRowwiseQuantizedToFloat",
1270+
fused_8_bit_rowwise_quantized_to_float,
1271+
)
1272+
impl_abstract(
1273+
"fbgemm::Fused8BitRowwiseQuantizedToHalf",
1274+
fused_8_bit_rowwise_quantized_to_half,
1275+
)
11691276
_setup.done = True
11701277

11711278

fbgemm_gpu/test/quantize/failures_dict_fast.json

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
"_description": "This is a dict containing failures for tests autogenerated by generate_opcheck_tests. For more details, please see https://docs.google.com/document/d/1Pj5HRZvdOq3xpFpbEjUZp2hBovhy7Wnxw14m6lF2154/edit",
33
"_version": 1,
44
"data": {
5+
"fbgemm::FloatOrHalfToFused8BitRowwiseQuantized": {
6+
"TestFused8BitRowwiseQuantizationConversion.test_faketensor__test_quantize_op": {
7+
"comment": "",
8+
"status": "xfail"
9+
}
10+
},
511
"fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf": {},
612
"fbgemm::FloatToFused8BitRowwiseQuantized": {
713
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_forward_cpu_int8": {
@@ -23,6 +29,10 @@
2329
"SplitTableBatchedEmbeddingsTest.test_faketensor__test_nbit_forward_fused_pooled_emb_quant": {
2430
"comment": "",
2531
"status": "xfail"
32+
},
33+
"TestFused8BitRowwiseQuantizationConversion.test_faketensor__test_quantize_op": {
34+
"comment": "",
35+
"status": "xfail"
2636
}
2737
},
2838
"fbgemm::FloatToFusedNBitRowwiseQuantizedSBHalf": {
@@ -66,6 +76,12 @@
6676
"status": "xfail"
6777
}
6878
},
79+
"fbgemm::Fused8BitRowwiseQuantizedToFloatOrHalf": {
80+
"TestFused8BitRowwiseQuantizationConversion.test_faketensor__test_quantize_op": {
81+
"comment": "",
82+
"status": "xfail"
83+
}
84+
},
6985
"fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloat": {
7086
"TestFusedNBitRowwiseQuantizationConversion.test_faketensor__test_quantize_and_dequantize_op": {
7187
"comment": "",
@@ -93,6 +109,12 @@
93109
}
94110
},
95111
"fbgemm::HFP8QuantizedToFloat": {},
112+
"fbgemm::HalfToFused8BitRowwiseQuantized": {
113+
"TestFused8BitRowwiseQuantizationConversion.test_faketensor__test_quantize_op": {
114+
"comment": "",
115+
"status": "xfail"
116+
}
117+
},
96118
"fbgemm::HalfToFusedNBitRowwiseQuantizedSBHalf": {
97119
"TestFusedNBitRowwiseQuantizationConversion.test_faketensor__test_quantize_and_dequantize_op": {
98120
"comment": "",

fbgemm_gpu/test/quantize/fused_8bit_rowwise_test.py

Lines changed: 58 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,18 @@
2828
# pyre-fixme[16]: Module `common` has no attribute `open_source`.
2929
if open_source:
3030
# pyre-ignore[21]
31-
from test_utils import gpu_available
31+
from test_utils import gpu_available, optests
3232
else:
33-
from fbgemm_gpu.test.test_utils import gpu_available
33+
from fbgemm_gpu.test.test_utils import gpu_available, optests
34+
35+
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
36+
37+
torch.ops.import_module("fbgemm_gpu.sparse_ops")
3438

3539
no_long_tests: bool = False
3640

3741

42+
@optests.generate_opcheck_tests(fast=True)
3843
class TestFused8BitRowwiseQuantizationConversion(unittest.TestCase):
3944
# pyre-fixme[56]: Pyre was not able to infer the type of argument
4045
# `hypothesis.strategies.integers($parameter$min_value = 0, $parameter$max_value =
@@ -118,21 +123,7 @@ def test_quantize_op(
118123
reference[:, ncols + 4 : ncols + 8],
119124
)
120125

121-
# pyre-ignore [56]: Invalid decoration, was not able to infer the type of argument
122-
@given(
123-
nrows=st.integers(min_value=0, max_value=100),
124-
ncols=st.sampled_from([32, 128, 256, 384, 512, 1024]),
125-
output_dtype=st.sampled_from(
126-
[SparseType.FP16, SparseType.FP32, SparseType.BF16]
127-
),
128-
quant_padding_float_type=st.sampled_from(
129-
[True, False],
130-
),
131-
test_generic_op=st.booleans(),
132-
test_cuda=st.booleans(),
133-
)
134-
@settings(deadline=10000, suppress_health_check=[HealthCheck.filter_too_much])
135-
def test_quantize_and_dequantize_op( # noqa: C901
126+
def quantize_and_dequantize_op_test_helper( # noqa: C901
136127
self,
137128
nrows: int,
138129
ncols: int,
@@ -289,6 +280,56 @@ def test_quantize_and_dequantize_op( # noqa: C901
289280
dequantized_data_trimmed.bfloat16(), reference.bfloat16()
290281
)
291282

283+
# pyre-ignore [56]: Invalid decoration, was not able to infer the type of argument
284+
@given(
285+
nrows=st.integers(min_value=0, max_value=100),
286+
ncols=st.sampled_from([32, 128, 256, 384, 512, 1024]),
287+
output_dtype=st.sampled_from(
288+
[SparseType.FP16, SparseType.FP32, SparseType.BF16]
289+
),
290+
quant_padding_float_type=st.sampled_from(
291+
[True, False],
292+
),
293+
test_generic_op=st.booleans(),
294+
)
295+
@settings(deadline=10000, suppress_health_check=[HealthCheck.filter_too_much])
296+
def test_quantize_and_dequantize_op_cpu( # noqa: C901
297+
self,
298+
nrows: int,
299+
ncols: int,
300+
output_dtype: SparseType,
301+
quant_padding_float_type: bool,
302+
test_generic_op: bool,
303+
) -> None:
304+
self.quantize_and_dequantize_op_test_helper(
305+
nrows, ncols, output_dtype, quant_padding_float_type, test_generic_op, False
306+
)
307+
308+
# pyre-ignore [56]: Invalid decoration, was not able to infer the type of argument
309+
@given(
310+
nrows=st.integers(min_value=0, max_value=100),
311+
ncols=st.sampled_from([32, 128, 256, 384, 512, 1024]),
312+
output_dtype=st.sampled_from(
313+
[SparseType.FP16, SparseType.FP32, SparseType.BF16]
314+
),
315+
quant_padding_float_type=st.sampled_from(
316+
[True, False],
317+
),
318+
test_generic_op=st.booleans(),
319+
)
320+
@settings(deadline=10000, suppress_health_check=[HealthCheck.filter_too_much])
321+
def test_quantize_and_dequantize_op_cuda( # noqa: C901
322+
self,
323+
nrows: int,
324+
ncols: int,
325+
output_dtype: SparseType,
326+
quant_padding_float_type: bool,
327+
test_generic_op: bool,
328+
) -> None:
329+
self.quantize_and_dequantize_op_test_helper(
330+
nrows, ncols, output_dtype, quant_padding_float_type, test_generic_op, True
331+
)
332+
292333
@unittest.skipIf(no_long_tests, "Slow test, requires buck build to run.") # noqa
293334
def test_quantize_and_dequantize_op_cuda_large_nrows(self) -> None:
294335
ncols = 256

0 commit comments

Comments
 (0)