Skip to content

Eliminate MemCpyDtoH overhead for quantized fast_gemv kernel #3725

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,20 +410,20 @@ def cuda(self) -> bool:
@register_quantize_op
class BF16Fp8OSSFastGemv(QuantizeOpBase):
"""
FP8 OSS fast gemv kernel.
BF16FP8 OSS fast gemv kernel.
"""

def quantize(self, x, w):
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
return x, wq, w_scale

def compute(self, x, wq, w_scale):
out = torch.ops.fbgemm.bf16fp8bf16_fast_gemv(x, wq, w_scale, 0.0)
out = torch.ops.fbgemm.bf16fp8bf16_fast_gemv(x, wq, w_scale)
return out

def quantize_and_compute(self, x, w):
x, wq, w_scale = self.quantize(x, w)
return self.compute(x, wq, w_scale.item())
return self.compute(x, wq, w_scale)

@property
def name(self) -> str:
Expand Down Expand Up @@ -451,12 +451,12 @@ def quantize(self, x, w):
return xq, wq, w_scale, x_scale

def compute(self, xq, wq, w_scale, x_scale):
out = torch.ops.fbgemm.fp8fp8bf16_fast_gemv(xq, wq, w_scale * x_scale, 0.0)
out = torch.ops.fbgemm.fp8fp8bf16_fast_gemv(xq, wq, w_scale * x_scale)
return out

def quantize_and_compute(self, x, w):
xq, wq, w_scale, x_scale = self.quantize(x, w)
return self.compute(xq, wq, w_scale.item(), x_scale.item())
return self.compute(xq, wq, w_scale, x_scale)

@property
def name(self) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dim3 get_best_block_dim(int m, int n, int k) {
} // namespace

at::Tensor
bf16fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, double w_scale, double w_zp) {
bf16fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, at::Tensor w_scale) {
// X: M x K
// W: N x K
auto m = X.size(0);
Expand All @@ -65,8 +65,7 @@ bf16fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, double w_scale, double w_zp) {
reinterpret_cast<__nv_bfloat16*>(X.data_ptr()), // vec
reinterpret_cast<__nv_bfloat16*>(Y.data_ptr()), // res
k,
__float2half(float(w_scale)),
__float2half(float(w_zp)),
reinterpret_cast<float const*>(w_scale.data_ptr()),
num_per_thread);

C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ dim3 get_best_block_dim(int m, int n, int k) {
}
} // namespace

at::Tensor
fp8fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, double scale, double zp) {
at::Tensor fp8fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, at::Tensor scale) {
// X: M x K
// W: N x K
auto m = X.size(0);
Expand All @@ -65,8 +64,7 @@ fp8fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, double scale, double zp) {
reinterpret_cast<cutlass::float_e4m3_t*>(X.data_ptr()), // vec
reinterpret_cast<__nv_bfloat16*>(Y.data_ptr()), // res
k,
__float2half(scale),
__float2half(zp),
reinterpret_cast<float const*>(scale.data_ptr()),
num_per_thread);

