Skip to content

Commit 62c9588

Browse files
authored
Merge branch 'main' into enum_improvements
2 parents cf5cc93 + 9f46d39 commit 62c9588

36 files changed

+1239
-87
lines changed

.github/workflows/build-test-linux.yml

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ jobs:
6666
pre-script: packaging/pre_build_script.sh
6767
post-script: packaging/post_build_script.sh
6868
smoke-test-script: packaging/smoke_test_script.sh
69-
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
69+
uses: ./.github/workflows/linux-test.yml
7070
with:
7171
job-name: tests-py-torchscript-fe
7272
repository: "pytorch/tensorrt"
@@ -101,7 +101,7 @@ jobs:
101101
pre-script: packaging/pre_build_script.sh
102102
post-script: packaging/post_build_script.sh
103103
smoke-test-script: packaging/smoke_test_script.sh
104-
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
104+
uses: ./.github/workflows/linux-test.yml
105105
with:
106106
job-name: tests-py-dynamo-converters
107107
repository: "pytorch/tensorrt"
@@ -129,7 +129,7 @@ jobs:
129129
pre-script: packaging/pre_build_script.sh
130130
post-script: packaging/post_build_script.sh
131131
smoke-test-script: packaging/smoke_test_script.sh
132-
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
132+
uses: ./.github/workflows/linux-test.yml
133133
with:
134134
job-name: tests-py-dynamo-fe
135135
repository: "pytorch/tensorrt"
@@ -158,7 +158,7 @@ jobs:
158158
pre-script: packaging/pre_build_script.sh
159159
post-script: packaging/post_build_script.sh
160160
smoke-test-script: packaging/smoke_test_script.sh
161-
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
161+
uses: ./.github/workflows/linux-test.yml
162162
with:
163163
job-name: tests-py-dynamo-serde
164164
repository: "pytorch/tensorrt"
@@ -186,7 +186,7 @@ jobs:
186186
pre-script: packaging/pre_build_script.sh
187187
post-script: packaging/post_build_script.sh
188188
smoke-test-script: packaging/smoke_test_script.sh
189-
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
189+
uses: ./.github/workflows/linux-test.yml
190190
with:
191191
job-name: tests-py-torch-compile-be
192192
repository: "pytorch/tensorrt"
@@ -216,7 +216,7 @@ jobs:
216216
pre-script: packaging/pre_build_script.sh
217217
post-script: packaging/post_build_script.sh
218218
smoke-test-script: packaging/smoke_test_script.sh
219-
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@main
219+
uses: ./.github/workflows/linux-test.yml
220220
with:
221221
job-name: tests-py-dynamo-core
222222
repository: "pytorch/tensorrt"
@@ -246,7 +246,7 @@ jobs:
246246
pre-script: packaging/pre_build_script.sh
247247
post-script: packaging/post_build_script.sh
248248
smoke-test-script: packaging/smoke_test_script.sh
249-
uses: pytorch/tensorrt/.github/workflows/linux-test.yml@release/2.3
249+
uses: ./.github/workflows/linux-test.yml
250250
with:
251251
job-name: tests-py-core
252252
repository: "pytorch/tensorrt"

core/runtime/TRTEngine.cpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,16 @@ TRTEngine::TRTEngine(
3333
const RTDevice& cuda_device,
3434
const std::vector<std::string>& _in_binding_names,
3535
const std::vector<std::string>& _out_binding_names,
36-
bool hardware_compatible)
36+
bool hardware_compatible,
37+
const std::string& serialized_metadata)
3738
: TRTEngine(
3839
"deserialized_trt",
3940
serialized_engine,
4041
cuda_device,
4142
_in_binding_names,
4243
_out_binding_names,
43-
hardware_compatible) {}
44+
hardware_compatible,
45+
serialized_metadata) {}
4446

4547
TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
4648
: TRTEngine(
@@ -49,17 +51,19 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
4951
RTDevice(serialized_info[DEVICE_IDX]),
5052
split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM),
5153
split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM),
52-
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX]))) {}
54+
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
55+
serialized_info[SERIALIZED_METADATA_IDX]) {}
5356

