Skip to content
This repository was archived by the owner on Apr 24, 2025. It is now read-only.

Support for Phi-3 MLP layer #84

Merged
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
22c627c
Add support for phi-3 MLP layer
Jul 1, 2024
ea4b27a
Updating support for Phi-3 MLP
Jul 2, 2024
39c070c
Update for Phi-3 MLP testing
Jul 3, 2024
2042fab
Merge branch 'main' into sarah/feature/phi3MLP_layer
Jul 5, 2024
5660cc3
Merge branch 'intel:main' into sarah/feature/phi3MLP_layer
SarahByrneIntel Jul 5, 2024
727454e
Update for phi-3 mlp layer
Jul 8, 2024
00a64f0
Merge branch 'sarah/feature/phi3MLP_layer' of https://github.com/Sara…
Jul 8, 2024
100fe88
Merge branch 'intel:main' into sarah/feature/phi3MLP_layer
SarahByrneIntel Jul 8, 2024
ea4ea19
Remove old code for phi-3 mlp layer
Jul 8, 2024
53c7b0d
Merge branch 'sarah/feature/phi3MLP_layer' of https://github.com/Sara…
Jul 8, 2024
1fef8a4
Add type tensor op and quantisation support
Jul 12, 2024
cc5d373
add support for model quantisation and code clean up
Jul 15, 2024
ff47c1d
Merge branch 'main' into sarah/feature/phi3MLP_layer
SarahByrneIntel Jul 15, 2024
d2fe9fe
Fix for model quantization
Jul 15, 2024
b7825e7
Add testing for phi-3 mlp quantisation
Jul 16, 2024
c652859
Add phi-3 mlp test and enable model profiling toggling
Jul 17, 2024
786c663
Update for model profiling toggle
Jul 17, 2024
003d639
Add compile config feature
Jul 18, 2024
c63c223
Fix test for compile config and remove old code
Jul 18, 2024
e652eaa
Fix tests with compile config
Jul 18, 2024
7f2faf9
Fix for compiler, updates for tests and examples, doc update
Jul 18, 2024
4b5f857
Update for model examples and remove test code
Jul 18, 2024
2718e13
Merge branch 'main' into sarah/feature/phi3MLP_layer
alessandropalla Jul 19, 2024
ae1fd61
Fix for quantization and remove unused code
Jul 19, 2024
5d578a1
Merge branch 'sarah/feature/phi3MLP_layer' of https://github.com/Sara…
Jul 19, 2024
2890299
Update for quantization of a model
Jul 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions intel_npu_acceleration_library/backend/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -948,6 +948,18 @@ def to(self, dtype: NPUDtype) -> "Tensor":
"""
return generate_op([self], "to", dtype)

def type(self, dtype: NPUDtype) -> "Tensor":
"""
Convert the tensor to the specified data type.

Args:
dtype (NPUDtype): The data type to convert the tensor to.

Returns:
Tensor: The converted tensor.
"""
return self.to(dtype)

@classmethod
def __torch_function__(
cls: Any,
Expand Down
123 changes: 109 additions & 14 deletions intel_npu_acceleration_library/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
from intel_npu_acceleration_library.optimizations import horizontal_fusion_linear
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaAttention
from transformers.models.gemma.modeling_gemma import GemmaMLP, GemmaAttention
from transformers.models.phi3.modeling_phi3 import Phi3MLP
from neural_compressor.adaptor.torch_utils.model_wrapper import WeightOnlyLinear
from intel_npu_acceleration_library.quantization import quantize_model
from intel_npu_acceleration_library.dtypes import int8, int4
from intel_npu_acceleration_library.nn.module import NPUModuleWrapper
import intel_npu_acceleration_library.nn as nn
from torch._dynamo import register_backend
from typing import Union, Callable, Any
from typing import List
import torch
from functools import partial


def compile(
Expand All @@ -39,14 +42,26 @@ def compile(

# Prepare and optimize model for NPU
with torch.no_grad():
# General optimizations
apply_general_optimizations(model)
if dtype in (int8, int4):
# Quantize model
model = quantize_model(model, dtype)

# Model lowering to NPU ops
create_npu_kernels(model)
if isinstance(model, Phi3MLP):
# Apply optimizations to a single MLP block model
model = model

if dtype in (int8, int4):
# Quantize model
model = quantize_model(model, dtype)
weights_quantization(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why there is a specific branch about Phi3MLP?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If only a single mlp block is passed in to be compiled, we don't want to pass it to the recursive function as it will break it down into the layers. When the block is contained within a larger model, then it is the model that is broken down and we can prevent the blocks being broken down through the NPUModuleWrapper check. However, this won't happen if it is only a single block


else:
# General optimizations
apply_general_optimizations(model)

if dtype in (int8, int4):
# Quantize model
model = quantize_model(model, dtype)
weights_quantization(model)

create_npu_kernels(model)

if dtype.is_floating_point and training:
# Set model to evaluation only as quantized training is not supported yet
Expand All @@ -63,6 +78,7 @@ def apply_general_optimizations(model: torch.nn.Module):
"""
apply_horizontal_fusion(model)
optimize_llama_attention(model)
optimize_phi3_MLP(model)