C10_CUDA_KERNEL_LAUNCH_CHECK();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ __global__ void gemv_quantized_bf16_fp8(
__nv_bfloat16* vec,
__nv_bfloat16* res,
unsigned int n,
half scale,
half zero_point,
float const* scale,
unsigned int num_per_thread) {
float sum = 0;
// each thread load num_per_thread elements from global
Expand All @@ -151,9 +150,6 @@ __global__ void gemv_quantized_bf16_fp8(
half4* mat4 = reinterpret_cast<half4*>(mat);
float4* vec4 = reinterpret_cast<float4*>(vec);

float zero_point_f = static_cast<float>(zero_point);
float scale_f = static_cast<float>(scale);

#pragma unroll
for (int iter = 0; iter < num_per_thread >> 3; iter++) {
unsigned int j = start_idx + iter * blockDim.x;
Expand All @@ -170,48 +166,40 @@ __global__ void gemv_quantized_bf16_fp8(
const fp8_2* mat_h4 = (fp8_2*)&mat_val.w;
sum +=
cutlass::NumericConverter<float, __nv_bfloat16>::convert(vec_h1->x) *
(cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h1->x) -
zero_point_f);
cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h1->x);
sum +=
cutlass::NumericConverter<float, __nv_bfloat16>::convert(vec_h1->y) *
(cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h1->y) -
zero_point_f);
cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h1->y);
sum +=
cutlass::NumericConverter<float, __nv_bfloat16>::convert(vec_h2->x) *
(cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h2->x) -
zero_point_f);
cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h2->x);
sum +=
cutlass::NumericConverter<float, __nv_bfloat16>::convert(vec_h2->y) *
(cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h2->y) -
zero_point_f);
cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h2->y);
sum +=
cutlass::NumericConverter<float, __nv_bfloat16>::convert(vec_h3->x) *
(cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h3->x) -
zero_point_f);
cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h3->x);
sum +=
cutlass::NumericConverter<float, __nv_bfloat16>::convert(vec_h3->y) *
(cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h3->y) -
zero_point_f);
cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h3->y);
sum +=
cutlass::NumericConverter<float, __nv_bfloat16>::convert(vec_h4->x) *
(cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h4->x) -
zero_point_f);
cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h4->x);
sum +=
cutlass::NumericConverter<float, __nv_bfloat16>::convert(vec_h4->y) *
(cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h4->y) -
zero_point_f);
cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h4->y);
}
}

sum *= scale_f;
sum *= (*scale);

sum = warpReduceSum(sum, blockDim.x);

Expand Down Expand Up @@ -248,8 +236,7 @@ __global__ void gemv_quantized_fp8_fp8(
cutlass::float_e4m3_t* vec,
__nv_bfloat16* res,
unsigned int n,
half scale,
half zero_point,
float const* scale,
unsigned int num_per_thread) {
float sum = 0;
// each thread load num_per_thread elements from global
Expand All @@ -259,10 +246,6 @@ __global__ void gemv_quantized_fp8_fp8(
half4* mat4 = reinterpret_cast<half4*>(mat);
half4* vec4 = reinterpret_cast<half4*>(vec);

float zero_point_f = static_cast<float>(
zero_point); // so far, we use a default 0 value zero_point
float scale_f = static_cast<float>(scale);

#pragma unroll
for (int iter = 0; iter < num_per_thread >> 3; iter++) {
unsigned int j = start_idx + iter * blockDim.x;
Expand All @@ -277,58 +260,42 @@ __global__ void gemv_quantized_fp8_fp8(
const fp8_2* mat_h2 = (fp8_2*)&mat_val.y;
const fp8_2* mat_h3 = (fp8_2*)&mat_val.z;
const fp8_2* mat_h4 = (fp8_2*)&mat_val.w;
sum += (cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
vec_h1->x) -
zero_point_f) *
(cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h1->x) -
zero_point_f);
sum += (cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
vec_h1->y) -
zero_point_f) *
(cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h1->y) -
zero_point_f);
sum += (cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
vec_h2->x) -
zero_point_f) *
(cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h2->x) -
zero_point_f);
sum += (cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
vec_h2->y) -
zero_point_f) *
(cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h2->y) -
zero_point_f);
sum += (cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
vec_h3->x) -
zero_point_f) *
(cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h3->x) -
zero_point_f);
sum += (cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
vec_h3->y) -
zero_point_f) *
(cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h3->y) -
zero_point_f);
sum += (cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
vec_h4->x) -
zero_point_f) *
(cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h4->x) -
zero_point_f);
sum += (cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
vec_h4->y) -
zero_point_f) *
(cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h4->y) -
zero_point_f);
sum += cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
vec_h1->x) *
cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h1->x);
sum += cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
vec_h1->y) *
cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h1->y);
sum += cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
vec_h2->x) *
cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h2->x);
sum += cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
vec_h2->y) *
cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h2->y);
sum += cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
vec_h3->x) *
cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h3->x);
sum += cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
vec_h3->y) *
cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h3->y);
sum += cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
vec_h4->x) *
cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h4->x);
sum += cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
vec_h4->y) *
cutlass::NumericConverter<float, cutlass::float_e4m3_t>::convert(
mat_h4->y);
}
}

sum *= (scale_f);
sum *= (*scale);

sum = warpReduceSum(sum, blockDim.x);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,15 @@ __global__ void gemv_quantized_bf16_fp8(
__nv_bfloat16* vec,
__nv_bfloat16* res,
unsigned int n,
half scale,
half zero_point,
float const* scale,
unsigned int num_per_thread);

