目录

pytorch自定义算子转tensorrt

pytorch自定义算子转tensorrt

pytorch导出onnx模型

设有模型:

import torch
import torch.nn as nn


class MYPLUGINImpl(torch.autograd.Function):
    @staticmethod
    def symbolic(g, x, p):
        return g.op("MYPLUGIN", x, p,
                    g.op("Constant", value_t=torch.tensor([3, 2, 1], dtype=torch.int32)),
                    attr1_s="这是字符串属性",
                    attr2_i=[1, 2],
                    attr3_f=222
                    )
    @staticmethod
    def forward(ctx, x, a):
        return x + a


class MYPLUGIN(nn.Module):
    def __init__(self, n):
        super().__init__()
        self.param = nn.parameter.Parameter(torch.arange(n).float())
    def forward(self, x, a):
        return MYPLUGINImpl.apply(x, a)


class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 1, 3, padding=1)
        self.conv.weight.data.fill_(1)
        self.conv.bias.data.fill_(0)
        self.myplugin = MYPLUGIN(3)
    def forward(self, x, a):
        x = self.conv(x)
        x = self.myplugin(x, a)
        return x


model = Model().eval()
input = torch.tensor([
    [
        [1, 1, 1],
        [1, 1, 1],
        [1, 1, 1],
    ],
], dtype=torch.float32).view(1, 1, 3, 3)
a = torch.tensor(2, dtype=torch.int32)
output = model(input, a)
print(f"inference output = {output}")

torch.onnx.export(
    model, # 这里的args,是指输入给model的参数,需要传递tuple,因此用括号
    (input, a),
    "myplugin.onnx", # 储存的文件路径
    input_names=["image"], # 为输入和输出节点指定名称,方便后面查看或者操作
    output_names=["output"],
    opset_version=13,# 这里的opset,指,各类算子以何种方式导出,对应于symbolic_opset11
    operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
    #verbose=True,# 打印详细信息
    # 表示有batch、height、width3个维度是动态的,在onnx中给其赋值为-1,通常,我们只设置batch为动态,其他的避免动态
    # dynamic_axes={        
    #     "image": {
    #         0: "batch", 2: "height", 3: "width"},
    #     "output": {
    #         0: "batch", 2: "height", 3: "width"},
    # },
)

程序输出:

inference output = 
tensor([[[[ 6.,  8.,  6.],
          [ 8., 11.,  8.],
          [ 6.,  8.,  6.]]]], grad_fn=<MYPLUGINImplBackward>)

导出的onnx模型结构如下:https://i-blog.csdnimg.cn/direct/5203ced1a6bf4dee95b482080a33e5e3.png

编写tensorrt自定义算子

采用TensorRT-10.6.0.26。由于TensorRT是部分开源,首先在 下载TensorRT-10.6.0.26的库,然后在 下载源代码。
在TensorRT/plugin下新建myPlugin文件夹,添加下面文件:
myPlugin.h

#ifndef TRT_MYPLUGIN_H
#define TRT_MYPLUGIN_H

#include "NvInfer.h"
#include "NvInferPlugin.h"
#include "common/plugin.h"
#include <string>
#include <vector>

namespace nvinfer1
{
namespace plugin
{
class myPlugin : public nvinfer1::IPluginV2DynamicExt 
{          
public:
	myPlugin(const std::string name, const std::string attr1, float attr3);  // 接受算子名称属性,build engine时构造函数
	myPlugin(const std::string name, const void* data, size_t length);  // 接受算子名称和反序列化的engine data,推理时构造函数

	int getNbOutputs() const noexcept override;
	virtual nvinfer1::DataType getOutputDataType(int32_t index,
		nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override {           
		return inputTypes[0];
	}
	virtual nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex,
		const nvinfer1::DimsExprs* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override;

	int initialize() noexcept override;
	void terminate() noexcept override;

	virtual size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
		int32_t nbInputs, const nvinfer1::PluginTensorDesc* outputs,
		int32_t nbOutputs) const noexcept override {           
		return 0;
	};

	int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
		const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;

	size_t getSerializationSize() const noexcept override;
	void serialize(void* buffer)  const noexcept override;

	virtual void configurePlugin(const  nvinfer1::DynamicPluginTensorDesc* in, int32_t nbInputs,
		const  nvinfer1::DynamicPluginTensorDesc* out, int32_t nbOutputs) noexcept override;

