This repository was archived by the owner on Apr 24, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 79
Support for Phi-3 MLP layer #84
Merged
alessandropalla
merged 26 commits into
intel:main
from
SarahByrneIntel:sarah/feature/phi3MLP_layer
Jul 19, 2024
+421
−64
Merged
Changes from 17 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
22c627c
Add support for phi-3 MLP layer
ea4b27a
Updating support for Phi-3 MLP
39c070c
Update for Phi-3 MLP testing
2042fab
Merge branch 'main' into sarah/feature/phi3MLP_layer
5660cc3
Merge branch 'intel:main' into sarah/feature/phi3MLP_layer
SarahByrneIntel 727454e
Update for phi-3 mlp layer
00a64f0
Merge branch 'sarah/feature/phi3MLP_layer' of https://github.com/Sara…
100fe88
Merge branch 'intel:main' into sarah/feature/phi3MLP_layer
SarahByrneIntel ea4ea19
Remove old code for phi-3 mlp layer
53c7b0d
Merge branch 'sarah/feature/phi3MLP_layer' of https://github.com/Sara…
1fef8a4
Add type tensor op and quantisation support
cc5d373
add support for model quantisation and code clean up
ff47c1d
Merge branch 'main' into sarah/feature/phi3MLP_layer
SarahByrneIntel d2fe9fe
Fix for model quantization
b7825e7
Add testing for phi-3 mlp quantisation
c652859
Add phi-3 mlp test and enable model profiling toggling
786c663
Update for model profiling toggle
003d639
Add compile config feature
c63c223
Fix test for compile config and remove old code
e652eaa
Fix tests with compile config
7f2faf9
Fix for compiler, updates for tests and examples, doc update
4b5f857
Update for model examples and remove test code
2718e13
Merge branch 'main' into sarah/feature/phi3MLP_layer
alessandropalla ae1fd61
Fix for quantization and remove unused code
5d578a1
Merge branch 'sarah/feature/phi3MLP_layer' of https://github.com/Sara…
2890299
Update for quantization of a model
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# | ||
# Copyright © 2024 Intel Corporation | ||
# SPDX-License-Identifier: Apache 2.0 | ||
# | ||
|
||
from transformers.models.phi3.modeling_phi3 import Phi3Config, Phi3MLP | ||
from intel_npu_acceleration_library.dtypes import int8, int4 | ||
from torch.profiler import profile, ProfilerActivity | ||
from sklearn.metrics import r2_score | ||
import intel_npu_acceleration_library | ||
import argparse | ||
import torch | ||
import numpy as np | ||
|
||
|
||
def main( | ||
seq_len=128, | ||
hidden_size=256, | ||
intermediate_size=512, | ||
dtype="float16", | ||
_profile=False, | ||
): | ||
|
||
conf = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct") | ||
conf.num_hidden_layers = 1 | ||
conf.hidden_size = hidden_size | ||
conf.intermediate_size = intermediate_size | ||
|
||
# Define a single Phi-3 MLP layer | ||
mlp = Phi3MLP(conf) | ||
|
||
hidden_states = torch.rand((seq_len, conf.hidden_size)) | ||
|
||
reference = mlp(hidden_states.to(torch.float32)).to(torch.float16) | ||
|
||
if dtype == "float16": | ||
dtype = torch.float16 | ||
elif dtype == "int8": | ||
dtype = int8 | ||
elif dtype == "int4": | ||
dtype = int4 | ||
else: | ||
raise RuntimeError(f"Invalid dtype: {dtype}") | ||
|
||
# Compile model | ||
model = intel_npu_acceleration_library.compile(mlp, dtype) | ||
if _profile: | ||
model.profile = True | ||
|
||
with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof: | ||
for _ in range(1000): | ||
results = model(hidden_states) | ||
|
||
print( | ||
prof.key_averages(group_by_input_shape=True).table( | ||
sort_by="cpu_time_total", row_limit=20 | ||
) | ||
) | ||
|
||
prof.export_chrome_trace("trace.json") | ||
|
||
results = results.detach().numpy() | ||
reference = reference.detach().numpy() | ||
|
||
assert results.shape == reference.shape, "Output shape mismatch" | ||
assert np.isfinite(reference).all(), "Pytorch Reference contains NaN or Inf" | ||
assert np.isfinite(results).all(), "NPU output contains NaN or Inf" | ||
|
||
if dtype == int4: | ||
assert 1 - r2_score(reference, results) < 0.05 | ||
else: | ||
assert 1 - r2_score(reference, results) < 0.001 | ||
|
||
|
||
def define_and_parse_args(): | ||
parser = argparse.ArgumentParser(description="Profiling a MLP layer in the NPU") | ||
parser.add_argument( | ||
"--seq-len", | ||
type=int, | ||
default=128, | ||
help="Sequence length (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"--hidden-size", | ||
type=int, | ||
default=256, | ||
help="Hidden size (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"--intermediate-size", | ||
type=int, | ||
default=512, | ||
help="Intermediate size (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"--dtype", | ||
default="float16", | ||
choices=["float16", "int8", "int4"], | ||
help="Select the target dtype (default: %(default)s)", | ||
) | ||
parser.add_argument( | ||
"--profile", | ||
action="store_true", | ||
default=False, | ||
help="Enable the profiling (default: False)", | ||
) | ||
|
||
return parser.parse_args() | ||
|
||
|
||
if __name__ == "__main__": | ||
args = define_and_parse_args() | ||
|
||
print( | ||
f"Profiling with sequence length {args.seq_len}, hidden size {args.hidden_size}, intermediate size {args.intermediate_size}, dtype {args.dtype}" | ||
) | ||
|
||
main( | ||
seq_len=args.seq_len, | ||
hidden_size=args.hidden_size, | ||
intermediate_size=args.intermediate_size, | ||
dtype=args.dtype, | ||
_profile=args.profile, | ||
) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why there is a specific branch about Phi3MLP?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If only a single mlp block is passed in to be compiled, we don't want to pass it to the recursive function as it will break it down into the layers. When the block is contained within a larger model, then it is the model that is broken down and we can prevent the blocks being broken down through the NPUModuleWrapper check. However, this won't happen if it is only a single block