Logo Questions Linux Laravel Mysql Ubuntu Git Menu

Inference with TensorRT .engine file on python

I used Nvidia's Transfer Learning Toolkit(TLT) to train and then used the tlt-converter to convert the .etlt model into an .engine file.

I want to use this .engine file for inference in python. But since I trained using TLT I dont have any frozen graphs or pb files which is what all the TensorRT inference tutorials need.

I would like to know if python inference is possible on .engine files. If not, what are the supported conversions(UFF,ONNX) to make this possible?

like image 999
Sharan Avatar asked Dec 11 '19 07:12


People also ask

How do I run a TensorRT model in python?

The tutorial consists of the following steps: Setup–launch the test container, and generate the TensorRT engine from a PyTorch model exported to ONNX and converted using trtexec. C++ runtime API–run inference using engine and TensorRT's C++ API. Python runtime AP–run inference using engine and TensorRT's Python API.

What is TensorRT plan file?

plan file is a serialized file format of the TensorRT engine. The plan file needs to be deserialized to run inference using the TensorRT runtime.

1 Answers

Python inference is possible via .engine files. Example below loads a .trt file (literally same thing as an .engine file) from disk and performs single inference.

In this project, I've converted an ONNX model to TRT model using onnx2trt executable before using it. You can even convert a PyTorch model to TRT using ONNX as a middleware.

import tensorrt as trt
import numpy as np
import os

import pycuda.driver as cuda
import pycuda.autoinit

class HostDeviceMem(object):
    def __init__(self, host_mem, device_mem):
        self.host = host_mem
        self.device = device_mem

    def __str__(self):
        return "Host:\n" + str(self.host) + "\nDevice:\n" + str(self.device)

    def __repr__(self):
        return self.__str__()

class TrtModel:
    def __init__(self,engine_path,max_batch_size=1,dtype=np.float32):
        self.engine_path = engine_path
        self.dtype = dtype
        self.logger = trt.Logger(trt.Logger.WARNING)
        self.runtime = trt.Runtime(self.logger)
        self.engine = self.load_engine(self.runtime, self.engine_path)
        self.max_batch_size = max_batch_size
        self.inputs, self.outputs, self.bindings, self.stream = self.allocate_buffers()
        self.context = self.engine.create_execution_context()

    def load_engine(trt_runtime, engine_path):
        trt.init_libnvinfer_plugins(None, "")             
        with open(engine_path, 'rb') as f:
            engine_data = f.read()
        engine = trt_runtime.deserialize_cuda_engine(engine_data)
        return engine
    def allocate_buffers(self):
        inputs = []
        outputs = []
        bindings = []
        stream = cuda.Stream()
        for binding in self.engine:
            size = trt.volume(self.engine.get_binding_shape(binding)) * self.max_batch_size
            host_mem = cuda.pagelocked_empty(size, self.dtype)
            device_mem = cuda.mem_alloc(host_mem.nbytes)

            if self.engine.binding_is_input(binding):
                inputs.append(HostDeviceMem(host_mem, device_mem))
                outputs.append(HostDeviceMem(host_mem, device_mem))
        return inputs, outputs, bindings, stream
    def __call__(self,x:np.ndarray,batch_size=2):
        x = x.astype(self.dtype)
        for inp in self.inputs:
            cuda.memcpy_htod_async(inp.device, inp.host, self.stream)
        self.context.execute_async(batch_size=batch_size, bindings=self.bindings, stream_handle=self.stream.handle)
        for out in self.outputs:
            cuda.memcpy_dtoh_async(out.host, out.device, self.stream) 
        return [out.host.reshape(batch_size,-1) for out in self.outputs]

if __name__ == "__main__":
    batch_size = 1
    trt_engine_path = os.path.join("..","models","main.trt")
    model = TrtModel(trt_engine_path)
    shape = model.engine.get_binding_shape(0)

    data = np.random.randint(0,255,(batch_size,*shape[1:]))/255
    result = model(data,batch_size)

Stay safe y'all!

like image 139
O Vuruşkaner Avatar answered Sep 28 '22 18:09

O Vuruşkaner