Skip to content
This repository was archived by the owner on Apr 24, 2025. It is now read-only.

Fix int4 quantization for llama and gemma #47

Merged
merged 5 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions examples/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
#

from transformers import AutoTokenizer, TextStreamer
from intel_npu_acceleration_library import NPUModelForCausalLM
import torch
from intel_npu_acceleration_library import NPUModelForCausalLM, int4

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

model = NPUModelForCausalLM.from_pretrained(
model_id, use_cache=True, dtype=torch.int8
model_id, use_cache=True, dtype=int4, attn_implementation="sdpa"
).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True)
tokenizer.pad_token_id = tokenizer.eos_token_id
Expand Down
7 changes: 2 additions & 5 deletions examples/llama3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@
#

from transformers import AutoTokenizer, TextStreamer
from intel_npu_acceleration_library import NPUModelForCausalLM
import torch
from intel_npu_acceleration_library import NPUModelForCausalLM, int4

model_id = "meta-llama/Meta-Llama-3-8B-Instruct"

model = NPUModelForCausalLM.from_pretrained(
model_id, dtype=torch.int8, use_cache=True
).eval()
model = NPUModelForCausalLM.from_pretrained(model_id, dtype=int4, use_cache=True).eval()
tokenizer = AutoTokenizer.from_pretrained(model_id)
streamer = TextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True)

Expand Down
38 changes: 26 additions & 12 deletions intel_npu_acceleration_library/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,13 @@ def compile(
# Prepare and optimize model for NPU
with torch.no_grad():
# General optimizations
apply_horizontal_fusion(model)
optimize_llama_attention(model, dtype)
apply_general_optimizations(model)
if dtype in (int8, int4):
# Quantize model
model = quantize_model(model, dtype)

# Model lowering to NPU ops
lower_linear(model, dtype)
create_npu_kernels(model)

if dtype.is_floating_point and training:
# Set model to evaluation only as quantized training is not supported yet
Expand All @@ -56,6 +55,25 @@ def compile(
return model.eval()


def apply_general_optimizations(model: torch.nn.Module):
"""Apply general optimizations to a torch.nn.Module.

Args:
model (torch.nn.Module): a pytorch nn.Module to compile and optimize for the npu
"""
apply_horizontal_fusion(model)
optimize_llama_attention(model)


def create_npu_kernels(model: torch.nn.Module):
"""Create NPU kernels.

Args:
model (torch.nn.Module): a pytorch nn.Module to compile and optimize for the npu
"""
lower_linear(model)


def module_optimization(func: Callable) -> torch.nn.Module:
"""Optimize recursively a torch.nn.Module with a specific function.

Expand Down Expand Up @@ -89,15 +107,12 @@ def wrapper(model: torch.nn.Module, *args: Any, **kwargs: Any):


@module_optimization
def lower_linear(
name: str, layer: torch.nn.Module, dtype: torch.dtype
) -> Union[torch.nn.Module, None]:
def lower_linear(name: str, layer: torch.nn.Module) -> Union[torch.nn.Module, None]:
"""Lower torch.nn.Linear layer to NPU equivalent operators.

Args:
name (str): Layer name
layer (torch.nn.Module): Original torch.nn.Linear module
dtype (torch.dtype): Target datatype

Raises:
RuntimeError: unsupported quantization bits
Expand All @@ -106,9 +121,9 @@ def lower_linear(
Union[torch.nn.Module, None]: Return the new NPU operator or None
"""
if isinstance(layer, torch.nn.Linear):
return nn.Linear.fromTorch(layer, dtype)
return nn.Linear.fromTorch(layer)
if isinstance(layer, torch.nn.Conv2d):
return nn.Conv2d.fromTorch(layer, dtype)
return nn.Conv2d.fromTorch(layer)
if isinstance(layer, WeightOnlyLinear):
if layer.bits == 4:
return nn.QuantizedLinear(
Expand Down Expand Up @@ -143,20 +158,19 @@ def apply_horizontal_fusion(

@module_optimization
def optimize_llama_attention(
name: str, layer: torch.nn.Module, dtype: torch.dtype
name: str, layer: torch.nn.Module
) -> Union[torch.nn.Module, None]:
"""Optimize LLAMA attention block.

Args:
name (str): Module name
layer (torch.nn.Module): Original Module
dtype (torch.dtype): Target datatype

Returns:
Union[torch.nn.Module, None]: optimized llama module
"""
if isinstance(layer, (LlamaAttention, GemmaAttention)):
return nn.LlamaAttention.fromTorch(layer, dtype)
return nn.LlamaAttention.fromTorch(layer)
return None


Expand Down
9 changes: 9 additions & 0 deletions intel_npu_acceleration_library/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def __eq__(self, value: Union["NPUDtype", torch.dtype]) -> bool:
else:
return super().__eq__(value)

def __repr__(self) -> str:
"""
Return a string representation of the NPUDtype object.

Returns:
str: The string representation of the NPUDtype object.
"""
return self.name


float16 = NPUDtype("fp16", 16, -65504, 65504, torch.float16)
bfloat16 = NPUDtype("bfloat16", 16, -65504, 65504, torch.float16)
Expand Down
2 changes: 1 addition & 1 deletion intel_npu_acceleration_library/modelling.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def get_model_path(model_name: str, *args: Any, **kwargs: Any) -> Tuple[str, str
cache_dir = get_cache_dir()
mangled_model_name = get_mangled_model_name(model_name, *args, **kwargs)
model_dir_path = os.path.join(cache_dir, mangled_model_name)
model_path = os.path.join(model_dir_path, "model.pt")
model_path = os.path.join(model_dir_path, "pytorch_npu_model.pt")
return model_dir_path, model_path


Expand Down
4 changes: 2 additions & 2 deletions intel_npu_acceleration_library/nn/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ def forward(self, x) -> torch.Tensor:
return out

@staticmethod
def fromTorch(layer, dtype) -> "Conv2d":
def fromTorch(layer, dtype: torch.dtype = torch.float16) -> "Conv2d":
"""
Create a Conv2d layer from a torch.nn.Conv2d layer.

Args:
layer (torch.nn.Conv2d): The torch Conv2d layer.
dtype (torch.dtype): Data type of the layer.
dtype (torch.dtype, optional): Data type of the layer.

Returns:
Conv2d: The converted Conv2d layer.
Expand Down
121 changes: 121 additions & 0 deletions script/quantize_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
#
# Copyright © 2024 Intel Corporation
# SPDX-License-Identifier: Apache 2.0
#

from transformers import AutoModelForCausalLM, AutoTokenizer
import intel_npu_acceleration_library as npu_lib
from neural_compressor.config import PostTrainingQuantConfig
from neural_compressor.quantization import fit
from neural_compressor.adaptor.torch_utils.auto_round import get_dataloader
import argparse
import torch
import os


def define_and_parse_arguments():
parser = argparse.ArgumentParser(description="Export a model to NPU")
parser.add_argument(
"-m",
"--model",
type=str,
required=True,
help="The name of the model to export",
)
parser.add_argument(
"-b",
"--bits",
type=int,
default=4,
help="The number of bits to use for quantization",
)
parser.add_argument(
"-o",
"--output-dir",
type=str,
default="models",
help="The directory where to save the exported model",
)
parser.add_argument(
"-s",
"--sequence-lenght",
type=int,
default=2048,
help="The sequence lenght to use for the dataloader",
)
parser.add_argument(
"-a",
"--algorithm",
type=str,
default="RTN",
help="The quantization algorithm to use",
)
return parser.parse_args()


def export_model(
model_name: str,
algorithm: str,
bits: int = 4,
sequence_lenght: int = 2048,
output_dir: str = "models",
):
"""Quantize and export a model.

Args:
model_name (str): the name of the model to export
algorithm (str, optional): the neural compressor quantization algorithm
bits (int, optional): the number of bits. Defaults to 4.
sequence_lenght (int, optional): the model sequence lenght. Defaults to 2048.
output_dir (str, optional): the output directory. Defaults to "models".
"""
print(f"Exporting model {model_name} with {bits} bits")
output_folder = os.path.join(output_dir, model_name, algorithm, f"int{bits}")
os.makedirs(output_folder, exist_ok=True)

float_model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
float_model.config.save_pretrained(output_folder)
tokenizer.save_pretrained(output_folder)

dataloader = get_dataloader(tokenizer, seqlen=sequence_lenght)

woq_conf = PostTrainingQuantConfig(
approach="weight_only",
op_type_dict={
".*": { # match all ops
"weight": {
"dtype": "int",
"bits": bits,
"group_size": -1,
"scheme": "sym",
"algorithm": algorithm.upper(),
},
"activation": {
"dtype": "fp16",
},
}
},
)

print("Apply generic model optimizations")
npu_lib.compiler.apply_general_optimizations(float_model)
print("Quantize model")
quantized_model = fit(model=float_model, conf=woq_conf, calib_dataloader=dataloader)
print("Quantize model")
compressed_model = quantized_model.export_compressed_model(
scale_dtype=torch.float16, use_optimum_format=False
)

print("Create NPU kernels")
npu_model = npu_lib.compiler.create_npu_kernels(compressed_model)

torch.save(npu_model, os.path.join(output_folder, "pytorch_npu_model.bin"))
print(f"Model succesfully exported to {output_folder}")


if __name__ == "__main__":
args = define_and_parse_arguments()
export_model(
args.model, args.algorithm, args.bits, args.sequence_lenght, args.output_dir
)
2 changes: 1 addition & 1 deletion src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,4 +198,4 @@ intel_npu_acceleration_library_DLL_API ov::op::Op* scaled_dot_product_attention(
ov::op::Op* attn_mask, bool is_causal) {
return factory->scaled_dot_product_attention(query, key, value, attn_mask, is_causal);
}
};
}
Loading