之前的博客中已經(jīng)實現(xiàn)了YOLOv4、YOLOR、YOLOX的剪枝,經(jīng)過了幾天的辛勤努力,終于實現(xiàn)了YOLOv5的剪枝。相關鏈接如下:
YOLOv4剪枝(剪枝相關細節(jié)理論這里有寫):YOLOv4剪枝
YOLOX剪枝:YOLOX剪枝
YOLOR剪枝:YOLOR剪枝
Paper:Pruning Filters for Efficient ConvNets
說明:本文章僅僅是實現(xiàn)了針對v5的剪枝的方法,至于怎么剪,剪哪些層需要根據(jù)自己的需求以及數(shù)據(jù)集來,不保證最終效果。
有關YOLOv5其他資料如大家需要可以參考以下我的其他文章:
通過yaml修改YOLOv5網(wǎng)絡
利用yaml自定義網(wǎng)絡模型
本文章實現(xiàn)功能如下:
1.訓練自己的數(shù)據(jù)集
2.對任意卷積層進行剪枝
3.剪枝后的訓練
4.剪枝后的模型預測
代碼:
1.訓練自己的數(shù)據(jù)集
將自己制作好的數(shù)據(jù)集放在dataset文件下,目錄形式如下:
dataset
|-- Annotations
|-- ImageSets
|-- images
|-- labels
Annotations是存放xml標簽文件的,images是存放圖像的,ImageSets存放四個txt文件【后面運行代碼的時候會自動生成】,labels是將xml轉txt文件。
1.運行makeTXT.py。這將會在ImageSets文件夾下生成 ?trainval.txt,test.txt,train.txt,val.txt四個文件【如果你打開這些txt文件,里面僅有圖像的名字】。
2.打開voc_label.py,并修改代碼 classes=[""]填入自己的類名,比如你的是訓練貓和狗,那么就是classes=["dog","cat"],然后運行該程序。此時會在labels文件下生成對應每個圖像的txt文件,形式如下:【最前面的0是類對應的索引,我這里只有一個類,后面的四個數(shù)為box的參數(shù),均歸一化以后的,分別表示box的左上和右下坐標,等訓練的時候會處理成center_x,center_y,w, h】
0 0.4723557692307693 0.5408653846153847 0.34375 0.8990384615384616 0 0.8834134615384616 0.5793269230769231 0.21875 0.8221153846153847?
3.在data文件夾下新建一個mydata.yaml文件。內(nèi)容如下【你也可以把coco.yaml復制過來】。
你只需要修改nc以及names即可,nc是類的數(shù)量,names是類的名字。
train: ./dataset/train.txt val: ./dataset/val.txt test: ./dataset/test.txt # number of classes nc: 1 # class names names: ['target']
4.終端輸入?yún)?shù),開始訓練。?
以yolov5s為例:
python train.py --weights yolov5s.pt --cfg models/yolov5s.yaml --data data/mydata.yaml
?
? ? ? ? ? ? ? ? ?from ?n ? ?params ?module ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?arguments
? 0 ? ? ? ? ? ? ? ?-1 ?1 ? ? ?3520 ?models.common.Conv ? ? ? ? ? ? ? ? ? ? ?[3, 32, 6, 2, 2] ? ? ? ? ? ? ?
? 1 ? ? ? ? ? ? ? ?-1 ?1 ? ? 18560 ?models.common.Conv ? ? ? ? ? ? ? ? ? ? ?[32, 64, 3, 2] ? ? ? ? ? ? ? ?
? 2 ? ? ? ? ? ? ? ?-1 ?1 ? ? 18816 ?models.common.C3 ? ? ? ? ? ? ? ? ? ? ? ?[64, 64, 1]
? 3 ? ? ? ? ? ? ? ?-1 ?1 ? ? 73984 ?models.common.Conv ? ? ? ? ? ? ? ? ? ? ?[64, 128, 3, 2]
? 4 ? ? ? ? ? ? ? ?-1 ?2 ? ?115712 ?models.common.C3 ? ? ? ? ? ? ? ? ? ? ? ?[128, 128, 2]
? 5 ? ? ? ? ? ? ? ?-1 ?1 ? ?295424 ?models.common.Conv ? ? ? ? ? ? ? ? ? ? ?[128, 256, 3, 2] ? ? ? ? ? ? ?
? 6 ? ? ? ? ? ? ? ?-1 ?3 ? ?625152 ?models.common.C3 ? ? ? ? ? ? ? ? ? ? ? ?[256, 256, 3]
? 7 ? ? ? ? ? ? ? ?-1 ?1 ? 1180672 ?models.common.Conv ? ? ? ? ? ? ? ? ? ? ?[256, 512, 3, 2] ? ? ? ? ? ? ?
? 8 ? ? ? ? ? ? ? ?-1 ?1 ? 1182720 ?models.common.C3 ? ? ? ? ? ? ? ? ? ? ? ?[512, 512, 1] ? ? ? ? ? ? ? ??
? 9 ? ? ? ? ? ? ? ?-1 ?1 ? ?656896 ?models.common.SPPF ? ? ? ? ? ? ? ? ? ? ?[512, 512, 5]
?10 ? ? ? ? ? ? ? ?-1 ?1 ? ?131584 ?models.common.Conv ? ? ? ? ? ? ? ? ? ? ?[512, 256, 1, 1]
?11 ? ? ? ? ? ? ? ?-1 ?1 ? ? ? ? 0 ?torch.nn.modules.upsampling.Upsample ? ?[None, 2, 'nearest']
?12 ? ? ? ? ? [-1, 6] ?1 ? ? ? ? 0 ?models.common.Concat ? ? ? ? ? ? ? ? ? ?[1] ? ? ? ? ? ? ? ? ? ? ? ? ??
?13 ? ? ? ? ? ? ? ?-1 ?1 ? ?361984 ?models.common.C3 ? ? ? ? ? ? ? ? ? ? ? ?[512, 256, 1, False]
?20 ? ? ? ? ? ? ? ?-1 ?1 ? ?296448 ?models.common.C3 ? ? ? ? ? ? ? ? ? ? ? ?[256, 256, 1, False]
?21 ? ? ? ? ? ? ? ?-1 ?1 ? ?590336 ?models.common.Conv ? ? ? ? ? ? ? ? ? ? ?[256, 256, 3, 2]
?22 ? ? ? ? ?[-1, 10] ?1 ? ? ? ? 0 ?models.common.Concat ? ? ? ? ? ? ? ? ? ?[1]
?23 ? ? ? ? ? ? ? ?-1 ?1 ? 1182720 ?models.common.C3 ? ? ? ? ? ? ? ? ? ? ? ?[512, 512, 1, False]
?24 ? ? ?[17, 20, 23] ?1 ? ? 16182 ?models.yolo.Detect ? ? ? ? ? ? ? ? ? ? ?[1, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]] ?
Model Summary: 270 layers, 7022326 parameters, 7022326 gradients, 15.8 GFLOPs
Starting training for 300 epochs...
? ? ?Epoch ? gpu_mem ? ? ? box ? ? ? obj ? ? ? cls ? ?labels ?img_size
? ? ?0/299 ? ?0.589G ? ?0.0779 ? 0.03841 ? ? ? ? 0 ? ? ? ? 4 ? ? ? 640: ? 6%|████▋ ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?| 23/359 [00:23<04:15, ?1.31it/s]?
看到以上信息就開始訓練了。?
2.對任意卷積層進行剪枝
在利用剪枝功能前,需要安裝一下剪枝的庫。需要安裝0.2.7版本,0.2.8有粉絲說有問題。剪枝時的一些log信息會自動保存在logs文件夾下,每個log的大小我設置的為1MB,如果有其他需要大家可以更改。
pip install torch_pruning==0.2.7
YOLOv5與我之前寫過的剪枝不同,v5在訓練保存后的權重本身就保存了完整的model,即用的是torch.save(model,...),而不是torch.save(model.state_dict(),...),因此不需要單獨在對網(wǎng)絡結構保存一次。?
模型剪枝代碼在tools/prunmodel.py。你只需要找到這部分代碼進行修改:我這里是以剪枝整個backbone的卷積層為例,如果你要剪枝的是其他層按需修改.included_layers內(nèi)就是你要剪枝的層。
"""
這里寫要剪枝的層
"""
included_layers = []
for layer in model.model[:10]:
if type(layer) is Conv:
included_layers.append(layer.conv)
elif type(layer) is C3:
included_layers.append(layer.cv1.conv)
included_layers.append(layer.cv2.conv)
included_layers.append(layer.cv3.conv)
elif type(layer) is SPPF:
included_layers.append(layer.cv1.conv)
included_layers.append(layer.cv2.conv)
?接下來在找到下面這行代碼,amount為剪枝率,同樣也是按需修改。【這里需要明白的一點,這里的剪枝率僅是對你要剪枝的所有層剪枝這么多,并不是把網(wǎng)絡從頭到尾全部剪,有些粉絲說我選了一層,剪枝率50%,怎么模型還那么大,沒啥變化,這個就是他搞混了,他以為是對整個網(wǎng)絡剪枝50%】。
pruning_plan = DG.get_pruning_plan(m, tp.prune_conv, idxs=strategy(m.weight, amount=0.8))
接下來調(diào)用剪枝函數(shù),傳入?yún)?shù)為自己的訓練好的權重文件路徑。
layer_pruning('../runs/train/exp/weights/best.pt')
?見到如下形式,就說明剪枝成功了,剪枝以后的權重會保存在model_data下,名字為layer_pruning.pt。
這里需要說明一下,保存的權重文件中不僅包含了網(wǎng)絡結構和權值內(nèi)容,還有優(yōu)化器的權值,如果僅僅保存網(wǎng)絡結構和權值也是可以的,這樣pt會更小一點,我這里默認都保存是為了和官方pt格式一致。
-------------
[ <DEP: prune_conv => prune_conv on model.9.cv2.conv (Conv2d(208, 512, kernel_size=(1, 1), stride=(1, 1), bias=False))>, Index=[0, 1, 2, 3, 7, 8, 10, 11, 12, 13, 16, 17, 18, 19, 21, 22, 23, 25, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 53, 54, 56, 57, 58, 59, 60, 61, 62, 63, 65, 67, 69, 70, 71, 72, 73, 74, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 89, 90, 91, 92, 95, 96, 97, 99, 100, 102, 103, 104, 105, 106, 107, 109, 110, 111, 113, 114, 115, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 135, 137, 139, 142, 143, 144, 146, 148, 150, 152, 153, 154, 155, 156, 157, 158, 159, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 173, 174, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 215, 216, 217, 219, 220, 221, 222, 223, 224, 225, 226, 228, 229, 230, 232, 233, 234, 235, 236, 237, 239, 240, 241, 242, 243, 246, 247, 248, 249, 251, 252, 253, 254, 257, 258, 259, 260, 263, 264, 265, 266, 267, 268, 270, 271, 272, 273, 274, 275, 276, 277, 278, 280, 281, 282, 283, 284, 285, 286, 287, 288, 292, 293, 294, 295, 296, 297, 299, 301, 302, 303, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 317, 318, 321, 322, 323, 324, 325, 326, 327, 329, 330, 331, 332, 334, 335, 338, 339, 341, 342, 343, 344, 346, 347, 349, 351, 353, 354, 355, 356, 357, 358, 359, 361, 362, 363, 364, 365, 366, 368, 369, 370, 372, 373, 374, 375, 378, 379, 381, 382, 383, 385, 386, 387, 388, 389, 390, 391, 392, 393, 395, 396, 397, 398, 399, 401, 402, 403, 404, 405, 407, 408, 411, 413, 414, 415, 416, 418, 419, 420, 421, 422, 423, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 440, 441, 442, 443, 444, 445, 446, 448, 449, 451, 452, 453, 454, 455, 456, 457, 458, 459, 461, 463, 465, 466, 468, 470, 472, 473, 474, 475, 476, 477, 478, 479, 480, 482, 483, 484, 485, 486, 487, 488, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 502, 503, 505, 506, 507, 510, 511], NumPruned=85072]
[ <DEP: prune_conv => prune_batchnorm on model.9.cv2.bn (BatchNorm2d(512, eps=0.001, momentum=0.03, affine=True, track_running_stats=True))>, Index=[0, 1, 2, 3, 7, 8, 10, 11, 12, 13, 16, 17, 18, 19, 21, 22, 23, 25, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 53, 54, 56, 57, 58, 59, 60, 61, 62, 63, 65, 67, 69, 70, 71, 72, 73, 74, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 89, 90, 91, 92, 95, 96, 97, 99, 100, 102, 103, 104, 105, 106, 107, 109, 110, 111, 113, 114, 115, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 135, 137, 139, 142, 143, 144, 146, 148, 150, 152, 153, 154, 155, 156, 157, 158, 159, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 173, 174, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 215, 216, 217, 219, 220, 221, 222, 223, 224, 225, 226, 228, 229, 230, 232, 233, 234, 235, 236, 237, 239, 240, 241, 242, 243, 246, 247, 248, 249, 251, 252, 253, 254, 257, 258, 259, 260, 263, 264, 265, 266, 267, 268, 270, 271, 272, 273, 274, 275, 276, 277, 278, 280, 281, 282, 283, 284, 285, 286, 287, 288, 292, 293, 294, 295, 296, 297, 299, 301, 302, 303, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 317, 318, 321, 322, 323, 324, 325, 326, 327, 329, 330, 331, 332, 334, 335, 338, 339, 341, 342, 343, 344, 346, 347, 349, 351, 353, 354, 355, 356, 357, 358, 359, 361, 362, 363, 364, 365, 366, 368, 369, 370, 372, 373, 374, 375, 378, 379, 381, 382, 383, 385, 386, 387, 388, 389, 390, 391, 392, 393, 395, 396, 397, 398, 399, 401, 402, 403, 404, 405, 407, 408, 411, 413, 414, 415, 416, 418, 419, 420, 421, 422, 423, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 440, 441, 442, 443, 444, 445, 446, 448, 449, 451, 452, 453, 454, 455, 456, 457, 458, 459, 461, 463, 465, 466, 468, 470, 472, 473, 474, 475, 476, 477, 478, 479, 480, 482, 483, 484, 485, 486, 487, 488, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 502, 503, 505, 506, 507, 510, 511], NumPruned=818]
[ <DEP: prune_batchnorm => _prune_elementwise_op on _ElementWiseOp()>, Index=[0, 1, 2, 3, 7, 8, 10, 11, 12, 13, 16, 17, 18, 19, 21, 22, 23, 25, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 53, 54, 56, 57, 58, 59, 60, 61, 62, 63, 65, 67, 69, 70, 71, 72, 73, 74, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 89, 90, 91, 92, 95, 96, 97, 99, 100, 102, 103, 104, 105, 106, 107, 109, 110, 111, 113, 114, 115, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 135, 137, 139, 142, 143, 144, 146, 148, 150, 152, 153, 154, 155, 156, 157, 158, 159, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 173, 174, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 215, 216, 217, 219, 220, 221, 222, 223, 224, 225, 226, 228, 229, 230, 232, 233, 234, 235, 236, 237, 239, 240, 241, 242, 243, 246, 247, 248, 249, 251, 252, 253, 254, 257, 258, 259, 260, 263, 264, 265, 266, 267, 268, 270, 271, 272, 273, 274, 275, 276, 277, 278, 280, 281, 282, 283, 284, 285, 286, 287, 288, 292, 293, 294, 295, 296, 297, 299, 301, 302, 303, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 317, 318, 321, 322, 323, 324, 325, 326, 327, 329, 330, 331, 332, 334, 335, 338, 339, 341, 342, 343, 344, 346, 347, 349, 351, 353, 354, 355, 356, 357, 358, 359, 361, 362, 363, 364, 365, 366, 368, 369, 370, 372, 373, 374, 375, 378, 379, 381, 382, 383, 385, 386, 387, 388, 389, 390, 391, 392, 393, 395, 396, 397, 398, 399, 401, 402, 403, 404, 405, 407, 408, 411, 413, 414, 415, 416, 418, 419, 420, 421, 422, 423, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 440, 441, 442, 443, 444, 445, 446, 448, 449, 451, 452, 453, 454, 455, 456, 457, 458, 459, 461, 463, 465, 466, 468, 470, 472, 473, 474, 475, 476, 477, 478, 479, 480, 482, 483, 484, 485, 486, 487, 488, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 502, 503, 505, 506, 507, 510, 511], NumPruned=0]
[ <DEP: _prune_elementwise_op => _prune_elementwise_op on _ElementWiseOp()>, Index=[0, 1, 2, 3, 7, 8, 10, 11, 12, 13, 16, 17, 18, 19, 21, 22, 23, 25, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 53, 54, 56, 57, 58, 59, 60, 61, 62, 63, 65, 67, 69, 70, 71, 72, 73, 74, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 89, 90, 91, 92, 95, 96, 97, 99, 100, 102, 103, 104, 105, 106, 107, 109, 110, 111, 113, 114, 115, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 135, 137, 139, 142, 143, 144, 146, 148, 150, 152, 153, 154, 155, 156, 157, 158, 159, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 173, 174, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 215, 216, 217, 219, 220, 221, 222, 223, 224, 225, 226, 228, 229, 230, 232, 233, 234, 235, 236, 237, 239, 240, 241, 242, 243, 246, 247, 248, 249, 251, 252, 253, 254, 257, 258, 259, 260, 263, 264, 265, 266, 267, 268, 270, 271, 272, 273, 274, 275, 276, 277, 278, 280, 281, 282, 283, 284, 285, 286, 287, 288, 292, 293, 294, 295, 296, 297, 299, 301, 302, 303, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 317, 318, 321, 322, 323, 324, 325, 326, 327, 329, 330, 331, 332, 334, 335, 338, 339, 341, 342, 343, 344, 346, 347, 349, 351, 353, 354, 355, 356, 357, 358, 359, 361, 362, 363, 364, 365, 366, 368, 369, 370, 372, 373, 374, 375, 378, 379, 381, 382, 383, 385, 386, 387, 388, 389, 390, 391, 392, 393, 395, 396, 397, 398, 399, 401, 402, 403, 404, 405, 407, 408, 411, 413, 414, 415, 416, 418, 419, 420, 421, 422, 423, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 440, 441, 442, 443, 444, 445, 446, 448, 449, 451, 452, 453, 454, 455, 456, 457, 458, 459, 461, 463, 465, 466, 468, 470, 472, 473, 474, 475, 476, 477, 478, 479, 480, 482, 483, 484, 485, 486, 487, 488, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 502, 503, 505, 506, 507, 510, 511], NumPruned=0]
[ <DEP: _prune_elementwise_op => prune_related_conv on model.10.conv (Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False))>, Index=[0, 1, 2, 3, 7, 8, 10, 11, 12, 13, 16, 17, 18, 19, 21, 22, 23, 25, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 53, 54, 56, 57, 58, 59, 60, 61, 62, 63, 65, 67, 69, 70, 71, 72, 73, 74, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 89, 90, 91, 92, 95, 96, 97, 99, 100, 102, 103, 104, 105, 106, 107, 109, 110, 111, 113, 114, 115, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 135, 137, 139, 142, 143, 144, 146, 148, 150, 152, 153, 154, 155, 156, 157, 158, 159, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 173, 174, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 215, 216, 217, 219, 220, 221, 222, 223, 224, 225, 226, 228, 229, 230, 232, 233, 234, 235, 236, 237, 239, 240, 241, 242, 243, 246, 247, 248, 249, 251, 252, 253, 254, 257, 258, 259, 260, 263, 264, 265, 266, 267, 268, 270, 271, 272, 273, 274, 275, 276, 277, 278, 280, 281, 282, 283, 284, 285, 286, 287, 288, 292, 293, 294, 295, 296, 297, 299, 301, 302, 303, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 317, 318, 321, 322, 323, 324, 325, 326, 327, 329, 330, 331, 332, 334, 335, 338, 339, 341, 342, 343, 344, 346, 347, 349, 351, 353, 354, 355, 356, 357, 358, 359, 361, 362, 363, 364, 365, 366, 368, 369, 370, 372, 373, 374, 375, 378, 379, 381, 382, 383, 385, 386, 387, 388, 389, 390, 391, 392, 393, 395, 396, 397, 398, 399, 401, 402, 403, 404, 405, 407, 408, 411, 413, 414, 415, 416, 418, 419, 420, 421, 422, 423, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 440, 441, 442, 443, 444, 445, 446, 448, 449, 451, 452, 453, 454, 455, 456, 457, 458, 459, 461, 463, 465, 466, 468, 470, 472, 473, 474, 475, 476, 477, 478, 479, 480, 482, 483, 484, 485, 486, 487, 488, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 502, 503, 505, 506, 507, 510, 511], NumPruned=104704]
190594 parameters will be pruned
-------------
2022-09-29 12:30:50.396 | INFO | __main__:layer_pruning:75 - Params: 7022326 => 3056461
2022-09-29 12:30:50.691 | INFO | __main__:layer_pruning:89 - 剪枝完成
如果你僅僅就想剪一層,可以這樣寫:
included_layers = [model.model[3].conv] # 僅僅想剪一個卷積層
這樣也可以檢測出來效果圖。
3.剪枝后的訓練
這里需要和稀疏訓練區(qū)別一下,因為很多人在之前項目中問我有沒有稀疏訓練。我這里的通道剪枝是離線式的,也就是針對已經(jīng)訓練好的模型進行剪枝,而邊訓練邊剪枝是在線式剪枝,這個訓練過程也就是稀疏訓練,所以還是有區(qū)別的。
訓練后的剪枝訓練與訓練部分是一樣的,只不過加一個pt參數(shù)而已。命令如下:
python train.py --weights model_data/layer_pruning.pt --data data/mydata.yaml --pt
4.剪枝后的模型預測
剪枝后的預測,和正常預測一樣。
python detect.py --weights model_data/layer_pruning.pt --source [你的圖像路徑]
這里再說明一下??!本文章只是給大家造個輪子,具體最終的剪枝效果,需要根據(jù)自己的需求以及實際效果來實現(xiàn),我對整個backbone剪枝80%后的微調(diào)訓練反正是效果很不好,對SPPF后其他的層剪枝還稍微好點,網(wǎng)上也有很多人說對backbone剪枝效果不行。
代碼:
GitHub - YINYIPENG-EN/Pruning_for_YOLOV5_pytorch
所遇問題:
1.剪枝后的微調(diào)訓練中如果采用原來優(yōu)化器中參數(shù)訓練可能會報以下錯誤:
訓練到一半報錯:RuntimeError: The size of tensor a (512) must match the size of tensor b (103) at non-singleton dimension 1文章來源:http://www.zghlxwxcb.cn/news/detail-801137.html
解決辦法:出現(xiàn)這種問題可能是由于原先用的SGD,但現(xiàn)在又用Adam訓練;另一種是剪枝后由于網(wǎng)絡結構發(fā)生了改變,原先優(yōu)化器的一些參數(shù)無法加載進去,可以采用key所對應value的shape進行加載,或者采用默認權重進行訓練,致于哪個效果好可以自行嘗試。?文章來源地址http://www.zghlxwxcb.cn/news/detail-801137.html
到了這里,關于YOLOV5通道剪枝【附代碼】的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關文章,希望大家以后多多支持TOY模板網(wǎng)!