-
Notifications
You must be signed in to change notification settings - Fork 375
Added refitting acceleration #2983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
zewenli98
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
8bbd573 to
6f3142b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_conversion.py 2024-08-08 20:53:00.452273+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_conversion.py 2024-08-08 20:54:40.434855+00:00
@@ -167,7 +167,7 @@
serialized_engine=interpreter_result.serialized_engine,
input_binding_names=list(interpreter_result.input_names),
output_binding_names=list(interpreter_result.output_names),
name=name,
settings=settings,
- weight_name_map = weight_name_map
+ weight_name_map=weight_name_map,
)
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-08-08 20:53:00.452273+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-08-08 20:54:40.911400+00:00
@@ -502,11 +502,13 @@
with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()
- return TRTInterpreterResult(engine_str, self._input_names, self._output_names, self.weight_name_map)
+ return TRTInterpreterResult(
+ engine_str, self._input_names, self._output_names, self.weight_name_map
+ )
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
self._cur_node_name = get_node_name(n)
self._cur_node = n
# add "_itensor_to_tensor_meta"
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py 2024-08-08 20:53:00.456273+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py 2024-08-08 20:54:41.476969+00:00
@@ -143,12 +143,11 @@
TorchTensorRTModule._pack_binding_names(self.output_binding_names),
str(int(self.hardware_compatible)),
self.encode_metadata(metadata),
]
)
-
-
+
def encode_metadata(self, metadata: Any) -> str:
metadata = copy.deepcopy(metadata)
metadata["settings"].torch_executed_ops = {
f"torch.ops.{op.__str__()}"
for op in metadata["settings"].torch_executed_opsThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-08-08 20:59:52.444408+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-08-08 21:01:37.564015+00:00
@@ -502,11 +502,13 @@
with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()
- return TRTInterpreterResult(engine_str, self._input_names, self._output_names, self.weight_name_map)
+ return TRTInterpreterResult(
+ engine_str, self._input_names, self._output_names, self.weight_name_map
+ )
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
self._cur_node_name = get_node_name(n)
self._cur_node = n
# add "_itensor_to_tensor_meta"
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py 2024-08-08 20:59:52.452408+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py 2024-08-08 21:01:38.143764+00:00
@@ -143,12 +143,11 @@
TorchTensorRTModule._pack_binding_names(self.output_binding_names),
str(int(self.hardware_compatible)),
self.encode_metadata(metadata),
]
)
-
-
+
def encode_metadata(self, metadata: Any) -> str:
metadata = copy.deepcopy(metadata)
metadata["settings"].torch_executed_ops = {
f"torch.ops.{op.__str__()}"
for op in metadata["settings"].torch_executed_ops8927b0c to
b054fbc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-08-08 21:05:58.675792+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py 2024-08-08 21:09:37.161892+00:00
@@ -502,11 +502,13 @@
with io.BytesIO() as engine_bytes:
engine_bytes.write(serialized_engine)
engine_str = engine_bytes.getvalue()
- return TRTInterpreterResult(engine_str, self._input_names, self._output_names, self.weight_name_map)
+ return TRTInterpreterResult(
+ engine_str, self._input_names, self._output_names, self.weight_name_map
+ )
def run_node(self, n: torch.fx.Node) -> torch.fx.Node:
self._cur_node_name = get_node_name(n)
self._cur_node = n
# add "_itensor_to_tensor_meta"
--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py 2024-08-08 21:05:58.679792+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/runtime/_TorchTensorRTModule.py 2024-08-08 21:09:37.741137+00:00
@@ -143,12 +143,11 @@
TorchTensorRTModule._pack_binding_names(self.output_binding_names),
str(int(self.hardware_compatible)),
self.encode_metadata(metadata),
]
)
-
-
+
def encode_metadata(self, metadata: Any) -> str:
metadata = copy.deepcopy(metadata)
metadata["settings"].torch_executed_ops = {
f"torch.ops.{op.__str__()}"
for op in metadata["settings"].torch_executed_ops66c99a4 to
6588edb
Compare
c242a15 to
b3aa04f
Compare
narendasan
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| weight_name_map=interpreter_result.weight_name_map, | ||
| ) | ||
| except AssertionError: | ||
| logger.warning("Fast refit test failed. Removing the weight map caching.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where's the operation that you remove the weight map caching?
| """ | ||
|
|
||
| def find_weight( | ||
| weight_name: str, np_map: dict[str, Any], sd: dict[str, Any] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does np_map mean?
Added refit acceleration to existing refit pipeline.
During the first time of compilation, the interpreter will cache the weight name mapping between weights in TRT engine and weights in state_dict. The compiler then will do a tentative refit to test whether fast refit is success or not. If not, the caching will be removed. Later on, during refitting, if this mapping cache is detected, the re-interpretation of the module is skipped.
If the fast refit fails, the refitter falls back to the regular refit, which re-interprets the module and does refitting accordingly.
Checklist: