Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

怎样转化为onnx模型 #26

Open
huaxiangwangman opened this issue Feb 14, 2022 · 3 comments
Open

怎样转化为onnx模型 #26

huaxiangwangman opened this issue Feb 14, 2022 · 3 comments

Comments

@huaxiangwangman
Copy link

No description provided.

@taohan10200
Copy link
Owner

我们提供的是pytorch保存的模型,onnx格式的模型可自通过我们开源的模型参数自行转换

@csz-006
Copy link

csz-006 commented Sep 7, 2023

请问你转换成功了吗,可以看下转onnx的代码嘛

@csz-006
Copy link

csz-006 commented Nov 2, 2023

这个是转换为onnx的代码,需要注意输入为13h*w
import os
from tkinter.messagebox import NO
import torch
import torch.onnx
import torch.nn as nn
import onnxruntime as ort
import numpy as np
import torch.nn.functional as F
from tqdm import tqdm
import onnx
from onnxsim import simplify
from model.locator import Crowd_locator
from collections import OrderedDict

# os.environ["CUDA_VISIBLE_DEVICES"]= "1"

GPU_ID = "0"
os.environ["CUDA_VISIBLE_DEVICES"] = GPU_ID
torch.backends.cudnn.benchmark = True

def onnx_export(model_path):
net = Crowd_locator(netName,GPU_ID,pretrained=False)
net.cuda()
state_dict = torch.load(model_path)
if len(GPU_ID.split(",")) > 1:
net.load_state_dict(state_dict)
else:
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k.replace("module.", "")
new_state_dict[name] = v
net.load_state_dict(new_state_dict)
net.eval()
# 打印模型的每一层名称
for name, module in net.named_modules():
print(name,"nnnnn")
dummpy_input = torch.zeros(1, 3, 512, 1024).cuda() # 640 640
# dummpy_input = torch.zeros(1, 3, 512, 1024).cuda()
onnx_name = "HRnet_Crowd_count_512_1024_opset12.onnx"
# net = net(dummpy_input)
torch.onnx.export(
net, dummpy_input, onnx_name,
verbose=True,
input_names=["image"],
output_names=["predict"],
opset_version=12,
dynamic_axes=None
)

def onnx_sim(onnx_path):
model_onnx = onnx.load_model(onnx_path)
model_smi, check = simplify(model_onnx)
save_path = "HRnet_Crowd_count_512_1024_opset12-sim.onnx"
onnx.save(model_smi, save_path)
print("模型静态图简化完成")

if name == "main":
netName = "HR_Net" # VGG16_FPN HR_Net
model_path = "/IIM/Preweights/NWPU-HR-ep_241_F1_0.802_Pre_0.841_Rec_0.766_mae_55.6_mse_330.9.pth"

onnx_path = "/IIM/Preweights/1024_HRnet_Crowd_count_512_1024_opset12.onnx"
# save_model(pth_file)
onnx_export(model_path)
# onnx_sim(onnx_path)
print("Done")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants