diff --git a/examples/llama.py b/examples/llama.py index 188b5dd..9c2aaba 100644 --- a/examples/llama.py +++ b/examples/llama.py @@ -4,13 +4,12 @@ # from transformers import AutoTokenizer, TextStreamer -from intel_npu_acceleration_library import NPUModelForCausalLM -import torch +from intel_npu_acceleration_library import NPUModelForCausalLM, int4 model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" model = NPUModelForCausalLM.from_pretrained( - model_id, use_cache=True, dtype=torch.int8 + model_id, use_cache=True, dtype=int4, attn_implementation="sdpa" ).eval() tokenizer = AutoTokenizer.from_pretrained(model_id, use_default_system_prompt=True) tokenizer.pad_token_id = tokenizer.eos_token_id diff --git a/examples/llama3.py b/examples/llama3.py index d611d03..dcdddf9 100644 --- a/examples/llama3.py +++ b/examples/llama3.py @@ -4,14 +4,11 @@ # from transformers import AutoTokenizer, TextStreamer -from intel_npu_acceleration_library import NPUModelForCausalLM -import torch +from intel_npu_acceleration_library import NPUModelForCausalLM, int4 model_id = "meta-llama/Meta-Llama-3-8B-Instruct" -model = NPUModelForCausalLM.from_pretrained( - model_id, dtype=torch.int8, use_cache=True -).eval() +model = NPUModelForCausalLM.from_pretrained(model_id, dtype=int4, use_cache=True).eval() tokenizer = AutoTokenizer.from_pretrained(model_id) streamer = TextStreamer(tokenizer, skip_special_tokens=True, skip_prompt=True) diff --git a/intel_npu_acceleration_library/compiler.py b/intel_npu_acceleration_library/compiler.py index 128b6c7..4e80d04 100644 --- a/intel_npu_acceleration_library/compiler.py +++ b/intel_npu_acceleration_library/compiler.py @@ -40,14 +40,13 @@ def compile( # Prepare and optimize model for NPU with torch.no_grad(): # General optimizations - apply_horizontal_fusion(model) - optimize_llama_attention(model, dtype) + apply_general_optimizations(model) if dtype in (int8, int4): # Quantize model model = quantize_model(model, dtype) # Model lowering to NPU ops - lower_linear(model, dtype) + create_npu_kernels(model) if dtype.is_floating_point and training: # Set model to evaluation only as quantized training is not supported yet @@ -56,6 +55,25 @@ def compile( return model.eval() +def apply_general_optimizations(model: torch.nn.Module): + """Apply general optimizations to a torch.nn.Module. + + Args: + model (torch.nn.Module): a pytorch nn.Module to compile and optimize for the npu + """ + apply_horizontal_fusion(model) + optimize_llama_attention(model) + + +def create_npu_kernels(model: torch.nn.Module): + """Create NPU kernels. + + Args: + model (torch.nn.Module): a pytorch nn.Module to compile and optimize for the npu + """ + lower_linear(model) + + def module_optimization(func: Callable) -> torch.nn.Module: """Optimize recursively a torch.nn.Module with a specific function. @@ -89,15 +107,12 @@ def wrapper(model: torch.nn.Module, *args: Any, **kwargs: Any): @module_optimization -def lower_linear( - name: str, layer: torch.nn.Module, dtype: torch.dtype -) -> Union[torch.nn.Module, None]: +def lower_linear(name: str, layer: torch.nn.Module) -> Union[torch.nn.Module, None]: """Lower torch.nn.Linear layer to NPU equivalent operators. Args: name (str): Layer name layer (torch.nn.Module): Original torch.nn.Linear module - dtype (torch.dtype): Target datatype Raises: RuntimeError: unsupported quantization bits @@ -106,9 +121,9 @@ def lower_linear( Union[torch.nn.Module, None]: Return the new NPU operator or None """ if isinstance(layer, torch.nn.Linear): - return nn.Linear.fromTorch(layer, dtype) + return nn.Linear.fromTorch(layer) if isinstance(layer, torch.nn.Conv2d): - return nn.Conv2d.fromTorch(layer, dtype) + return nn.Conv2d.fromTorch(layer) if isinstance(layer, WeightOnlyLinear): if layer.bits == 4: return nn.QuantizedLinear( @@ -143,20 +158,19 @@ def apply_horizontal_fusion( @module_optimization def optimize_llama_attention( - name: str, layer: torch.nn.Module, dtype: torch.dtype + name: str, layer: torch.nn.Module ) -> Union[torch.nn.Module, None]: """Optimize LLAMA attention block. Args: name (str): Module name layer (torch.nn.Module): Original Module - dtype (torch.dtype): Target datatype Returns: Union[torch.nn.Module, None]: optimized llama module """ if isinstance(layer, (LlamaAttention, GemmaAttention)): - return nn.LlamaAttention.fromTorch(layer, dtype) + return nn.LlamaAttention.fromTorch(layer) return None diff --git a/intel_npu_acceleration_library/dtypes.py b/intel_npu_acceleration_library/dtypes.py index 1082fda..1f28569 100644 --- a/intel_npu_acceleration_library/dtypes.py +++ b/intel_npu_acceleration_library/dtypes.py @@ -61,6 +61,15 @@ def __eq__(self, value: Union["NPUDtype", torch.dtype]) -> bool: else: return super().__eq__(value) + def __repr__(self) -> str: + """ + Return a string representation of the NPUDtype object. + + Returns: + str: The string representation of the NPUDtype object. + """ + return self.name + float16 = NPUDtype("fp16", 16, -65504, 65504, torch.float16) bfloat16 = NPUDtype("bfloat16", 16, -65504, 65504, torch.float16) diff --git a/intel_npu_acceleration_library/modelling.py b/intel_npu_acceleration_library/modelling.py index e9510ec..420db3c 100644 --- a/intel_npu_acceleration_library/modelling.py +++ b/intel_npu_acceleration_library/modelling.py @@ -52,7 +52,7 @@ def get_model_path(model_name: str, *args: Any, **kwargs: Any) -> Tuple[str, str cache_dir = get_cache_dir() mangled_model_name = get_mangled_model_name(model_name, *args, **kwargs) model_dir_path = os.path.join(cache_dir, mangled_model_name) - model_path = os.path.join(model_dir_path, "model.pt") + model_path = os.path.join(model_dir_path, "pytorch_npu_model.pt") return model_dir_path, model_path diff --git a/intel_npu_acceleration_library/nn/conv.py b/intel_npu_acceleration_library/nn/conv.py index 9ad57bb..3d00a62 100644 --- a/intel_npu_acceleration_library/nn/conv.py +++ b/intel_npu_acceleration_library/nn/conv.py @@ -118,13 +118,13 @@ def forward(self, x) -> torch.Tensor: return out @staticmethod - def fromTorch(layer, dtype) -> "Conv2d": + def fromTorch(layer, dtype: torch.dtype = torch.float16) -> "Conv2d": """ Create a Conv2d layer from a torch.nn.Conv2d layer. Args: layer (torch.nn.Conv2d): The torch Conv2d layer. - dtype (torch.dtype): Data type of the layer. + dtype (torch.dtype, optional): Data type of the layer. Returns: Conv2d: The converted Conv2d layer. diff --git a/script/quantize_model.py b/script/quantize_model.py new file mode 100644 index 0000000..2a294b3 --- /dev/null +++ b/script/quantize_model.py @@ -0,0 +1,121 @@ +# +# Copyright © 2024 Intel Corporation +# SPDX-License-Identifier: Apache 2.0 +# + +from transformers import AutoModelForCausalLM, AutoTokenizer +import intel_npu_acceleration_library as npu_lib +from neural_compressor.config import PostTrainingQuantConfig +from neural_compressor.quantization import fit +from neural_compressor.adaptor.torch_utils.auto_round import get_dataloader +import argparse +import torch +import os + + +def define_and_parse_arguments(): + parser = argparse.ArgumentParser(description="Export a model to NPU") + parser.add_argument( + "-m", + "--model", + type=str, + required=True, + help="The name of the model to export", + ) + parser.add_argument( + "-b", + "--bits", + type=int, + default=4, + help="The number of bits to use for quantization", + ) + parser.add_argument( + "-o", + "--output-dir", + type=str, + default="models", + help="The directory where to save the exported model", + ) + parser.add_argument( + "-s", + "--sequence-lenght", + type=int, + default=2048, + help="The sequence lenght to use for the dataloader", + ) + parser.add_argument( + "-a", + "--algorithm", + type=str, + default="RTN", + help="The quantization algorithm to use", + ) + return parser.parse_args() + + +def export_model( + model_name: str, + algorithm: str, + bits: int = 4, + sequence_lenght: int = 2048, + output_dir: str = "models", +): + """Quantize and export a model. + + Args: + model_name (str): the name of the model to export + algorithm (str, optional): the neural compressor quantization algorithm + bits (int, optional): the number of bits. Defaults to 4. + sequence_lenght (int, optional): the model sequence lenght. Defaults to 2048. + output_dir (str, optional): the output directory. Defaults to "models". + """ + print(f"Exporting model {model_name} with {bits} bits") + output_folder = os.path.join(output_dir, model_name, algorithm, f"int{bits}") + os.makedirs(output_folder, exist_ok=True) + + float_model = AutoModelForCausalLM.from_pretrained(model_name) + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + float_model.config.save_pretrained(output_folder) + tokenizer.save_pretrained(output_folder) + + dataloader = get_dataloader(tokenizer, seqlen=sequence_lenght) + + woq_conf = PostTrainingQuantConfig( + approach="weight_only", + op_type_dict={ + ".*": { # match all ops + "weight": { + "dtype": "int", + "bits": bits, + "group_size": -1, + "scheme": "sym", + "algorithm": algorithm.upper(), + }, + "activation": { + "dtype": "fp16", + }, + } + }, + ) + + print("Apply generic model optimizations") + npu_lib.compiler.apply_general_optimizations(float_model) + print("Quantize model") + quantized_model = fit(model=float_model, conf=woq_conf, calib_dataloader=dataloader) + print("Quantize model") + compressed_model = quantized_model.export_compressed_model( + scale_dtype=torch.float16, use_optimum_format=False + ) + + print("Create NPU kernels") + npu_model = npu_lib.compiler.create_npu_kernels(compressed_model) + + torch.save(npu_model, os.path.join(output_folder, "pytorch_npu_model.bin")) + print(f"Model succesfully exported to {output_folder}") + + +if __name__ == "__main__": + args = define_and_parse_arguments() + export_model( + args.model, args.algorithm, args.bits, args.sequence_lenght, args.output_dir + ) diff --git a/src/bindings.cpp b/src/bindings.cpp index ae576b1..7952fcd 100644 --- a/src/bindings.cpp +++ b/src/bindings.cpp @@ -198,4 +198,4 @@ intel_npu_acceleration_library_DLL_API ov::op::Op* scaled_dot_product_attention( ov::op::Op* attn_mask, bool is_causal) { return factory->scaled_dot_product_attention(query, key, value, attn_mask, is_causal); } -}; \ No newline at end of file +} \ No newline at end of file