@@ -1016,6 +1016,90 @@ def fused_nbit_rowwise_quantized_sb_half_to_float_or_half(
1016
1016
)
1017
1017
1018
1018
1019
+ def fused_8_bit_rowwise_quantized_to_float_or_half (
1020
+ input_t : Tensor ,
1021
+ output_dtype : int = 0 ,
1022
+ scale_bias_last : bool = True ,
1023
+ quant_padding_float_type : bool = True ,
1024
+ ) -> Tensor :
1025
+ torch ._check (
1026
+ output_dtype
1027
+ in [
1028
+ SparseType .FP32 .as_int (),
1029
+ SparseType .FP16 .as_int (),
1030
+ SparseType .BF16 .as_int (),
1031
+ ]
1032
+ )
1033
+ torch ._check (quant_padding_float_type or not scale_bias_last )
1034
+ torch ._check (input_t .dim () >= 2 )
1035
+ last_dim = input_t .dim () - 1
1036
+ output_shape = list (input_t .shape )
1037
+ ncols = input_t .size (last_dim )
1038
+ quant_padding_size = 4 if quant_padding_float_type else 2
1039
+ ncols_aligned = (
1040
+ (ncols + quant_padding_size - 1 ) // quant_padding_size * quant_padding_size
1041
+ )
1042
+ output_columns = ncols_aligned - 2 * quant_padding_size
1043
+ output_shape [last_dim ] = output_columns
1044
+ if output_dtype == SparseType .FP32 .as_int ():
1045
+ return torch .empty (output_shape , dtype = torch .float32 , device = input_t .device )
1046
+ elif output_dtype == SparseType .FP16 .as_int ():
1047
+ return torch .empty (output_shape , dtype = torch .float16 , device = input_t .device )
1048
+ else : # output_dtype is SparseType.BF16
1049
+ return torch .empty (output_shape , dtype = torch .bfloat16 , device = input_t .device )
1050
+
1051
+
1052
+ def float_or_half_to_fused_8_bit_rowwise (
1053
+ input_t : Tensor ,
1054
+ ) -> Tensor :
1055
+ torch ._check (input_t .dim () >= 2 )
1056
+ last_dim = input_t .dim () - 1
1057
+ output_shape = list (input_t .shape )
1058
+ ncols = input_t .size (last_dim )
1059
+ ncols_aligned = (ncols + 4 - 1 ) // 4 * 4
1060
+ output_columns = ncols_aligned + 2 * 4
1061
+ output_shape [last_dim ] = output_columns
1062
+ return torch .empty (output_shape , dtype = torch .uint8 , device = input_t .device )
1063
+
1064
+
1065
+ def fused_8_bit_rowwise_quantized_to_float (
1066
+ input_t : Tensor ,
1067
+ scale_bias_last : bool = True ,
1068
+ quant_padding_float_type : bool = True ,
1069
+ ) -> Tensor :
1070
+ torch ._check (quant_padding_float_type or not scale_bias_last )
1071
+ torch ._check (input_t .dim () >= 2 )
1072
+ last_dim = input_t .dim () - 1
1073
+ output_shape = list (input_t .shape )
1074
+ ncols = input_t .size (last_dim )
1075
+ quant_padding_size = 4 if quant_padding_float_type else 2
1076
+ ncols_aligned = (
1077
+ (ncols + quant_padding_size - 1 ) // quant_padding_size * quant_padding_size
1078
+ )
1079
+ output_columns = ncols_aligned - 2 * quant_padding_size
1080
+ output_shape [last_dim ] = output_columns
1081
+ return torch .empty (output_shape , dtype = torch .float32 , device = input_t .device )
1082
+
1083
+
1084
+ def fused_8_bit_rowwise_quantized_to_half (
1085
+ input_t : Tensor ,
1086
+ scale_bias_last : bool = True ,
1087
+ quant_padding_float_type : bool = True ,
1088
+ ) -> Tensor :
1089
+ torch ._check (quant_padding_float_type or not scale_bias_last )
1090
+ torch ._check (input_t .dim () >= 2 )
1091
+ last_dim = input_t .dim () - 1
1092
+ output_shape = list (input_t .shape )
1093
+ ncols = input_t .size (last_dim )
1094
+ quant_padding_size = 4 if quant_padding_float_type else 2
1095
+ ncols_aligned = (
1096
+ (ncols + quant_padding_size - 1 ) // quant_padding_size * quant_padding_size
1097
+ )
1098
+ output_columns = ncols_aligned - 2 * quant_padding_size
1099
+ output_shape [last_dim ] = output_columns
1100
+ return torch .empty (output_shape , dtype = torch .float16 , device = input_t .device )
1101
+
1102
+
1019
1103
def _setup () -> None :
1020
1104
# pyre-ignore[16]
1021
1105
_setup .done = getattr (_setup , "done" , False )
@@ -1165,7 +1249,30 @@ def impl_autograd(op_name, fn, setup_context: Optional[Callable] = None) -> None
1165
1249
"fbgemm::FusedNBitRowwiseQuantizedSBHalfToFloatOrHalf" ,
1166
1250
fused_nbit_rowwise_quantized_sb_half_to_float_or_half ,
1167
1251
)
1168
-
1252
+ impl_abstract (
1253
+ "fbgemm::Fused8BitRowwiseQuantizedToFloatOrHalf" ,
1254
+ fused_8_bit_rowwise_quantized_to_float_or_half ,
1255
+ )
1256
+ impl_abstract (
1257
+ "fbgemm::FloatToFused8BitRowwiseQuantized" ,
1258
+ float_or_half_to_fused_8_bit_rowwise ,
1259
+ )
1260
+ impl_abstract (
1261
+ "fbgemm::FloatOrHalfToFused8BitRowwiseQuantized" ,
1262
+ float_or_half_to_fused_8_bit_rowwise ,
1263
+ )
1264
+ impl_abstract (
1265
+ "fbgemm::HalfToFused8BitRowwiseQuantized" ,
1266
+ float_or_half_to_fused_8_bit_rowwise ,
1267
+ )
1268
+ impl_abstract (
1269
+ "fbgemm::Fused8BitRowwiseQuantizedToFloat" ,
1270
+ fused_8_bit_rowwise_quantized_to_float ,
1271
+ )
1272
+ impl_abstract (
1273
+ "fbgemm::Fused8BitRowwiseQuantizedToHalf" ,
1274
+ fused_8_bit_rowwise_quantized_to_half ,
1275
+ )
1169
1276
_setup .done = True
1170
1277
1171
1278
0 commit comments