	virtual bool supportsFormatCombination(int32_t pos, const nvinfer1::PluginTensorDesc* inOut, int32_t nbInputs,
		int32_t nbOutputs) noexcept override;

	const char* getPluginType() const noexcept override;
	const char* getPluginVersion() const noexcept override;

	void destroy() noexcept override;
	nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
	void setPluginNamespace(nvinfer1::AsciiChar const* pluginNamespace) noexcept override;
	const char* getPluginNamespace()const noexcept override;

private:
	const std::string mLayerName;
	std::string mattr1;
	float mattr3;
	size_t mInputVolume;
	std::string mNamespace;
};


class myPluginCreator : public nvinfer1::IPluginCreator 
{           
public:
	myPluginCreator();
	const char* getPluginName() const noexcept override;
	const char* getPluginVersion() const noexcept override;
	const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override;

	nvinfer1::IPluginV2* createPlugin(nvinfer1::AsciiChar const* name,
		nvinfer1::PluginFieldCollection const* fc) noexcept override;
	nvinfer1::IPluginV2* deserializePlugin(nvinfer1::AsciiChar const* name,
		void const* serialData, size_t serialLength)noexcept override;
	void setPluginNamespace(nvinfer1::AsciiChar const* pluginNamespace) noexcept override;
	const char* getPluginNamespace() const noexcept override;
private:
	static nvinfer1::PluginFieldCollection mfc;
	static std::vector<nvinfer1::PluginField> mPluginAttributes;
	std::string mNamespace;
};
}
}

#endif // TRT_MYPLUGIN_H

myPlugin.cpp

#include "myPlugin.h"
#include <NvInfer.h>
#include <cstring>
#include <vector>
#include <cassert>

using namespace nvinfer1;
using namespace nvinfer1::pluginInternal;
using nvinfer1::plugin::myPlugin;
using nvinfer1::plugin::myPluginCreator;