def create_npu_kernels(model: torch.nn.Module):
Expand Down Expand Up @@ -95,13 +111,17 @@ def wrapper(model: torch.nn.Module, *args: Any, **kwargs: Any):
kwargs (Any): keyword arguments

"""
for name, layer in model.named_children():
new_layer = func(name, layer, *args, **kwargs)
if new_layer:
model.add_module(name, new_layer)
wrapper(new_layer, *args, **kwargs)
else:
wrapper(layer, *args, **kwargs)
if not isinstance(model, NPUModuleWrapper):
for name, layer in model.named_children():
new_layer = func(name, layer, *args, **kwargs)

if new_layer:
model.add_module(name, new_layer)
if not isinstance(new_layer, NPUModuleWrapper):
wrapper(new_layer, *args, **kwargs)
else:
if not isinstance(layer, NPUModuleWrapper):
wrapper(layer, *args, **kwargs)

return wrapper

Expand Down Expand Up @@ -174,6 +194,81 @@ def optimize_llama_attention(
return None


@module_optimization
def optimize_phi3_MLP(
name: str, layer: torch.nn.Module
) -> Union[torch.nn.Module, None]:
"""Optimize Phi-3 MLP block.

Args:
name (str): Module name
layer (torch.nn.Module): Original Module

Returns:
Union[torch.nn.Module, None]: optimized Phi-3 module
"""
if layer.__class__.__name__ == "Phi3MLP":
return layer
return None


@module_optimization
def weights_quantization(
name: str, layer: torch.nn.Module
) -> Union[torch.nn.Module, None]:
"""Apply weights quantization.

Args:
name (str): Layer name
layer (torch.nn.Module): Original torch.nn.Linear module

Raises:
RuntimeError: unsupported quantization bits

Returns:
None: Returns None
"""
if isinstance(layer, WeightOnlyLinear):
if (layer.bits == 4) or (layer.bits == 8):
layer.forward = partial(forward, layer)
else:
raise RuntimeError(f"Unsupported quantization bits: {layer.bits}")
return None


def forward(self, input):
"""Override forward method for WeightOnlyLinear class.

Args:
input: The input tensor.

Returns:
torch.Tensor: The output tensor.
"""
if self.bits == 4:
# Unpack the int4 values
lower_int4 = self.qweight & 0x0F
lower_int4 = lower_int4 - (lower_int4 & 0x8) * 2
upper_int4 = (self.qweight >> 4) & 0x0F
upper_int4 = upper_int4 - (upper_int4 & 0x8) * 2

w = torch.stack((lower_int4, upper_int4), dim=2)
w = w.contiguous().view(self.qweight.shape[0], -1)

elif self.bits == 8:
w = self.qweight.view(torch.int8)

output = (
torch.nn.functional.linear(input.to(torch.float16), w.to(torch.float16), None)
* self.scales.T
)

if self.bias:
return output + self.bias

return output


@register_backend
def npu(
gm: Union[torch.nn.Module, torch.fx.GraphModule], example_inputs: List[torch.Tensor]
Expand Down
18 changes: 13 additions & 5 deletions intel_npu_acceleration_library/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#
from intel_npu_acceleration_library.backend import NNFactory, Tensor
from typing import MutableMapping, Sequence, Any, List
from torch.profiler import record_function
import numpy as np
import torch

Expand Down Expand Up @@ -104,12 +105,17 @@ def patch_modules(module: torch.nn.Module, model: NNFactory):
class Module(torch.nn.Module):
"""A PyTorch module that runs on the NPU."""

def __init__(self) -> None:
"""Initialize the module."""
def __init__(self, profile: bool = False) -> None:
"""Initialize the module.

