pytorch自定义算子转tensorrt
目录
pytorch自定义算子转tensorrt
pytorch导出onnx模型
设有模型:
import torch
import torch.nn as nn
class MYPLUGINImpl(torch.autograd.Function):
@staticmethod
def symbolic(g, x, p):
return g.op("MYPLUGIN", x, p,
g.op("Constant", value_t=torch.tensor([3, 2, 1], dtype=torch.int32)),
attr1_s="这是字符串属性",
attr2_i=[1, 2],
attr3_f=222
)
@staticmethod
def forward(ctx, x, a):
return x + a
class MYPLUGIN(nn.Module):
def __init__(self, n):
super().__init__()
self.param = nn.parameter.Parameter(torch.arange(n).float())
def forward(self, x, a):
return MYPLUGINImpl.apply(x, a)
class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(1, 1, 3, padding=1)
self.conv.weight.data.fill_(1)
self.conv.bias.data.fill_(0)
self.myplugin = MYPLUGIN(3)
def forward(self, x, a):
x = self.conv(x)
x = self.myplugin(x, a)
return x
model = Model().eval()
input = torch.tensor([
[
[1, 1, 1],
[1, 1, 1],
[1, 1, 1],
],
], dtype=torch.float32).view(1, 1, 3, 3)
a = torch.tensor(2, dtype=torch.int32)
output = model(input, a)
print(f"inference output = {output}")
torch.onnx.export(
model, # 这里的args,是指输入给model的参数,需要传递tuple,因此用括号
(input, a),
"myplugin.onnx", # 储存的文件路径
input_names=["image"], # 为输入和输出节点指定名称,方便后面查看或者操作
output_names=["output"],
opset_version=13,# 这里的opset,指,各类算子以何种方式导出,对应于symbolic_opset11
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK,
#verbose=True,# 打印详细信息
# 表示有batch、height、width3个维度是动态的,在onnx中给其赋值为-1,通常,我们只设置batch为动态,其他的避免动态
# dynamic_axes={
# "image": {
# 0: "batch", 2: "height", 3: "width"},
# "output": {
# 0: "batch", 2: "height", 3: "width"},
# },
)
程序输出:
inference output =
tensor([[[[ 6., 8., 6.],
[ 8., 11., 8.],
[ 6., 8., 6.]]]], grad_fn=<MYPLUGINImplBackward>)
导出的onnx模型结构如下:
编写tensorrt自定义算子
采用TensorRT-10.6.0.26。由于TensorRT是部分开源,首先在
下载TensorRT-10.6.0.26的库,然后在
下载源代码。
在TensorRT/plugin下新建myPlugin文件夹,添加下面文件:
myPlugin.h
#ifndef TRT_MYPLUGIN_H
#define TRT_MYPLUGIN_H
#include "NvInfer.h"
#include "NvInferPlugin.h"
#include "common/plugin.h"
#include <string>
#include <vector>
namespace nvinfer1
{
namespace plugin
{
class myPlugin : public nvinfer1::IPluginV2DynamicExt
{
public:
myPlugin(const std::string name, const std::string attr1, float attr3); // 接受算子名称属性,build engine时构造函数
myPlugin(const std::string name, const void* data, size_t length); // 接受算子名称和反序列化的engine data,推理时构造函数
int getNbOutputs() const noexcept override;
virtual nvinfer1::DataType getOutputDataType(int32_t index,
nvinfer1::DataType const* inputTypes, int32_t nbInputs) const noexcept override {
return inputTypes[0];
}
virtual nvinfer1::DimsExprs getOutputDimensions(int32_t outputIndex,
const nvinfer1::DimsExprs* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept override;
int initialize() noexcept override;
void terminate() noexcept override;
virtual size_t getWorkspaceSize(const nvinfer1::PluginTensorDesc* inputs,
int32_t nbInputs, const nvinfer1::PluginTensorDesc* outputs,
int32_t nbOutputs) const noexcept override {
return 0;
};
int enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
virtual void configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int32_t nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int32_t nbOutputs) noexcept override;
virtual bool supportsFormatCombination(int32_t pos, const nvinfer1::PluginTensorDesc* inOut, int32_t nbInputs,
int32_t nbOutputs) noexcept override;
const char* getPluginType() const noexcept override;
const char* getPluginVersion() const noexcept override;
void destroy() noexcept override;
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
void setPluginNamespace(nvinfer1::AsciiChar const* pluginNamespace) noexcept override;
const char* getPluginNamespace()const noexcept override;
private:
const std::string mLayerName;
std::string mattr1;
float mattr3;
size_t mInputVolume;
std::string mNamespace;
};
class myPluginCreator : public nvinfer1::IPluginCreator
{
public:
myPluginCreator();
const char* getPluginName() const noexcept override;
const char* getPluginVersion() const noexcept override;
const nvinfer1::PluginFieldCollection* getFieldNames() noexcept override;
nvinfer1::IPluginV2* createPlugin(nvinfer1::AsciiChar const* name,
nvinfer1::PluginFieldCollection const* fc) noexcept override;
nvinfer1::IPluginV2* deserializePlugin(nvinfer1::AsciiChar const* name,
void const* serialData, size_t serialLength)noexcept override;
void setPluginNamespace(nvinfer1::AsciiChar const* pluginNamespace) noexcept override;
const char* getPluginNamespace() const noexcept override;
private:
static nvinfer1::PluginFieldCollection mfc;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
std::string mNamespace;
};
}
}
#endif // TRT_MYPLUGIN_H
myPlugin.cpp
#include "myPlugin.h"
#include <NvInfer.h>
#include <cstring>
#include <vector>
#include <cassert>
using namespace nvinfer1;
using namespace nvinfer1::pluginInternal;
using nvinfer1::plugin::myPlugin;
using nvinfer1::plugin::myPluginCreator;
namespace nvinfer1
{
namespace plugin
{
void myselu_inference(const float* x, const int* a, float* output, int n, cudaStream_t stream);
// 静态类字段的初始化
nvinfer1::PluginFieldCollection myPluginCreator::mfc{};
std::vector<nvinfer1::PluginField> myPluginCreator::mPluginAttributes;
// 用于序列化插件的Helper function
template <typename T>
void writeToBuffer(char*& buffer, const T& val)
{
*reinterpret_cast<T*>(buffer) = val;
buffer += sizeof(T);
}
// 用于反序列化插件的Helper function
template <typename T>
T readFromBuffer(char const*& buffer)
{
T val = *reinterpret_cast<const T*>(buffer);
buffer += sizeof(T);
return val;
}
// 定义插件类MYPlugin
myPlugin::myPlugin(const std::string name, const std::string attr1, float attr3)
:mLayerName(name), mattr1(attr1), mattr3(attr3)
{
std::cout<<"myPlugin"<<std::endl;
};
myPlugin::myPlugin(const std::string name, const void* data, size_t length)
:mLayerName(name)
{
std::cout<<"myPlugin()"<<std::endl;
// Deserialize in the same order as serialization
char const* d = static_cast<char const*>(data);
char const* a = d;
int nstr = readFromBuffer<int>(d);
mattr1 = std::string(d, d + nstr);
d += nstr;
mattr3 = readFromBuffer<float>(d);
assert(d == (a + length));
};
char const* myPlugin::getPluginType() const noexcept
{
std::cout<<"getPluginType"<<std::endl;
return "MYPLUGIN";
}
char const* myPlugin::getPluginVersion() const noexcept
{
std::cout<<"getPluginVersion"<<std::endl;
return "1";
}
int myPlugin::getNbOutputs() const noexcept
{
std::cout<<"getNbOutputs"<<std::endl;
return 1;
}
// 获取该层的输出维度是多少
nvinfer1::DimsExprs myPlugin::getOutputDimensions(int32_t outputIndex,
const nvinfer1::DimsExprs* inputs, int32_t nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
std::cout<<"getOutputDimensions"<<std::endl;
//不改变输入尺寸,所以输出尺寸将与输入尺寸相同
return inputs[0];
}
int myPlugin::initialize() noexcept
{
std::cout<<"initialize"<<std::endl;
return 0;
}
int myPlugin::enqueue(const nvinfer1::PluginTensorDesc* inputDesc, const nvinfer1::PluginTensorDesc* outputDesc,
const void* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{
std::cout<<"enqueue"<<std::endl;
void* output = outputs[0];
size_t volume = 1;
for (int i = 0; i < inputDesc->dims.nbDims; ++i)
{
volume *= inputDesc->dims.d[i];
}
mInputVolume = volume;
myselu_inference(static_cast<const float*>(inputs[0]),
static_cast<const int*>(inputs[1]),
static_cast<float*>(output),
mInputVolume,
stream
);
return 0;
}
size_t myPlugin::getSerializationSize() const noexcept
{
std::cout<<"getSerializationSize"<<std::endl;
return sizeof(int) + mattr1.size() + sizeof(mattr3);
}
// 该层的参数序列化储存为trtmodel文件
void myPlugin::serialize(void* buffer) const noexcept
{
std::cout<<"serialize"<<std::endl;
char* d = static_cast<char*>(buffer);
char const* a = d;
int nstr = mattr1.size();
writeToBuffer(d, nstr);
memcpy(d, mattr1.data(), nstr);
d += nstr;
writeToBuffer(d, mattr3);
assert(d == a + getSerializationSize());
}
// 判断该插件所支持的数据格式和类型
bool myPlugin::supportsFormatCombination(int32_t pos, const nvinfer1::PluginTensorDesc* inOut, int32_t nbInputs,
int32_t nbOutputs) noexcept
{
std::cout<<"supportsFormatCombination"<<std::endl;
PLUGIN_ASSERT(pos < nbInputs + nbOutputs);
if (pos == 0)
{
return (inOut[pos].type == nvinfer1::DataType::kFLOAT)
&& (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
}
else if (pos == 1)
{
return (inOut[pos].type == nvinfer1::DataType::kINT32)
&& (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
}
else if (pos == 2)
{
return (inOut[pos].type == nvinfer1::DataType::kINT32)
&& (inOut[pos].format == nvinfer1::PluginFormat::kLINEAR);
}
return true;
}
void myPlugin::terminate() noexcept { }
void myPlugin::destroy() noexcept
{
// This gets called when the network containing plugin is destroyed
delete this;
}
// 配置插件格式:目前这个层所采用的数据格式和类型
void myPlugin::configurePlugin(const nvinfer1::DynamicPluginTensorDesc* in, int32_t nbInputs,
const nvinfer1::DynamicPluginTensorDesc* out, int32_t nbOutputs) noexcept
{
std::cout<<"configurePlugin"<<std::endl;
try
{
PLUGIN_ASSERT(nbInputs == 3 && nbOutputs == 1); // 确认3个输入和1个输出
}
catch (std::exception const& e)
{
caughtError(e);
}
}
// 克隆插件
nvinfer1::IPluginV2DynamicExt* myPlugin::clone() const noexcept
{
std::cout<<"clone"<<std::endl;
auto plugin = new myPlugin(mLayerName, mattr1, mattr3);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
void myPlugin::setPluginNamespace(char const* libNamespace) noexcept
{
std::cout<<"setPluginNamespace"<<std::endl;
mNamespace = libNamespace;
}
char const* myPlugin::getPluginNamespace() const noexcept
{
std::cout<<"getPluginNamespace"<<std::endl;
return mNamespace.c_str();
}
// 插件创建器
myPluginCreator::myPluginCreator()
{
std::cout<<"myPluginCreator"<<std::endl;
// 描述myPlugin的必要PluginField参数
mPluginAttributes.emplace_back(nvinfer1::PluginField("attr1", nullptr, nvinfer1::PluginFieldType::kCHAR));
mPluginAttributes.emplace_back(nvinfer1::PluginField("attr3", nullptr, nvinfer1::PluginFieldType::kFLOAT32));
// 收集PluginField的参数
mfc.nbFields = mPluginAttributes.size();
mfc.fields = mPluginAttributes.data();
}
char const* myPluginCreator::getPluginName() const noexcept
{
std::cout<<"getPluginName"<<std::endl;
return "MYPLUGIN";
}
char const* myPluginCreator::getPluginVersion() const noexcept
{
std::cout<<"getPluginVersion"<<std::endl;
return "1";
}
const nvinfer1::PluginFieldCollection* myPluginCreator::getFieldNames() noexcept
{
std::cout<<"getFieldNames"<<std::endl;
return &mfc;
}
// 创建plugin
nvinfer1::IPluginV2* myPluginCreator::createPlugin(nvinfer1::AsciiChar const* name,
nvinfer1::PluginFieldCollection const* fc) noexcept
{
std::cout<<"createPlugin"<<std::endl;
std::string attr1;
float attr3;
const nvinfer1::PluginField* fields = fc->fields;
// Parse fields from PluginFieldCollection
for (int i = 0; i < fc->nbFields; ++i)
{
if (strcmp(fields[i].name, "attr1")==0)
{
assert(fields[i].type == nvinfer1::PluginFieldType::kCHAR);
auto cp = static_cast<char const*>(fields[i].data);
attr1 = std::string(cp, cp + fields[i].length);
}
else if (strcmp(fields[i].name, "attr3") == 0)
{
assert(fields[i].type == nvinfer1::PluginFieldType::kFLOAT32);
attr3 = *(static_cast<const float*>(fields[i].data));
}
}
return new myPlugin(name, attr1, attr3);
}
// 反序列化插件参数进行创建
nvinfer1::IPluginV2* myPluginCreator::deserializePlugin(nvinfer1::AsciiChar const* name,
void const* serialData, size_t serialLength)noexcept
{
std::cout<<"deserializePlugin"<<std::endl;
// This object will be deleted when the network is destroyed, which will
// call myPlugin::destroy()
return new myPlugin(name, serialData, serialLength);
}
void myPluginCreator::setPluginNamespace(char const* libNamespace) noexcept
{
std::cout<<"setPluginNamespace"<<std::endl;
mNamespace = libNamespace;
}
char const* myPluginCreator::getPluginNamespace() const noexcept
{
std::cout<<"getPluginNamespace"<<std::endl;
return mNamespace.c_str();
}
}
}
myPlugin.cu
#include "NvInfer.h"
#include <cuda_runtime.h>
namespace nvinfer1
{
namespace plugin
{
static __global__ void myselu_kernel(const float* x, const int* a, float* output, int n)
{
int position = threadIdx.x + blockDim.x*blockIdx.x;
if (position >= n) return;
output[position] = x[position] + a[0];
}
void myselu_inference(const float* x, const int* a, float* output, int n, cudaStream_t stream)
{
const int nthreads = 512;
int block_size = n > nthreads ? nthreads : n;
int grid_size = (n + block_size - 1) / block_size;
myselu_kernel<<<grid_size, block_size, 0, stream>>>(x, a, output, n);
}
}
}
CMakeLists.txt
file(GLOB SRCS *.cpp)
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} PARENT_SCOPE)
file(GLOB CU_SRCS *.cu)
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} ${CU_SRCS})
set(PLUGIN_CU_SOURCES ${PLUGIN_CU_SOURCES} PARENT_SCOPE)
在TensorRT/plugin/inferPlugin.cpp的开头添加
#include "myPlugin/myPlugin.h"
并在initLibNvInferPlugins函数中添加
initializePlugin<nvinfer1::plugin::myPluginCreator>(logger, libNamespace);
在TensorRT/plugin/CMakeLists.txt的set(PLUGIN_LISTS添加
myPlugin
在TensorRT/CMakeLists.txt中设置TRT_LIB_DIR
、TRT_OUT_DIR
,再重新编译tensorrt。
tensorrt推理测试
运行下面的命令把onnx 转为engine模型:
TensorRT-10.6.0.26/bin/trtexec --onnx=myplugin.onnx --saveEngine=myplugin.engine
编写python推理脚本:
import numpy as np
import tensorrt as trt
import common
logger = trt.Logger(trt.Logger.WARNING)
trt.init_libnvinfer_plugins(logger, "")
with open("myplugin.engine", "rb") as f, trt.Runtime(logger) as runtime:
engine = runtime.deserialize_cuda_engine(f.read())
context = engine.create_execution_context()
inputs, outputs, bindings, stream = common.allocate_buffers(engine)
input = np.ones((3, 3))
a = 2
np.copyto(inputs[0].host, input.ravel())
np.copyto(inputs[1].host, a)
output = common.do_inference(context,engine=engine, bindings=bindings,inputs=inputs, outputs=outputs, stream=stream)
print(output)