namespace nvinfer1
{
namespace plugin
{
void myselu_inference(const float* x, const int* a, float* output, int n, cudaStream_t stream);

// 静态类字段的初始化
nvinfer1::PluginFieldCollection myPluginCreator::mfc{};
std::vector<nvinfer1::PluginField> myPluginCreator::mPluginAttributes;

// 用于序列化插件的Helper function
template <typename T>
void writeToBuffer(char*& buffer, const T& val) 
{          
	*reinterpret_cast<T*>(buffer) = val;
	buffer += sizeof(T);
}

// 用于反序列化插件的Helper function
template <typename T>
T readFromBuffer(char const*& buffer) 
{          
	T val = *reinterpret_cast<const T*>(buffer);
	buffer += sizeof(T);
	return val;
}

// 定义插件类MYPlugin
myPlugin::myPlugin(const std::string name, const std::string attr1, float attr3)
	:mLayerName(name), mattr1(attr1), mattr3(attr3)
{
    std::cout<<"myPlugin"<<std::endl;
};

myPlugin::myPlugin(const std::string name, const void* data, size_t length) 
	:mLayerName(name)
{
    std::cout<<"myPlugin()"<<std::endl;       
	// Deserialize in the same order as serialization
	char const* d = static_cast<char const*>(data);
	char const* a = d;
	int nstr = readFromBuffer<int>(d);
	mattr1 = std::string(d, d + nstr);
	d += nstr;
	mattr3 = readFromBuffer<float>(d);
	assert(d == (a + length));
};

char const* myPlugin::getPluginType() const noexcept
{   
	std::cout<<"getPluginType"<<std::endl;         
	return "MYPLUGIN";
}

char const* myPlugin::getPluginVersion() const noexcept
{
 	std::cout<<"getPluginVersion"<<std::endl;             
	return "1";
}

int myPlugin::getNbOutputs() const noexcept 
{
   	std::cout<<"getNbOutputs"<<std::endl;          
	return 1;
}

// 获取该层的输出维度是多少
nvinfer1::DimsExprs myPlugin::getOutputDimensions(int32_t outputIndex,
	const nvinfer1::DimsExprs* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept 
{
    std::cout<<"getOutputDimensions"<<std::endl;         
	//不改变输入尺寸,所以输出尺寸将与输入尺寸相同
	return inputs[0];
}

int myPlugin::initialize() noexcept
{
    std::cout<<"initialize"<<std::endl;          
	return 0;
}

int myPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
	const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{
    std::cout<<"enqueue"<<std::endl;   
             
	void* output = outputs[0];
	size_t volume = 1;
	for (int i = 0; i < inputDesc->dims.nbDims; ++i)
	 {            
		volume *= inputDesc->dims.d[i];
	}
	mInputVolume = volume;

	myselu_inference(static_cast<const float*>(inputs[0]),
		static_cast<const int*>(inputs[1]),
		static_cast<float*>(output),
		mInputVolume,
		stream
	);

	return 0;
}

size_t myPlugin::getSerializationSize() const noexcept
{
    std::cout<<"getSerializationSize"<<std::endl;               
	return sizeof(int) + mattr1.size() + sizeof(mattr3);
}

// 该层的参数序列化储存为trtmodel文件
void myPlugin::serialize(void* buffer)  const noexcept
{
    std::cout<<"serialize"<<std::endl;                  
	char* d = static_cast<char*>(buffer);
	char const* a = d;
	int nstr = mattr1.size();
	writeToBuffer(d, nstr);
	memcpy(d, mattr1.data(), nstr);
	d += nstr;
	writeToBuffer(d, mattr3);
	assert(d == a + getSerializationSize());
}

// 判断该插件所支持的数据格式和类型
bool myPlugin::supportsFormatCombination(int32_t pos, const nvinfer1::PluginTensorDesc* inOut, int32_t nbInputs,
	int32_t nbOutputs) noexcept
{
	std::cout<<"supportsFormatCombination"<<std::endl;     
    PLUGIN_ASSERT(pos < nbInputs + nbOutputs);

    if (pos == 0)
    {
        return (inOut[pos].type == nvinfer1::DataType::kFLOAT) 
            && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
    }
    else if (pos == 1)
    {
        return (inOut[pos].type == nvinfer1::DataType::kINT32) 
            && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
    }
	else if (pos == 2)
    {
        return (inOut[pos].type == nvinfer1::DataType::kINT32) 
            && (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
    }
    return true;   
}

void myPlugin::terminate() noexcept { }

void myPlugin::destroy() noexcept
{            
	// This gets called when the network containing plugin is destroyed
	delete this;
}

// 配置插件格式:目前这个层所采用的数据格式和类型
void myPlugin::configurePlugin(const  nvinfer1::DynamicPluginTensorDesc* in, int32_t nbInputs,
	const  nvinfer1::DynamicPluginTensorDesc* out, int32_t nbOutputs) noexcept
{
    std::cout<<"configurePlugin"<<std::endl;           
    try 
    {
        PLUGIN_ASSERT(nbInputs == 3 && nbOutputs == 1); // 确认3个输入和1个输出
    } 
    catch (std::exception const& e) 
    {
        caughtError(e);
    }
}

// 克隆插件
nvinfer1::IPluginV2DynamicExt* myPlugin::clone() const noexcept
{
    std::cout<<"clone"<<std::endl;      
	auto plugin = new myPlugin(mLayerName, mattr1, mattr3);
	plugin->setPluginNamespace(mNamespace.c_str());
	return plugin;
}

void myPlugin::setPluginNamespace(char const* libNamespace) noexcept
{
    std::cout<<"setPluginNamespace"<<std::endl;                 
	mNamespace = libNamespace;
}

char const* myPlugin::getPluginNamespace() const noexcept
{
    std::cout<<"getPluginNamespace"<<std::endl;               
	return mNamespace.c_str();
}

// 插件创建器
myPluginCreator::myPluginCreator()
{
    std::cout<<"myPluginCreator"<<std::endl;               
	// 描述myPlugin的必要PluginField参数
	mPluginAttributes.emplace_back(nvinfer1::PluginField("attr1", nullptr, nvinfer1::PluginFieldType::kCHAR));
	mPluginAttributes.emplace_back(nvinfer1::PluginField("attr3", nullptr, nvinfer1::PluginFieldType::kFLOAT32));
	// 收集PluginField的参数
	mfc.nbFields = mPluginAttributes.size();
	mfc.fields = mPluginAttributes.data();
}

char const* myPluginCreator::getPluginName() const noexcept
{ 
	std::cout<<"getPluginName"<<std::endl;             
	return "MYPLUGIN";
}

char const* myPluginCreator::getPluginVersion() const noexcept
{
 	std::cout<<"getPluginVersion"<<std::endl;                
	return "1";
}

const nvinfer1::PluginFieldCollection* myPluginCreator::getFieldNames() noexcept
{
 	std::cout<<"getFieldNames"<<std::endl;               
	return &mfc;
}

// 创建plugin
nvinfer1::IPluginV2* myPluginCreator::createPlugin(nvinfer1::AsciiChar const* name,
	nvinfer1::PluginFieldCollection const* fc) noexcept
{
 	std::cout<<"createPlugin"<<std::endl;               
	std::string attr1;
	float attr3;
	const nvinfer1::PluginField* fields = fc->fields;

	// Parse fields from PluginFieldCollection
	for (int i = 0; i < fc->nbFields; ++i)
	 {          
		if (strcmp(fields[i].name, "attr1")==0) 
		{       
			assert(fields[i].type == nvinfer1::PluginFieldType::kCHAR);
			auto cp = static_cast<char const*>(fields[i].data);
			attr1 = std::string(cp, cp + fields[i].length);
		}
		else if (strcmp(fields[i].name, "attr3") == 0) 
		{        
			assert(fields[i].type == nvinfer1::PluginFieldType::kFLOAT32);
			attr3 = *(static_cast<const float*>(fields[i].data));
		}
	}
	return new myPlugin(name, attr1, attr3);
}

// 反序列化插件参数进行创建
nvinfer1::IPluginV2* myPluginCreator::deserializePlugin(nvinfer1::AsciiChar const* name,
	void const* serialData, size_t serialLength)noexcept
{   
	std::cout<<"deserializePlugin"<<std::endl;      
	// This object will be deleted when the network is destroyed, which will
	// call myPlugin::destroy()
	return new myPlugin(name, serialData, serialLength);
}

void myPluginCreator::setPluginNamespace(char const* libNamespace) noexcept
{  
	std::cout<<"setPluginNamespace"<<std::endl;        
	mNamespace = libNamespace;
}

char const* myPluginCreator::getPluginNamespace() const noexcept
{    
	std::cout<<"getPluginNamespace"<<std::endl;      
	return mNamespace.c_str();
}
}
}

myPlugin.cu

#include "NvInfer.h"
#include <cuda_runtime.h>


namespace nvinfer1
{
namespace plugin
{
static __global__ void myselu_kernel(const float* x, const int* a, float* output, int n)
{
            
	int position = threadIdx.x + blockDim.x*blockIdx.x;
	if (position >= n) return;
	output[position] = x[position] + a[0];
}

void myselu_inference(const float* x, const int* a, float* output, int n, cudaStream_t stream)
{     
	const int nthreads = 512;
	int block_size = n > nthreads ? nthreads : n;
	int grid_size = (n + block_size - 1) / block_size;
	myselu_kernel<<<grid_size, block_size, 0, stream>>>(x, a, output, n);
}
}
}

CMakeLists.txt

file(GLOB SRCS *.cpp)
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE)
file(GLOB CU_SRCS *.cu)
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} ${CU_SRCS})
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} PARENT_SCOPE)

在TensorRT/plugin/inferPlugin.cpp的开头添加

#include "myPlugin/myPlugin.h"

并在initLibNvInferPlugins函数中添加

initializePlugin<nvinfer1::plugin::myPluginCreator>(logger, libNamespace);

在TensorRT/plugin/CMakeLists.txt的set(PLUGIN_LISTS添加

myPlugin

在TensorRT/CMakeLists.txt中设置TRT_LIB_DIRTRT_OUT_DIR,再重新编译tensorrt。

tensorrt推理测试

运行下面的命令把onnx 转为engine模型:

TensorRT-10.6.0.26/bin/trtexec --onnx=myplugin.onnx --saveEngine=myplugin.engine

编写python推理脚本:

import numpy as np
import tensorrt as trt
import common


logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, "")
with open("myplugin.engine", "rb") as f, trt.Runtime(logger) as runtime:
    engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()
inputs, outputs, bindings, stream = common.allocate_buffers(engine)

input = np.ones((3, 3))
a = 2
np.copyto(inputs[0].host, input.ravel())
np.copyto(inputs[1].host, a)

output = common.do_inference(context,engine=engine, bindings=bindings,inputs=inputs, outputs=outputs, stream=stream)
print(output)