Skip to content
This repository was archived by the owner on Oct 16, 2023. It is now read-only.

Link TensorRT as backend for single device execution #82

Merged
merged 1 commit into from
May 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion energonai/context/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
'backend':"nccl",
'rm_padding': False,
'seed' : 1024,
'verbose' : True
'verbose' : True,
'trt_sample' : None
}


Expand Down
31 changes: 20 additions & 11 deletions energonai/engine/rpc_worker.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import os
import time
import torch
import inspect
import torch.distributed.rpc as rpc
import sys

from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.logging import get_dist_logger

from .rpc_utils import remote_cls_method, sync_cls_method, async_cls_method
from .pipeline_wrapper import PipelineCommWrapper
from .vit_pipeline_wrapper import ViTPipelineCommWrapper

# from torch2trt import torch2trt
from energonai.context import mcfg

logger = get_dist_logger('energonai')

Expand All @@ -39,9 +34,12 @@ def top(self, key):
return output




class RPCWorker:

def __init__(self, model_class, model_config, model_type, dtype, max_batch_size: int = 1) -> None:

self.model_class = model_class
self.model_config = model_config
self.dtype = dtype
Expand All @@ -55,7 +53,7 @@ def __init__(self, model_class, model_config, model_type, dtype, max_batch_size:

# self.trt_sample = None
self._init_self()
self.return_dict = ReturnDict()
self.return_dict = ReturnDict()

def _init_self(self):
logger.info("Init model in rank {}".format(self.rank))
Expand All @@ -67,10 +65,21 @@ def _init_self(self):

self.model.eval()

# if trt_sample is not None and gpc.get_world_size(ParallelMode.MODEL) > 1:
# logger.error("Tensor Parallelism does not support TensorRT convert")
# elif trt_sample is not None and gpc.get_world_size(ParallelMode.MODEL) == 1:
# model = torch2trt(model, [self.trt_sample])
if mcfg['trt_sample'] is not None:
try:
logger.info('Import Torch2Trt')
from torch2trt import torch2trt
from energonai.engine import trt_converter
except:
logger.error("Installation Required, \n \
follow https://docs.nvidia.com/deeplearning/tensorrt/install-guide/index.html \
and https://github.com/NVIDIA-AI-IOT/torch2trt")

if mcfg['trt_sample'] is not None and gpc.get_world_size(ParallelMode.MODEL) > 1:
logger.error("Tensor Parallelism does not support TensorRT convert")
elif mcfg['trt_sample'] is not None and gpc.get_world_size(ParallelMode.MODEL) == 1:
self.model = torch2trt(self.model, mcfg['trt_sample'])
logger.info("TensorRT convert complete.")

try:
self.model = pipe_wrapper[self.model_type](model=self.model, max_batch_size=self.max_batch_size, dtype=self.dtype)
Expand Down
10 changes: 10 additions & 0 deletions energonai/engine/trt_converter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from torch2trt.torch2trt import *

@tensorrt_converter('torch.matmul')
def convert_mul(ctx):
input_a = ctx.method_args[0]
input_b = ctx.method_args[1]
input_a_trt, input_b_trt = trt_(ctx.network, input_a, input_b)
output = ctx.method_return
layer = ctx.network.add_matrix_multiply(input_a_trt, trt.MatrixOperation.NONE, input_b_trt, trt.MatrixOperation.NONE)
output._trt = layer.get_output(0)
Loading