(多輸入+動(dòng)態(tài)維度)整理的自定義神經(jīng)網(wǎng)絡(luò)pt轉(zhuǎn)onnx過(guò)程的python代碼,記錄了pt文件轉(zhuǎn)onnx全過(guò)程,簡(jiǎn)單的修改即可應(yīng)用。
1、編寫(xiě)預(yù)處理代碼
預(yù)處理代碼 與torch模型的預(yù)處理代碼一樣
def preprocess(img):
img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1)
img = np.expand_dims(img, 0)
sh_im = img.shape
if sh_im[2]%2==1:
img = np.concatenate((img, img[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
if sh_im[3]%2==1:
img = np.concatenate((img, img[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
img = normalize(img)
img = torch.Tensor(img)
return img
2、用onnxruntime導(dǎo)出onnx
def export_onnx(net, model_path, img, nsigma, onnx_outPath):
nsigma /= 255.
if torch.cuda.is_available():
state_dict = torch.load(model_path)
model = net.cuda()
dtype = torch.cuda.FloatTensor
else:
state_dict = torch.load(model_path, map_location='cpu')
state_dict = remove_dataparallel_wrapper(state_dict)
model = net
dtype = torch.FloatTensor
img = Variable(img.type(dtype))
nsigma = Variable(torch.FloatTensor([nsigma]).type(dtype))
# 我這里預(yù)訓(xùn)練權(quán)重中參數(shù)名字與網(wǎng)絡(luò)名字不同
# 相同的話可直接load_state_dict(state_dict)
new_state_dict = {}
for k, v in state_dict.items():
new_state_dict[k[7:]] = v
model.load_state_dict(new_state_dict)
# 設(shè)置onnx的輸入輸出列表,多輸入多輸出就設(shè)置多個(gè)
input_list = ['input', 'nsigma']
output_list = ['output']
# onnx模型導(dǎo)出
# dynamic_axes為動(dòng)態(tài)維度,如果自己的輸入輸出是維度變化的建議設(shè)置,否則只能輸入固定維度的tensor
torch.onnx.export(model, (img, nsigma), onnx_outPath, verbose=True, opset_version=11, export_params=True,
input_names=input_list, output_names=output_list,
dynamic_axes={'input_img': {0: 'batch', 1: 'channel', 2: 'height', 3: 'width'},
'output': {0: 'batch', 1: 'channel', 2: 'height', 3: 'width'}})
導(dǎo)出結(jié)果
3、對(duì)導(dǎo)出的模型進(jìn)行檢查
此處為檢查onnx模型節(jié)點(diǎn),后面如果onnx算子不支持轉(zhuǎn)engine時(shí),方便定位節(jié)點(diǎn),找到不支持的算子進(jìn)行修改
def check_onnx(onnx_model_path):
model = onnx.load(onnx_model_path)
onnx.checker.check_model((model))
print(onnx.helper.printable_graph(model.graph))
下面貼出輸出結(jié)果
netron可視化文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-754149.html
4、推理onnx模型,查看輸出是否一致
def run_onnx(onnx_model_path, test_img, nsigma):
nsigma /= 255.
with torch.no_grad:
# 這里默認(rèn)是cuda推理torch.cuda.FloatTensor
img = Variable(test_img.type(torch.cuda.FloatTensor))
nsigma = Variable(torch.FloatTensor([nsigma]).type(torch.cuda.FloatTensor))
# 設(shè)置GPU推理
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
providers = ['CUDAExecutionProvider'] if device != "cpu" else ['CPUExecutionProvider']
# 通過(guò)創(chuàng)建onnxruntime session來(lái)運(yùn)行onnx模型
ort_session = ort.InferenceSession(onnx_model_path, providers=providers)
output = ort_session.run(output_names=['output'],
input_feed={'input_img': np.array(img.cpu(), dtype=np.float32),
'nsigma': np.array(nsigma.cpu(), dtype=np.float32)})
return output
5、對(duì)onnx模型的輸出進(jìn)行處理,顯示cv圖像
def postprocess(img, img_noise_estime):
out = torch.clamp(img-img_noise_estime, 0., 1.)
outimg = variable_to_cv2_image(out)
cv2.imshow(outimg)
6、編輯主函數(shù)進(jìn)行測(cè)試
def main():
##############################
#
# onnx模型導(dǎo)出
#
##############################
# pt權(quán)重路徑:自己的路徑 + mypt.pt
model_path = "D:/python/ffdnet-pytorch/models/net_rgb.pth"
# export onnx模型時(shí)輸入進(jìn)去數(shù)據(jù),用于onnx記錄網(wǎng)絡(luò)的計(jì)算過(guò)程
export_feed_path = "D:/python/ffdnet-pytorch/noisy.png"
# onnx模型導(dǎo)出的路徑
onnx_outpath = "D:/python/ffdnet-pytorch/models/myonnx.onnx"
# 實(shí)例化自己的網(wǎng)絡(luò)模型并設(shè)置輸入?yún)?shù)
net = FFDNet(num_input_channels=3)
nsigma = 25
# onnx 導(dǎo)出
img = cv2.imread(export_feed_path)
input = preprocess(img)
export_onnx(net, model_path, input, nsigma, onnx_outpath)
print("export success!")
##############################
#
# 檢查onnx模型
#
##############################
check_onnx(onnx_outpath)
# netron可視化網(wǎng)絡(luò),可視化用節(jié)點(diǎn)記錄的網(wǎng)絡(luò)推理流程
netron.start(onnx_outpath)
##############################
#
# 運(yùn)行onnx模型
#
##############################
# 此處過(guò)程是數(shù)據(jù)預(yù)處理 ---> 調(diào)用run_onnx函數(shù) ---> 對(duì)模型輸出后處理
# 具體代碼就不再重復(fù)了
#完整代碼文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-754149.html
import time
import netron
import cv2
import torch
import onnx
import numpy as np
from torch.autograd import Variable
import onnxruntime as ort
from models import FFDNet
from utils import remove_dataparallel_wrapper, normalize, variable_to_cv2_image
# 此處為預(yù)處理代碼 與torch模型的預(yù)處理代碼一樣
def preprocess(img):
img = (cv2.cvtColor(img, cv2.COLOR_BGR2RGB)).transpose(2, 0, 1)
img = np.expand_dims(img, 0)
sh_im = img.shape
if sh_im[2]%2==1:
img = np.concatenate((img, img[:, :, -1, :][:, :, np.newaxis, :]), axis=2)
if sh_im[3]%2==1:
img = np.concatenate((img, img[:, :, :, -1][:, :, :, np.newaxis]), axis=3)
img = normalize(img)
img = torch.Tensor(img)
return img
# 此處為onnx模型導(dǎo)出的代碼,包括torch模型的pt權(quán)重加載,onnx模型的導(dǎo)出
def export_onnx(net, model_path, img, nsigma, onnx_outPath):
nsigma /= 255.
if torch.cuda.is_available():
state_dict = torch.load(model_path)
model = net.cuda()
dtype = torch.cuda.FloatTensor
else:
state_dict = torch.load(model_path, map_location='cpu')
state_dict = remove_dataparallel_wrapper(state_dict)
model = net
dtype = torch.FloatTensor
img = Variable(img.type(dtype))
nsigma = Variable(torch.FloatTensor([nsigma]).type(dtype))
# 我這里預(yù)訓(xùn)練權(quán)重中參數(shù)名字與網(wǎng)絡(luò)名字不同
# 相同的話可直接load_state_dict(state_dict)
new_state_dict = {}
for k, v in state_dict.items():
new_state_dict[k[7:]] = v
model.load_state_dict(new_state_dict)
# 設(shè)置onnx的輸入輸出列表,多輸入多輸出就設(shè)置多個(gè)
input_list = ['input', 'nsigma']
output_list = ['output']
# onnx模型導(dǎo)出
# dynamic_axes為動(dòng)態(tài)維度,如果自己的輸入輸出是維度變化的建議設(shè)置,否則只能輸入固定維度的tensor
torch.onnx.export(model, (img, nsigma), onnx_outPath, verbose=True, opset_version=11, export_params=True,
input_names=input_list, output_names=output_list,
dynamic_axes={'input_img': {0: 'batch', 1: 'channel', 2: 'height', 3: 'width'},
'output': {0: 'batch', 1: 'channel', 2: 'height', 3: 'width'}})
# 此處為檢查onnx模型節(jié)點(diǎn),后面如果onnx算子不支持轉(zhuǎn)engine時(shí),方便定位節(jié)點(diǎn),找到不支持的算子進(jìn)行修改
def check_onnx(onnx_model_path):
model = onnx.load(onnx_model_path)
onnx.checker.check_model((model))
print(onnx.helper.printable_graph(model.graph))
# 此處為推理onnx模型的代碼,檢查輸出是否跟torch模型相同
def run_onnx(onnx_model_path, test_img, nsigma):
nsigma /= 255.
with torch.no_grad:
# 這里默認(rèn)是cuda推理torch.cuda.FloatTensor
img = Variable(test_img.type(torch.cuda.FloatTensor))
nsigma = Variable(torch.FloatTensor([nsigma]).type(torch.cuda.FloatTensor))
# 設(shè)置GPU推理
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
providers = ['CUDAExecutionProvider'] if device != "cpu" else ['CPUExecutionProvider']
# 通過(guò)創(chuàng)建onnxruntime session來(lái)運(yùn)行onnx模型
ort_session = ort.InferenceSession(onnx_model_path, providers=providers)
output = ort_session.run(output_names=['output'],
input_feed={'input_img': np.array(img.cpu(), dtype=np.float32),
'nsigma': np.array(nsigma.cpu(), dtype=np.float32)})
return output
# 此處是后處理代碼,將onnx模型的輸出處理成可顯示cv圖像
# 與torch模型的后處理一樣
def postprocess(img, img_noise_estime):
out = torch.clamp(img-img_noise_estime, 0., 1.)
outimg = variable_to_cv2_image(out)
cv2.imshow(outimg)
def main():
##############################
#
# onnx模型導(dǎo)出
#
##############################
# pt權(quán)重路徑:自己的路徑 + mypt.pt
model_path = "D:/python/ffdnet-pytorch/models/net_rgb.pth"
# export onnx模型時(shí)輸入進(jìn)去數(shù)據(jù),用于onnx記錄網(wǎng)絡(luò)的計(jì)算過(guò)程
export_feed_path = "D:/python/ffdnet-pytorch/noisy.png"
# onnx模型導(dǎo)出的路徑
onnx_outpath = "D:/python/ffdnet-pytorch/models/myonnx.onnx"
# 實(shí)例化自己的網(wǎng)絡(luò)模型并設(shè)置輸入?yún)?shù)
net = FFDNet(num_input_channels=3)
nsigma = 25
# onnx 導(dǎo)出
img = cv2.imread(export_feed_path)
input = preprocess(img)
export_onnx(net, model_path, input, nsigma, onnx_outpath)
print("export success!")
##############################
#
# 檢查onnx模型
#
##############################
onnx_outpath = "D:/python/ffdnet-pytorch/models/myonnx.onnx"
check_onnx(onnx_outpath)
# netron可視化網(wǎng)絡(luò),可視化用節(jié)點(diǎn)記錄的網(wǎng)絡(luò)推理流程
netron.start(onnx_outpath)
##############################
#
# 運(yùn)行onnx模型
#
##############################
# 此處過(guò)程是數(shù)據(jù)預(yù)處理 ---> 調(diào)用run_onnx函數(shù) ---> 對(duì)模型輸出后處理
# 具體代碼就不再重復(fù)了
if __name__ == '__main__':
main()
到了這里,關(guān)于python pytorch模型轉(zhuǎn)onnx模型(多輸入+動(dòng)態(tài)維度)的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!