ONNX(Open Neural Network Exchange)是一個(gè)開源項(xiàng)目,旨在建立一個(gè)開放的標(biāo)準(zhǔn),使深度學(xué)習(xí)模型可以在不同的軟件平臺(tái)和工具之間輕松移動(dòng)和重用。
ONNX模型可以用于各種應(yīng)用場景,例如機(jī)器翻譯、圖像識別、語音識別、自然語言處理等。
由于ONNX模型的互操作性,開發(fā)人員可以使用不同的框架來訓(xùn)練,模型可以更容易地在不同的框架之間轉(zhuǎn)換,例如從PyTorch轉(zhuǎn)換到TensorFlow,或從TensorFlow轉(zhuǎn)換到MXNet等。然后將其部署到不同的環(huán)境中,例如云端、邊緣設(shè)備或移動(dòng)設(shè)備等。
ONNX還提供了一組工具和庫,幫助開發(fā)人員更容易地創(chuàng)建、訓(xùn)練和部署深度學(xué)習(xí)模型。
ONNX模型是由多個(gè)節(jié)點(diǎn)(node)組成的圖(graph),每個(gè)節(jié)點(diǎn)代表一個(gè)操作或一個(gè)張量(tensor)。ONNX模型還包含了一些元數(shù)據(jù),例如模型的版本、輸入和輸出張量的名稱等。
onnx官網(wǎng)
ONNX | Home
pytorch官方使用onnx模型格式舉例
(optional) Exporting a Model from PyTorch to ONNX and Running it using ONNX Runtime — PyTorch Tutorials 2.2.0+cu121 documentation
TensorFlow官方使用onnx模型格式舉例
https://github.com/onnx/tutorials/blob/master/tutorials/TensorflowToOnnx-1.ipynb
Netron可視化模型結(jié)構(gòu)工具
Netron
你可通過該工具看到onnx具體的模型結(jié)構(gòu),點(diǎn)擊每層都能看到其對應(yīng)的內(nèi)容信息
onnxRuntime? |?提供各種編程語言推導(dǎo)onnx格式模型的接口
ONNX Runtime | Home
比如我需要在java環(huán)境下調(diào)用一個(gè)onnx模型,我可以先導(dǎo)入onnxRuntime的依賴,對數(shù)據(jù)預(yù)處理后,調(diào)用onnx格式模型正向傳播導(dǎo)出數(shù)據(jù),然后將數(shù)據(jù)處理成我要的數(shù)據(jù)。?
onnxRuntime也提供了其他編程語言的接口,如C++、C#、JavaScript、python等等。
實(shí)際案例舉例
python部分
python下利用ultralytics從網(wǎng)上下載并導(dǎo)出yolov8的onnx格式模型,用java調(diào)用onnxruntim接口,正向傳播推導(dǎo)模型數(shù)據(jù)。
pip install ultralytics
from ultralytics import YOLO
# 加載模型
model = YOLO('yolov8n.pt') # 加載官方模型
#加載自定義訓(xùn)練的模型
#model = YOLO('F:\\File\\AI\\Object\\yolov8_test\\runs\\detect\\train\\weights\\best.pt')
# 導(dǎo)出模型
model.export(format='onnx')
java部分
前提安裝java的opencv(Get Started - OpenCV),我這安裝的是opencv480
maven依賴
<dependencies>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
<version>1.12.0</version>
</dependency>
<!-- 加載lib目錄下的opencv包 -->
<dependency>
<groupId>org.opencv</groupId>
<artifactId>opencv</artifactId>
<version>4.8.0</version>
<scope>system</scope>
<!--通過路徑加載OpenCV480的jar包-->
<systemPath>${basedir}/lib/opencv-480.jar</systemPath>
</dependency>
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>fastjson</artifactId>
<version>2.0.32</version>
</dependency>
</dependencies>
java完整代碼
package com.sky;
//天宇 2023/12/21 20:23:13
import ai.onnxruntime.*;
import com.alibaba.fastjson.JSONObject;
import org.opencv.core.*;
import org.opencv.core.Point;
import org.opencv.highgui.HighGui;
import org.opencv.imgcodecs.Imgcodecs;
import org.opencv.imgproc.Imgproc;
import java.nio.FloatBuffer;
import java.text.DecimalFormat;
import java.util.*;
import java.util.List;
/**
* onnx學(xué)習(xí)筆記 GTianyu
*/
public class onnxLoadTest01 {
public static OrtEnvironment env;
public static OrtSession session;
public static JSONObject names;
public static long count;
public static long channels;
public static long netHeight;
public static long netWidth;
public static float srcw;
public static float srch;
public static float confThreshold = 0.25f;
public static float nmsThreshold = 0.5f;
static Mat src;
public static void load(String path) {
String weight = path;
try{
env = OrtEnvironment.getEnvironment();
session = env.createSession(weight, new OrtSession.SessionOptions());
OnnxModelMetadata metadata = session.getMetadata();
Map<String, NodeInfo> infoMap = session.getInputInfo();
TensorInfo nodeInfo = (TensorInfo)infoMap.get("images").getInfo();
String nameClass = metadata.getCustomMetadata().get("names");
System.out.println("getProducerName="+metadata.getProducerName());
System.out.println("getGraphName="+metadata.getGraphName());
System.out.println("getDescription="+metadata.getDescription());
System.out.println("getDomain="+metadata.getDomain());
System.out.println("getVersion="+metadata.getVersion());
System.out.println("getCustomMetadata="+metadata.getCustomMetadata());
System.out.println("getInputInfo="+infoMap);
System.out.println("nodeInfo="+nodeInfo);
System.out.println(nameClass);
names = JSONObject.parseObject(nameClass.replace("\"","\"\""));
count = nodeInfo.getShape()[0];//1 模型每次處理一張圖片
channels = nodeInfo.getShape()[1];//3 模型通道數(shù)
netHeight = nodeInfo.getShape()[2];//640 模型高
netWidth = nodeInfo.getShape()[3];//640 模型寬
System.out.println(names.get(0));
// 加載opencc需要的動(dòng)態(tài)庫
System.loadLibrary(Core.NATIVE_LIBRARY_NAME);
}
catch (Exception e){
e.printStackTrace();
System.exit(0);
}
}
public static Map<Object, Object> predict(String imgPath) throws Exception {
src=Imgcodecs.imread(imgPath);
return predictor();
}
public static Map<Object, Object> predict(Mat mat) throws Exception {
src=mat;
return predictor();
}
public static OnnxTensor transferTensor(Mat dst){
Imgproc.cvtColor(dst, dst, Imgproc.COLOR_BGR2RGB);
dst.convertTo(dst, CvType.CV_32FC1, 1. / 255);
float[] whc = new float[ Long.valueOf(channels).intValue() * Long.valueOf(netWidth).intValue() * Long.valueOf(netHeight).intValue() ];
dst.get(0, 0, whc);
float[] chw = whc2cwh(whc);
OnnxTensor tensor = null;
try {
tensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(chw), new long[]{count,channels,netWidth,netHeight});
}
catch (Exception e){
e.printStackTrace();
System.exit(0);
}
return tensor;
}
//寬 高 類型 to 類 寬 高
public static float[] whc2cwh(float[] src) {
float[] chw = new float[src.length];
int j = 0;
for (int ch = 0; ch < 3; ++ch) {
for (int i = ch; i < src.length; i += 3) {
chw[j] = src[i];
j++;
}
}
return chw;
}
public static Map<Object, Object> predictor() throws Exception{
srcw = src.width();
srch = src.height();
System.out.println("width:"+srcw+" hight:"+srch);
System.out.println("resize: \n width:"+netWidth+" hight:"+netHeight);
float scaleW=srcw/netWidth;
float scaleH=srch/netHeight;
// resize
Mat dst=new Mat();
Imgproc.resize(src, dst, new Size(netWidth, netHeight));
// 轉(zhuǎn)換成Tensor數(shù)據(jù)格式
OnnxTensor tensor = transferTensor(dst);
OrtSession.Result result = session.run(Collections.singletonMap("images", tensor));
System.out.println("res Data: "+result.get(0));
OnnxTensor res = (OnnxTensor)result.get(0);
float[][][] dataRes = (float[][][])res.getValue();
float[][] data = dataRes[0];
// 將矩陣轉(zhuǎn)置
// 先將xywh部分轉(zhuǎn)置
float rawData[][]=new float[data[0].length][6];
System.out.println(data.length-1);
for(int i=0;i<4;i++){
for(int j=0;j<data[0].length;j++){
rawData[j][i]=data[i][j];
}
}
// 保存每個(gè)檢查框置信值最高的類型置信值和該類型下標(biāo)
for(int i=0;i<data[0].length;i++){
for(int j=4;j<data.length;j++){
if(rawData[i][4]<data[j][i]){
rawData[i][4]=data[j][i]; //置信值
rawData[i][5]=j-4; //類型編號
}
}
}
List<ArrayList<Float>> boxes=new LinkedList<ArrayList<Float>>();
ArrayList<Float> box=null;
// 置信值過濾,xywh轉(zhuǎn)xyxy
for(float[] d:rawData){
// 置信值過濾
if(d[4]>confThreshold){
// xywh(xy為中心點(diǎn))轉(zhuǎn)xyxy
d[0]=d[0]-d[2]/2;
d[1]=d[1]-d[3]/2;
d[2]=d[0]+d[2];
d[3]=d[1]+d[3];
// 置信值符合的進(jìn)行插入法排序保存
box=new ArrayList<Float>();
for(float num:d) {
box.add(num);
}
if(boxes.size()==0){
boxes.add(box);
}else {
int i;
for(i=0;i<boxes.size();i++){
if(box.get(4)>boxes.get(i).get(4)){
boxes.add(i,box);
break;
}
}
// 插入到最后
if(i==boxes.size()){
boxes.add(box);
}
}
}
}
// 每個(gè)框分別有x1、x1、x2、y2、conf、class
//System.out.println(boxes);
// 非極大值抑制
int[] indexs=new int[boxes.size()];
Arrays.fill(indexs,1); //用于標(biāo)記1保留,0刪除
for(int cur=0;cur<boxes.size();cur++){
if(indexs[cur]==0){
continue;
}
ArrayList<Float> curMaxConf=boxes.get(cur); //當(dāng)前框代表該類置信值最大的框
for(int i=cur+1;i<boxes.size();i++){
if(indexs[i]==0){
continue;
}
float classIndex=boxes.get(i).get(5);
// 兩個(gè)檢測框都檢測到同一類數(shù)據(jù),通過iou來判斷是否檢測到同一目標(biāo),這就是非極大值抑制
if(classIndex==curMaxConf.get(5)){
float x1=curMaxConf.get(0);
float y1=curMaxConf.get(1);
float x2=curMaxConf.get(2);
float y2=curMaxConf.get(3);
float x3=boxes.get(i).get(0);
float y3=boxes.get(i).get(1);
float x4=boxes.get(i).get(2);
float y4=boxes.get(i).get(3);
//將幾種不相交的情況排除。提示:x1y1、x2y2、x3y3、x4y4對應(yīng)兩框的左上角和右下角
if(x1>x4||x2<x3||y1>y4||y2<y3){
continue;
}
// 兩個(gè)矩形的交集面積
float intersectionWidth =Math.max(x1, x3) - Math.min(x2, x4);
float intersectionHeight=Math.max(y1, y3) - Math.min(y2, y4);
float intersectionArea =Math.max(0,intersectionWidth * intersectionHeight);
// 兩個(gè)矩形的并集面積
float unionArea = (x2-x1)*(y2-y1)+(x4-x3)*(y4-y3)-intersectionArea;
// 計(jì)算IoU
float iou = intersectionArea / unionArea;
// 對交并比超過閾值的標(biāo)記
indexs[i]=iou>nmsThreshold?0:1;
//System.out.println(cur+" "+i+" class"+curMaxConf.get(5)+" "+classIndex+" u:"+unionArea+" i:"+intersectionArea+" iou:"+ iou);
}
}
}
List<ArrayList<Float>> resBoxes=new LinkedList<ArrayList<Float>>();
for(int index=0;index<indexs.length;index++){
if(indexs[index]==1) {
resBoxes.add(boxes.get(index));
}
}
boxes=resBoxes;
System.out.println("boxes.size : "+boxes.size());
for(ArrayList<Float> box1:boxes){
box1.set(0,box1.get(0)*scaleW);
box1.set(1,box1.get(1)*scaleH);
box1.set(2,box1.get(2)*scaleW);
box1.set(3,box1.get(3)*scaleH);
}
System.out.println("boxes: "+boxes);
//detect(boxes);
Map<Object,Object> map=new HashMap<Object,Object>();
map.put("boxes",boxes);
map.put("classNames",names);
return map;
}
public static Mat showDetect(Map<Object,Object> map){
List<ArrayList<Float>> boxes=(List<ArrayList<Float>>)map.get("boxes");
JSONObject names=(JSONObject) map.get("classNames");
Imgproc.resize(src,src,new Size(srcw,srch));
// 畫框,加數(shù)據(jù)
for(ArrayList<Float> box:boxes){
float x1=box.get(0);
float y1=box.get(1);
float x2=box.get(2);
float y2=box.get(3);
float config=box.get(4);
String className=(String)names.get((int)box.get(5).intValue());;
Point point1=new Point(x1,y1);
Point point2=new Point(x2,y2);
Imgproc.rectangle(src,point1,point2,new Scalar(0,0,255),2);
String conf=new DecimalFormat("#.###").format(config);
Imgproc.putText(src,className+" "+conf,new Point(x1,y1-5),0,0.5,new Scalar(255,0,0),1);
}
HighGui.imshow("image",src);
HighGui.waitKey();
return src;
}
public static void main(String[] args) throws Exception {
String modelPath="C:\\Users\\tianyu\\IdeaProjects\\test1\\src\\main\\java\\com\\sky\\best.onnx";
String path="C:\\Users\\tianyu\\IdeaProjects\\test1\\src\\main\\resources\\img\\img.png";
onnxLoadTest01.load(modelPath);
Map<Object,Object> map=onnxLoadTest01.predict(path);
showDetect(map);
}
}
效果:
參考文獻(xiàn):文章來源:http://www.zghlxwxcb.cn/news/detail-815211.html
使用 java-onnx 部署 yolovx 目標(biāo)檢測_java onnx-CSDN博客文章來源地址http://www.zghlxwxcb.cn/news/detail-815211.html
到了這里,關(guān)于ONNX格式模型 學(xué)習(xí)筆記 (onnxRuntime部署)---用java調(diào)用yolov8模型來舉例的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!