Skip to content

Add DeepGEMM blockwise GEMM in quantize bench #3746

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,13 @@ def benchmark_grouped(
# Compute the output given quantized values.
output = quantize_op.compute(*quantized_vals)
# Some kernels may pad output, just take the first m values of each row.
output = [o[: m[i]] for i, o in enumerate(output)]
if isinstance(output, torch.Tensor) and output.ndim == 2:
# Output is stacked and needs to be split.
output = torch.split(output, m, dim=0)
else:
# Otherwise output may be padded or require unbinding.
output = [o[: m[i]] for i, o in enumerate(output)]
# Compare the quantize op output to reference as a sanity check.

for i in range(num_groups):
metrics.sim += float(
torch.mean(torch.pow(output[i] - out_ref[i], 2)).item()
Expand Down
155 changes: 155 additions & 0 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
scale_fp8_row,
triton_quantize_fp8_row,
)
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
grouped_gemm_fp8_rowwise,
)
from tinygemm.utils import group_quantize_tensor

if torch.cuda.is_available() and torch.version.cuda:
Expand All @@ -35,6 +38,17 @@
except ImportError:
MARLIN_ENABLED = False

try:
from deep_gemm import (
gemm_fp8_fp8_bf16_nt,
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous,
)

DEEPGEMM_ENABLED = True
except ImportError:
DEEPGEMM_ENABLED = False


# Machete is also only supported internally at Meta for now.
try:
from machete.machete import machete_gemm
Expand Down Expand Up @@ -712,6 +726,147 @@ def cuda(self) -> bool:
return True


@register_quantize_op
class FP8TritonStackedGroupedGemm(QuantizeOpBase):
"""
FP8 grouped matmul with rowwise scaling and stacked inputs implemented with triton.
"""

def preprocess(self, x, w):
m_values = [i.shape[0] for i in x]
# Convert m_values into offsets into grouped tensor.
m_sizes = torch.tensor(m_values).to(dtype=torch.int32, device=x[0].device)
# Quantize weights.
wq, w_scale = zip(*[quantize_fp8_row(i) for i in w])
# Group weights as single tensor.
wq = torch.concat(wq, dim=0).contiguous()
w_scale = torch.concat(w_scale, dim=0).contiguous()
# Also view input as flattened.
x = torch.concat(x, dim=0).contiguous()
# Return processed tensors.
return x, wq, w_scale, m_sizes

def quantize(self, x, wq, w_scale, m_sizes):
B = x.shape[0]
xq, x_scale = triton_quantize_fp8_row(x)
x_scale = x_scale.view(B, -1)
return xq, wq, x_scale, w_scale, m_sizes

def compute(self, xq, wq, x_scale, w_scale, m_sizes):
return grouped_gemm_fp8_rowwise(xq, wq, m_sizes, x_scale, w_scale)

def quantize_and_compute(self, x, wq, w_scale, m_sizes):
xq, wq, x_scale, w_scale, m_sizes = self.quantize(x, wq, w_scale, m_sizes)
return self.compute(xq, wq, x_scale, w_scale, m_sizes)

@property
def name(self) -> str:
return "triton_grouped_stacked"

@property
def hip(self) -> bool:
return True

@property
def cuda(self) -> bool:
return True


@register_quantize_op
class DeepGemmStacked(QuantizeOpBase):
"""
FP8 grouped matmul with blockwise scaling implemented with DeepGemm.
"""

def preprocess(self, x, w):
m_values = [i.shape[0] for i in x]
# Convert m_values into offsets into grouped tensor.
indices = torch.arange(len(m_values))
m_indices = indices.repeat_interleave(torch.tensor(m_values)).to(
device=x[0].device, dtype=torch.int
)
# Quantize weights.
wq, w_scale = zip(*[quantize_fp8_block(i, block_k=128, block_m=128) for i in w])
# Group weights as single tensor.
wq = torch.stack(wq, dim=0).contiguous()
w_scale = torch.stack(w_scale, dim=0).contiguous()
# Also view input as flattened.
x = torch.concat(x, dim=0).contiguous()
# Return processed tensors.
return x, wq, w_scale, m_indices

def quantize(self, x, wq, w_scale, m_indices):
xq, x_scale = quantize_fp8_block(x, block_m=1, block_k=128)
return xq, wq, x_scale, w_scale, m_indices

def compute(self, xq, wq, x_scale, w_scale, m_indices):
# Preallocate output.
out = torch.empty(
[xq.shape[0], wq.shape[1]], device=xq.device, dtype=torch.bfloat16
)
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
(xq, x_scale), (wq, w_scale), out, m_indices
)
return out

def quantize_and_compute(self, x, wq, w_scale, m_indices):
xq, wq, x_scale, w_scale, m_indices = self.quantize(x, wq, w_scale, m_indices)
return self.compute(xq, wq, x_scale, w_scale, m_indices)

@property
def name(self) -> str:
return "deepgemm_stacked"

@property
def hip(self) -> bool:
return False

@property
def cuda(self) -> bool:
return DEEPGEMM_ENABLED


@register_quantize_op
class DeepGemmBlockwise(QuantizeOpBase):
"""
FP8 matmul with blockwise scaling implemented with DeepGemm.
"""

def preprocess(self, x, w):
# Quantize weights.
wq, w_scale = quantize_fp8_block(w, block_m=128, block_k=128)
# allocate output.
out = torch.empty(
x.shape[0], wq.shape[0], device=x.device, dtype=torch.bfloat16
)
# Return processed tensors.
return x, wq, w_scale, out

def quantize(self, x, wq, w_scale, out):
xq, x_scale = quantize_fp8_block(x, block_m=1, block_k=128)
return xq, wq, x_scale, w_scale, out

def compute(self, xq, wq, x_scale, w_scale, out):
gemm_fp8_fp8_bf16_nt((xq, x_scale), (wq, w_scale), out)
return out

def quantize_and_compute(self, x, wq, w_scale, out):
xq, wq, x_scale, w_scale, out = self.quantize(x, wq, w_scale, out)
return self.compute(xq, wq, x_scale, w_scale, out)

@property
def name(self) -> str:
return "deepgemm_blockwise"

@property
def hip(self) -> bool:
return False

@property
def cuda(self) -> bool:
return True


@register_quantize_op
class BF16GroupedGemm(QuantizeOpBase):
"""
Expand Down
Loading