__global__ void gemv_quantized_fp8_fp8(
cutlass::float_e4m3_t* mat,
cutlass::float_e4m3_t* vec,
__nv_bfloat16* res,
unsigned int n,
half scale,
half zero_point,
float const* scale,
unsigned int num_per_thread);

__global__ void gemv_quantized_int4(
Expand Down
22 changes: 7 additions & 15 deletions fbgemm_gpu/experimental/gen_ai/src/quantize/quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,8 @@ at::Tensor f8f8bf16_cublas(
std::optional<at::Tensor> output = std::nullopt);
at::Tensor bf16_fast_gemv(at::Tensor X, at::Tensor W);
at::Tensor
bf16fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, double w_scale, double w_zp);

at::Tensor
fp8fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, double scale, double zp);
bf16fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, at::Tensor w_scale);
at::Tensor fp8fp8bf16_fast_gemv(at::Tensor X, at::Tensor W, at::Tensor scale);

at::Tensor f8i4bf16_rowwise(
at::Tensor XQ,
Expand Down Expand Up @@ -192,10 +190,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"f8i4bf16_rowwise(Tensor XQ, Tensor WQ, Tensor x_scale, Tensor w_scale, Tensor w_zp) -> Tensor");
m.def("bf16_fast_gemv(Tensor X, Tensor W) -> Tensor");
m.def(
"bf16fp8bf16_fast_gemv(Tensor X, Tensor W, float w_scale, float w_zp) -> Tensor");
m.def(
"fp8fp8bf16_fast_gemv(Tensor X, Tensor W, float scale, float zp) -> Tensor");
m.def("bf16fp8bf16_fast_gemv(Tensor X, Tensor W, Tensor w_scale) -> Tensor");
m.def("fp8fp8bf16_fast_gemv(Tensor X, Tensor W, Tensor scale) -> Tensor");
m.def("f8f8bf16_lite(Tensor XQ, Tensor WQ, Tensor scale) -> Tensor");
m.def(
"bf16i4bf16_rowwise(Tensor X, Tensor WQ, Tensor w_scale, Tensor w_zp) -> Tensor");
Expand Down Expand Up @@ -424,19 +420,15 @@ at::Tensor bf16_fast_gemv_meta(at::Tensor X, at::Tensor W) {
at::Tensor bf16fp8bf16_fast_gemv_meta(
at::Tensor X,
at::Tensor W,
double /*w_scale*/,
double /*w_zp*/) {
at::Tensor /* w_scale */) {
const at::SymInt M = X.sym_size(0);
const at::SymInt N = W.sym_size(0);
auto Y = at::empty_symint({M, N}, X.options().dtype(at::kBFloat16));
return Y;
}

at::Tensor fp8fp8bf16_fast_gemv_meta(
at::Tensor X,
at::Tensor W,
double /*scale*/,
double /*zp*/) {
at::Tensor
fp8fp8bf16_fast_gemv_meta(at::Tensor X, at::Tensor W, at::Tensor /* scale */) {
const at::SymInt M = X.sym_size(0);
const at::SymInt N = W.sym_size(0);
auto Y = at::empty_symint({M, N}, X.options().dtype(at::kBFloat16));
Expand Down
4 changes: 2 additions & 2 deletions fbgemm_gpu/experimental/gen_ai/test/quantize/quantize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1128,11 +1128,11 @@ def test_gemv(
w = torch.randn(size=(N, K), dtype=torch.bfloat16, device="cuda") * 0.01
if quantize_w and not quantize_x:
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
z = gemv_op(x, wq, w_scale.item(), 0.0)
z = gemv_op(x, wq, w_scale)
elif quantize_w and quantize_x:
xq, x_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(x)
wq, w_scale = torch.ops.fbgemm.quantize_fp8_per_tensor(w)
z = gemv_op(xq, wq, x_scale.item() * w_scale.item(), 0.0)
z = gemv_op(xq, wq, x_scale * w_scale)
else:
z = gemv_op(x, w)
z_ref = (x @ w.T).to(torch.bfloat16).to("cuda")
Expand Down
Loading