1、數(shù)據(jù)集格式及存放
mmdet支持COCO格式和VOC格式,能用COCO格式,還是建議COCO的。網(wǎng)上有YOLO轉(zhuǎn)COCO,VOC轉(zhuǎn)COCO,可以自己轉(zhuǎn)換。
在mmdetection代碼的根目錄下,創(chuàng)建 data/coco
文件夾,按照coco的格式排放好數(shù)據(jù)集。annotations
下面是標(biāo)簽文件,train2017
、val2017
、test2017
是圖片。
2、修改兩處
第一處: mmdet/core/evalution/class_names.py
代碼下的 def coco_classes()
的 return 內(nèi)容改為自己數(shù)據(jù)集的類別;
第二處:mmdet/datasets/coco.py
代碼下的 class CocoDataset(CustomDataset)
的 CLASSES 改為自己數(shù)據(jù)集的類別;
注意:修改兩處后,一定要在根目錄下,輸入命令:python setup.py install build
重新編譯代碼,要不然類別會(huì)沒有載入,還是原coco類別,訓(xùn)練異常。
3、用訓(xùn)練命令生成配置文件
python tools/train.py configs/cascade_rcnn/cascade_rcnn_r50_fpn_1x_coco.py --work-dir work_dirs
其中,work_dirs是自己在根目錄新建的工作目錄,訓(xùn)練文件存儲(chǔ)在這里。
注意,此時(shí)運(yùn)行命令之后,并不是直接訓(xùn)練就可以不管了!我們還有參數(shù)設(shè)置沒改!這里輸入訓(xùn)練命令,只是需要它生成一個(gè)配置文件,便于我們改參數(shù)!
打開配置文件 cascade_rcnn_r50_fpn_1x_coco.py :
(1)修改 num_classes
,將其改為自己數(shù)據(jù)類別(直接全局搜索,有3處,都要改);
(2)修改 data_root
路徑和訓(xùn)練集、驗(yàn)證集、測(cè)試集的圖片和標(biāo)簽路徑,如下圖:
(3)修改訓(xùn)練圖片大小和學(xué)習(xí)率
修改下處代碼,可以更改圖片大小
img_scale = (1333, 800),
batch_size, mmdet默認(rèn)的方式是由 GPU 數(shù)量與 samples_per_gpu 參數(shù)決定:samples_per_gpu:
每個(gè)gpu讀取的圖像數(shù)量(意思不就是batch_size=2),該參數(shù)和訓(xùn)練時(shí)的gpu數(shù)量決定了訓(xùn)練時(shí)的batch_size。(為什么這么說(shuō)呢?因?yàn)閙mdet是8個(gè)GPU訓(xùn)練的,那么總的batch就是 8 *samples_per_gpu=16
,即訓(xùn)練時(shí)是batch_size為16) 。
但我們通常是只有一個(gè)gpu, 該參數(shù)設(shè)置為 2, 意思就是我們訓(xùn)練的 batch_size為2;
workers_per_gpu:
讀取數(shù)據(jù)時(shí)每個(gè)gpu分配的線程數(shù) ,一般設(shè)置為 2即可;(我感覺既然用單個(gè)GPU,設(shè)置到8也無(wú)妨吧?我還沒試)
學(xué)習(xí)率設(shè)置:
mmdet 默認(rèn)的學(xué)習(xí)率是基于8個(gè)gpu,而且默認(rèn)是1個(gè)GPU處理2個(gè)圖像(就上面說(shuō)的samples_per_gpu為2),可以這樣理解:
8個(gè)GPU,每個(gè)GPU處理2張圖片,那么真實(shí)訓(xùn)練總的一個(gè)batch就包括16張圖片,學(xué)習(xí)率為0.02;
4個(gè)GPU,每個(gè)GPU處理2張圖片,那么真實(shí)訓(xùn)練總的一個(gè)batch就包括8張圖片,學(xué)習(xí)率為0.01;
1個(gè)GPU,每個(gè)GPU處理2張圖片,那么真實(shí)訓(xùn)練總的一個(gè)batch就包括2張圖片,學(xué)習(xí)率為0.0025;
1個(gè)GPU,每個(gè)GPU處理1張圖片,那么真實(shí)訓(xùn)練總的一個(gè)batch就包括1張圖片,學(xué)習(xí)率為0.00125;
(4)使用預(yù)訓(xùn)練模型
提前從github上下載預(yù)訓(xùn)練模型,新建一個(gè)checkpoints文件夾下,放到里面。(模型下載鏈接:https://github.com/open-mmlab/mmdetection/blob/master/docs/en/model_zoo.md)
然后修改以下代碼:
# 原本是 load_from = None ,修改為
load_from = 'checkpoints/fcascade_rcnn_r50_fpn_1x_coco_20200316-3dc56deb.pth’
(5)訓(xùn)練輪數(shù),保存模型間隔,日志保存參數(shù)
4、正式訓(xùn)練開始
?。。】辞宄窂?!使用的是更改過的配置文件訓(xùn)練?。?!
python tools/train.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py
5、報(bào)錯(cuò)記錄
在第三步生成配置文件時(shí),遇到以下報(bào)錯(cuò):
AssertionError: The num_classes (10) in Shared2FCBBoxHead of
MMDataParallel does not matches the length of CLASSES 80) in
CocoDataset
即使在修改 coco.py 和 class_names.py 后運(yùn)行 python setup.py install仍然無(wú)法解決;
解決方法:
根據(jù)報(bào)錯(cuò)信息,找到自己虛擬環(huán)境的/mmdet/datasets/coco.py
和mmdet/core/evaluation/class_names.py
,再次修改CocoDataset()
和 coco_classes()
l兩處(跟第二步一樣,其實(shí)打開,就能看到虛擬環(huán)境下的并沒有修改成功)
參考鏈接:AssertionError: The num_classes (3) in Shared2FCBBoxHead of
MMDataParallel does not
matches
6、模型評(píng)價(jià)測(cè)試(VOC指標(biāo)mAP、COCO指標(biāo)AP)
(1)生成中間件
python tools/test.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py work_dirs/epoch_20.pth --out results.pkl
- work_dirs/cascade_rcnn_r50_fpn_1x_coco.py 模型配置文件(跟訓(xùn)練時(shí)的一樣)
- work_dirs/epoch_20.pth: 訓(xùn)練好的模型(我是訓(xùn)練了20epoch)
-
--out
指定 results.pkl 輸出目錄,可以自己指定輸出目錄
(2)使用COCO標(biāo)準(zhǔn)評(píng)估指標(biāo)
python tools/analysis_tools/eval_metric.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py results.pkl --eval=bbox
-
--eval
,COCO數(shù)據(jù)集可選參數(shù)有:bbox 、segm、proposal ;對(duì)VOC數(shù)據(jù)集可選參數(shù)有:mAP
(3)使用VOC標(biāo)準(zhǔn)評(píng)估指標(biāo)
# results.pkl 的順序別放錯(cuò),在中間。
python tools/voc_eval.py results.pkl work_dirs/cascade_rcnn_r50_fpn_1x_coco.py
- voc_eval.py 文件 mmdetection 2.X 版本刪除了,可以去老版本1.X 找找
7、繪制每個(gè)類別bbox 的結(jié)果曲線圖并保存
(1)使用 test.py 生成 results.bbox.json 文件(在根目錄下,路徑可自己指定)
python tools/test.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py work_dirs/epoch_20.pth --format-only --options "jsonfile_prefix=./results"
(2)獲得COCO bbox錯(cuò)誤結(jié)果每個(gè)類別,保存分析結(jié)果圖像到目錄results/
python tools/analysis_tools/coco_error_analysis.py results.bbox.json results --ann=data/coco/annotations/instances_val2017.json
- results.bbox.json:上一步生成的文件
- results: 結(jié)果曲線圖的生成目錄, 此處將生成到results/ 目錄下
- –ann=data/coco/annotations/instances_val2017.json: 數(shù)據(jù)集標(biāo)注文件存放路徑
8、統(tǒng)計(jì)模型參數(shù)量和FLOPs
python tools/analysis_tools/get_flops.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py --shape 640 640
-
--shape
參數(shù)指定輸入圖片尺寸
9 計(jì)算混淆矩陣
python tools/analysis_tools/confusion_matrix.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py results.pkl coco_confusion_matrix/
- 需要三個(gè)參數(shù),配置文件、pkl文件、輸出目錄
10 畫PR曲線
plot_pr_curve.py 代碼來(lái)自:https://blog.csdn.net/weixin_44966641/article/details/124558532
import os
import sys
import mmcv
import numpy as np
import argparse
import matplotlib.pyplot as plt
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from mmcv import Config
from mmdet.datasets import build_dataset
def plot_pr_curve(config_file, result_file, out_pic, metric="bbox"):
"""plot precison-recall curve based on testing results of pkl file.
Args:
config_file (list[list | tuple]): config file path.
result_file (str): pkl file of testing results path.
metric (str): Metrics to be evaluated. Options are
'bbox', 'segm'.
"""
cfg = Config.fromfile(config_file)
# turn on test mode of dataset
if isinstance(cfg.data.test, dict):
cfg.data.test.test_mode = True
elif isinstance(cfg.data.test, list):
for ds_cfg in cfg.data.test:
ds_cfg.test_mode = True
# build dataset
dataset = build_dataset(cfg.data.test)
# load result file in pkl format
pkl_results = mmcv.load(result_file)
# convert pkl file (list[list | tuple | ndarray]) to json
json_results, _ = dataset.format_results(pkl_results)
# initialize COCO instance
coco = COCO(annotation_file=cfg.data.test.ann_file)
coco_gt = coco
coco_dt = coco_gt.loadRes(json_results[metric])
# initialize COCOeval instance
coco_eval = COCOeval(coco_gt, coco_dt, metric)
coco_eval.evaluate()
coco_eval.accumulate()
coco_eval.summarize()
# extract eval data
precisions = coco_eval.eval["precision"]
'''
precisions[T, R, K, A, M]
T: iou thresholds [0.5 : 0.05 : 0.95], idx from 0 to 9
R: recall thresholds [0 : 0.01 : 1], idx from 0 to 100
K: category, idx from 0 to ...
A: area range, (all, small, medium, large), idx from 0 to 3
M: max dets, (1, 10, 100), idx from 0 to 2
'''
pr_array1 = precisions[0, :, 0, 0, 2]
pr_array2 = precisions[1, :, 0, 0, 2]
pr_array3 = precisions[2, :, 0, 0, 2]
pr_array4 = precisions[3, :, 0, 0, 2]
pr_array5 = precisions[4, :, 0, 0, 2]
pr_array6 = precisions[5, :, 0, 0, 2]
pr_array7 = precisions[6, :, 0, 0, 2]
pr_array8 = precisions[7, :, 0, 0, 2]
pr_array9 = precisions[8, :, 0, 0, 2]
pr_array10 = precisions[9, :, 0, 0, 2]
x = np.arange(0.0, 1.01, 0.01)
# plot PR curve
plt.plot(x, pr_array1, label="iou=0.5")
plt.plot(x, pr_array2, label="iou=0.55")
plt.plot(x, pr_array3, label="iou=0.6")
plt.plot(x, pr_array4, label="iou=0.65")
plt.plot(x, pr_array5, label="iou=0.7")
plt.plot(x, pr_array6, label="iou=0.75")
plt.plot(x, pr_array7, label="iou=0.8")
plt.plot(x, pr_array8, label="iou=0.85")
plt.plot(x, pr_array9, label="iou=0.9")
plt.plot(x, pr_array10, label="iou=0.95")
plt.xlabel("recall")
plt.ylabel("precison")
plt.xlim(0, 1.0)
plt.ylim(0, 1.01)
plt.grid(True)
plt.legend(loc="lower left")
plt.savefig(out_pic)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('config', help='config file path')
parser.add_argument('pkl_result_file', help='pkl result file path')
parser.add_argument('--out', default='pr_curve.png')
parser.add_argument('--eval', default='bbox')
cfg = parser.parse_args()
plot_pr_curve(config_file=cfg.config, result_file=cfg.pkl_result_file, out_pic=cfg.out, metric=cfg.eval)
輸入命令:
python plot_pr_curve.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py results.pkl
11 查看完整config配置文件
python tools/misc/print_config.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py
12 核查數(shù)據(jù)增強(qiáng)的結(jié)果是否正確
python tools/misc/browse_dataset.py work_dirs/cascade_rcnn_r50_fpn_1x_coco.py --output-dir work_dirs/
8、參考鏈接
https://blog.csdn.net/qq_35077107/article/details/124768460?spm=1001.2014.3001.5502文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-463234.html
https://blog.csdn.net/weixin_44966641/article/details/124558532文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-463234.html
到了這里,關(guān)于【MMDetection】——訓(xùn)練個(gè)人數(shù)據(jù)集的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!