5457
TRTEngine::TRTEngine(
5558
const std::string& mod_name,
5659
const std::string& serialized_engine,
5760
const RTDevice& cuda_device,
5861
const std::vector<std::string>& _in_binding_names,
5962
const std::vector<std::string>& _out_binding_names,
60-
bool hardware_compatible) {
63+
bool hardware_compatible,
64+
const std::string& serialized_metadata) {
6165
this->hardware_compatible = hardware_compatible;
62-
66+
this->serialized_metadata = serialized_metadata;
6367
auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible);
6468
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");
6569
device_info = most_compatible_device.value();

core/runtime/TRTEngine.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,26 @@ struct TRTEngine : torch::CustomClassHolder {
3535
std::vector<std::string> out_binding_names = {}; // ITO: PYT IDX
3636

3737
bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode
38+
std::string serialized_metadata; // This is a base64 encoded pkl object used to store metadata such as settings used
39+
// in compilation
3840

3941
~TRTEngine();
4042
TRTEngine(
4143
const std::string& serialized_engine,
4244
const RTDevice& cuda_device,
4345
const std::vector<std::string>& in_binding_names,
4446
const std::vector<std::string>& out_binding_names,
45-
bool hardware_compatible = false);
47+
bool hardware_compatible = false,
48+
const std::string& serialized_metadata = "");
4649
TRTEngine(std::vector<std::string> serialized_info);
4750
TRTEngine(
4851
const std::string& mod_name,
4952
const std::string& serialized_engine,
5053
const RTDevice& cuda_device,
5154
const std::vector<std::string>& in_binding_names,
5255
const std::vector<std::string>& out_binding_names,
53-
bool hardware_compatible = false);
56+
bool hardware_compatible = false,
57+
const std::string& serialized_metadata = "");
5458
TRTEngine& operator=(const TRTEngine& other);
5559
std::string to_str() const;
5660
static void verify_serialization_fmt(const std::vector<std::string>& serialized_info);

core/runtime/register_jit_hooks.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =
102102
serialize_info[INPUT_BINDING_NAMES_IDX] = serialize_bindings(self->in_binding_names);
103103
serialize_info[OUTPUT_BINDING_NAMES_IDX] = serialize_bindings(self->out_binding_names);
104104
serialize_info[HW_COMPATIBLE_IDX] = self->hardware_compatible ? "1" : "0";
105-
105+
serialize_info[SERIALIZED_METADATA_IDX] = self->serialized_metadata;
106106
LOG_DEBUG("Serialized Hardware Compatibility: " << (self->hardware_compatible ? "Enabled" : "Disabled"));
107107

108108
return serialize_info;
@@ -127,6 +127,15 @@ TORCH_LIBRARY(tensorrt, m) {
127127
});
128128
m.def(
129129
"get_logging_level", []() -> int64_t { return int64_t(util::logging::get_logger().get_reportable_log_level()); });
130+
m.def("ABI_TARGET_IDX", []() -> int64_t { return ABI_TARGET_IDX; });
131+
m.def("NAME_IDX", []() -> int64_t { return NAME_IDX; });
132+
m.def("DEVICE_IDX", []() -> int64_t { return DEVICE_IDX; });
133+
m.def("ENGINE_IDX", []() -> int64_t { return ENGINE_IDX; });
134+
m.def("INPUT_BINDING_NAMES_IDX", []() -> int64_t { return INPUT_BINDING_NAMES_IDX; });
135+
m.def("OUTPUT_BINDING_NAMES_IDX", []() -> int64_t { return OUTPUT_BINDING_NAMES_IDX; });
136+
m.def("HW_COMPATIBLE_IDX", []() -> int64_t { return HW_COMPATIBLE_IDX; });
137+
m.def("SERIALIZED_METADATA_IDX", []() -> int64_t { return SERIALIZED_METADATA_IDX; });
138+
m.def("SERIALIZATION_LEN", []() -> int64_t { return SERIALIZATION_LEN; });
130139
}
131140

132141
} // namespace

core/runtime/runtime.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ typedef enum {
2525
INPUT_BINDING_NAMES_IDX,
2626
OUTPUT_BINDING_NAMES_IDX,
2727
HW_COMPATIBLE_IDX,
28+
SERIALIZED_METADATA_IDX,
2829
SERIALIZATION_LEN, // NEVER USED FOR DATA, USED TO DETERMINE LENGTH OF SERIALIZED INFO
2930
} SerializedInfoIndex;
3031

docsrc/py_api/dynamo.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ Functions
2424

2525
.. autofunction:: convert_module_to_trt_engine
2626

27-
27+
.. autofunction:: refit_module_weights
2828