Args:
profile (bool): Enable model profiling. Defaults to False.
"""
super().__init__()
self._nn_factory_cache: MutableMapping[str, NNFactory] = {}
self._npu_inference = False
self.npu_top_level_module = True
self.profile = profile

def extract_tensors_from_arguments(
self, args: Sequence[Any]
Expand Down Expand Up @@ -170,7 +176,7 @@ def create_model(
Returns:
NNFactory: The model.
"""
model = NNFactory()
model = NNFactory(profile=self.profile)

def create_args_from_list(args: Sequence[Any]) -> Sequence[Any]:
"""Create arguments from a list.
Expand Down Expand Up @@ -249,7 +255,8 @@ def _call_impl(self, *args: Any, **kwargs: Any) -> Any:
# Run the model by replacing the forward method with the factory_forward
old_forward = self.forward
self.forward = self.factory_forward # type: ignore
out = super()._call_impl(*args, **kwargs)
with record_function(f"npu_{self.__class__.__name__}"):
out = super()._call_impl(*args, **kwargs)

# Restore the original forward method
self.forward = old_forward # type: ignore
Expand Down Expand Up @@ -322,7 +329,8 @@ def forward(self, *args, **kwargs) -> torch.Tensor:
Returns:
torch.Tensor: The output tensor.
"""
return self.module(*args, **kwargs)
with record_function(f"npu_{self.module.__class__.__name__}"):
return self.module(*args, **kwargs)


def convert_to_npu_module(module: torch.nn.Module) -> Module:
Expand Down
124 changes: 124 additions & 0 deletions script/profile_mlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
#
# Copyright © 2024 Intel Corporation
# SPDX-License-Identifier: Apache 2.0
#

from transformers.models.phi3.modeling_phi3 import Phi3Config, Phi3MLP
from intel_npu_acceleration_library.dtypes import int8, int4
from torch.profiler import profile, ProfilerActivity
from sklearn.metrics import r2_score
import intel_npu_acceleration_library
import argparse
import torch
import numpy as np


def main(
seq_len=128,
hidden_size=256,
intermediate_size=512,
dtype="float16",
_profile=False,
):

conf = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct")
conf.num_hidden_layers = 1
conf.hidden_size = hidden_size
conf.intermediate_size = intermediate_size

# Define a single Phi-3 MLP layer
mlp = Phi3MLP(conf)

hidden_states = torch.rand((seq_len, conf.hidden_size))

reference = mlp(hidden_states.to(torch.float32)).to(torch.float16)

if dtype == "float16":
dtype = torch.float16
elif dtype == "int8":
dtype = int8
elif dtype == "int4":
dtype = int4
else:
raise RuntimeError(f"Invalid dtype: {dtype}")

# Compile model
model = intel_npu_acceleration_library.compile(mlp, dtype)
if _profile:
model.profile = True

with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:
for _ in range(1000):
results = model(hidden_states)

print(
prof.key_averages(group_by_input_shape=True).table(
sort_by="cpu_time_total", row_limit=20
)
)

prof.export_chrome_trace("trace.json")

results = results.detach().numpy()
reference = reference.detach().numpy()

assert results.shape == reference.shape, "Output shape mismatch"
assert np.isfinite(reference).all(), "Pytorch Reference contains NaN or Inf"
assert np.isfinite(results).all(), "NPU output contains NaN or Inf"

if dtype == int4:
assert 1 - r2_score(reference, results) < 0.05
else:
assert 1 - r2_score(reference, results) < 0.001


def define_and_parse_args():
parser = argparse.ArgumentParser(description="Profiling a MLP layer in the NPU")
parser.add_argument(
"--seq-len",
type=int,
default=128,
help="Sequence length (default: %(default)s)",
)
parser.add_argument(
"--hidden-size",
type=int,
default=256,
help="Hidden size (default: %(default)s)",
)
parser.add_argument(
"--intermediate-size",
type=int,
default=512,
help="Intermediate size (default: %(default)s)",
)
parser.add_argument(
"--dtype",
default="float16",
choices=["float16", "int8", "int4"],
help="Select the target dtype (default: %(default)s)",
)
parser.add_argument(
"--profile",
action="store_true",
default=False,
help="Enable the profiling (default: False)",
)

return parser.parse_args()


if __name__ == "__main__":
args = define_and_parse_args()

print(
f"Profiling with sequence length {args.seq_len}, hidden size {args.hidden_size}, intermediate size {args.intermediate_size}, dtype {args.dtype}"
)

main(
seq_len=args.seq_len,
hidden_size=args.hidden_size,
intermediate_size=args.intermediate_size,
dtype=args.dtype,
_profile=args.profile,
)
Loading
Loading