Skip to content

Commit 2c3e549

Browse files
zhaozhulfacebook-github-bot
authored andcommitted
support fp16 dtypes for input weight and bias (pytorch#1017)
Summary: X-link: pytorch#3931 Pull Request resolved: facebookresearch/FBGEMM#1017 ATT, this diff - Supports fp16 inputs for fp8 quantization and gemm - add addmm fallback to cpu for fp8gemm, s.t. the module is compatible with publish time processings Reviewed By: sijiac Differential Revision: D72479579 fbshipit-source-id: db9615095b1f246578fec39b92be6c2238ded5da
1 parent 59115fc commit 2c3e549

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cu

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,8 +1077,9 @@ std::vector<at::Tensor> quantize_fp8_per_row(
10771077
"Invalid dim. The dim of input should be greater than or equal to 2");
10781078
TORCH_CHECK(
10791079
input.scalar_type() == torch::kBFloat16 ||
1080-
input.scalar_type() == torch::kFloat,
1081-
"Invalid datatype. input must be BF16 or FP32");
1080+
input.scalar_type() == torch::kFloat ||
1081+
input.scalar_type() == torch::kHalf,
1082+
"Invalid datatype. input must be BF16, FP16 or FP32");
10821083
TORCH_CHECK(
10831084
!stochastic_rounding || input.size(-1) % 4 == 0,
10841085
"input row dim must be 4's multiple when stochastic_rounding is True");

0 commit comments

Comments
 (0)