2929

3030
Classes

examples/dynamo/README.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@ a number of ways you can leverage this backend to accelerate inference.
1111
* :ref:`torch_compile_advanced_usage`: Advanced usage including making a custom backend to use directly with the ``torch.compile`` API
1212
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
1313
* :ref:`custom_kernel_plugins`: Creating a plugin to use a custom kernel inside TensorRT engines
14+
* :ref:`refit_engine_example`: Refitting a compiled TensorRT Graph Module with updated weights
1415
* :ref:`vgg16_fp8_ptq`: Compiling a VGG16 model with FP8 and PTQ using ``torch.compile``
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""
2+
.. _refit_engine_example:
3+
4+
Refit TenorRT Graph Module with Torch-TensorRT
5+
===================================================================
6+
7+
We are going to demonstrate how a compiled TensorRT Graph Module can be refitted with updated weights.
8+
9+
In many cases, we frequently update the weights of models, such as applying various LoRA to Stable Diffusion or constant A/B testing of AI products.
10+
That poses challenges for TensorRT inference optimizations, as compiling the TensorRT engines takes significant time, making repetitive compilation highly inefficient.
11+
Torch-TensorRT supports refitting TensorRT graph modules without re-compiling the engine, considerably accelerating the workflow.
12+
13+
In this tutorial, we are going to walk through
14+
1. Compiling a PyTorch model to a TensorRT Graph Module
15+
2. Save and load a graph module
16+
3. Refit the graph module
17+
"""
18+
19+
# %%
20+
# Standard Workflow
21+
# -----------------------------
22+
23+
# %%
24+
# Imports and model definition
25+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
26+
27+
import numpy as np
28+
import torch
29+
import torch_tensorrt as torch_trt
30+
import torchvision.models as models
31+
from torch_tensorrt.dynamo import refit_module_weights
32+
33+
np.random.seed(0)
34+
torch.manual_seed(0)
35+
inputs = [torch.rand((1, 3, 224, 224)).to("cuda")]
36+
37+
38+
# %%
39+
# Compile the module for the first time and save it.
40+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
41+
42+
model = models.resnet18(pretrained=False).eval().to("cuda")
43+
exp_program = torch.export.export(model, tuple(inputs))
44+
enabled_precisions = {torch.float}
45+
debug = False
46+
workspace_size = 20 << 30
47+
min_block_size = 0
48+
use_python_runtime = False
49+
torch_executed_ops = {}
50+
trt_gm = torch_trt.dynamo.compile(
51+
exp_program,
52+
tuple(inputs),
53+
use_python_runtime=use_python_runtime,
54+
enabled_precisions=enabled_precisions,
55+
debug=debug,
56+
min_block_size=min_block_size,
57+
torch_executed_ops=torch_executed_ops,
58+
make_refitable=True,
59+
) # Output is a torch.fx.GraphModule
60+
61+
# Save the graph module as an exported program
62+
# This is only supported when use_python_runtime = False
63+
torch_trt.save(trt_gm, "./compiled.ep", inputs=inputs)
64+
65+
66+
# %%
67+
# Refit the module with update model weights
68+
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
69+
70+
# Create and compile the updated model
71+
model2 = models.resnet18(pretrained=True).eval().to("cuda")
72+
exp_program2 = torch.export.export(model2, tuple(inputs))
73+
74+
75+
compiled_trt_ep = torch_trt.load("./compiled.ep")
76+
77+
# This returns a new module with updated weights
78+
new_trt_gm = refit_module_weights(
79+
compiled_module=compiled_trt_ep,
80+
new_weight_module=exp_program2,
81+
inputs=inputs,
82+
)
83+
84+
# Check the output
85+
expected_outputs, refitted_outputs = exp_program2.module()(*inputs), new_trt_gm(*inputs)
86+
for expected_output, refitted_output in zip(expected_outputs, refitted_outputs):
87+
assert torch.allclose(
88+
expected_output, refitted_output, 1e-2, 1e-2
89+
), "Refit Result is not correct. Refit failed"
90+
91+
print("Refit successfully!")
92+
93+
# %%
94+
# Alterative Workflow using Python Runtime
95+
# -----------------------------
96+
97+
# Currently python runtime does not support engine serialization. So the refitting will be done in the same runtime.
98+
# This usecase is more useful when you need to switch different weights in the same runtime, such as using Stable Diffusion.

py/torch_tensorrt/_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def convert_method_to_trt_engine(
353353

354354
return dynamo_convert_module_to_trt_engine(
355355
exp_program,
356-
inputs=inputs,
356+
inputs=tuple(inputs),
357357
enabled_precisions=enabled_precisions_set,
358358
**kwargs,
359359
)

py/torch_tensorrt/dynamo/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
if version.parse(sanitized_torch_version()) >= version.parse("2.1.dev"):
1010
from ._compiler import compile, convert_module_to_trt_engine
1111
from ._exporter import export
12+
from ._refit import refit_module_weights
1213
from ._settings import CompilationSettings
1314
from ._SourceIR import SourceIR
1415
from ._tracer import trace

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def compile(
5757
Set[torch.dtype | dtype] | Tuple[torch.dtype | dtype]
5858
) = _defaults.ENABLED_PRECISIONS,
5959
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
60-
refit: bool = _defaults.REFIT,
60+
make_refitable: bool = _defaults.MAKE_REFITABLE,
6161
debug: bool = _defaults.DEBUG,
6262
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
6363
workspace_size: int = _defaults.WORKSPACE_SIZE,
@@ -162,6 +162,17 @@ def compile(
162162
stacklevel=2,
163163
)
164164

165+
if "refit" in kwargs.keys():
166+
warnings.warn(
167+
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
168+
DeprecationWarning,
169+
stacklevel=2,
170+
)
171+
if make_refitable:
172+
raise ValueError("Use flag make_refitable only. Flag refit is deprecated.")
173+
else:
174+
make_refitable = kwargs["refit"]
175+
165176
engine_capability = EngineCapability._from(engine_capability)
166177

167178
if torch_executed_modules is not None and torch_executed_modules:
@@ -217,7 +228,7 @@ def compile(
217228
"require_full_compilation": require_full_compilation,
218229
"disable_tf32": disable_tf32,
219230
"sparse_weights": sparse_weights,
220-
"refit": refit,
231+
"make_refitable": make_refitable,
221232
"engine_capability": engine_capability,
222233
"dla_sram_size": dla_sram_size,
223234
"dla_local_dram_size": dla_local_dram_size,
@@ -477,7 +488,7 @@ def convert_module_to_trt_engine(
477488
require_full_compilation: bool = _defaults.REQUIRE_FULL_COMPILATION,
478489
disable_tf32: bool = _defaults.DISABLE_TF32,
479490
sparse_weights: bool = _defaults.SPARSE_WEIGHTS,
480-
refit: bool = _defaults.REFIT,
491+
make_refitable: bool = _defaults.MAKE_REFITABLE,
481492
engine_capability: EngineCapability = _defaults.ENGINE_CAPABILITY,
482493
num_avg_timing_iters: int = _defaults.NUM_AVG_TIMING_ITERS,
483494
dla_sram_size: int = _defaults.DLA_SRAM_SIZE,
@@ -562,6 +573,12 @@ def convert_module_to_trt_engine(
562573
DeprecationWarning,
563574
stacklevel=2,
564575
)
576+
if "refit" in kwargs.keys():
577+
warnings.warn(
578+
"Refit is deprecated. Please use make_refitable=True if you want to enable refitting of the engine.",
579+
DeprecationWarning,
580+
stacklevel=2,
581+
)
565582

566583
input_list = list(inputs) if inputs is not None else []
567584
torch_executed_ops = torch_executed_ops if torch_executed_ops is not None else set()
@@ -590,7 +607,7 @@ def convert_module_to_trt_engine(
590607
"require_full_compilation": require_full_compilation,
591608
"disable_tf32": disable_tf32,
592609
"sparse_weights": sparse_weights,
593-
"refit": refit,
610+
"make_refitable": make_refitable,
594611
"engine_capability": engine_capability,
595612
"num_avg_timing_iters": num_avg_timing_iters,
596613
"dla_sram_size": dla_sram_size,

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
USE_PYTHON_RUNTIME = False
2727
USE_FAST_PARTITIONER = True
2828
ENABLE_EXPERIMENTAL_DECOMPOSITIONS = False
29-
REFIT = False
29+
MAKE_REFITABLE = False
3030
REQUIRE_FULL_COMPILATION = False
3131
DRYRUN = False
3232
HARDWARE_COMPATIBLE = False

0 commit comments

Comments
 (0)