diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index 6be97d42d9..a68c08eb79 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -7,9 +7,9 @@ import torch import torch._dynamo as td import torch.utils._pytree as pytree +import torch_tensorrt from torch._dynamo.utils import detect_fake_mode from torch._functorch.aot_autograd import _aot_export_function -from torch._inductor.constant_folding import ConstantFolder, replace_node_with_constant from torch._ops import OpOverload from torch_tensorrt.dynamo import CompilationSettings from torch_tensorrt.dynamo.compile import compile_module @@ -17,6 +17,17 @@ from torch_tensorrt.dynamo.lowering._pre_aot_lowering import pre_aot_substitutions from torch_tensorrt.dynamo.utils import parse_dynamo_kwargs +from packaging import version + +# Modify import location of utilities based on Torch version +if version.parse(torch_tensorrt.sanitized_torch_version()) <= version.parse("2.1.0"): + from torch._inductor.freezing import ConstantFolder, replace_node_with_constant +else: + from torch._inductor.constant_folding import ( + ConstantFolder, + replace_node_with_constant, + ) + logger = logging.getLogger(__name__)