国产 无码 综合区,色欲AV无码国产永久播放,无码天堂亚洲国产AV,国产日韩欧美女同一区二区

yolov5剪枝與知識蒸餾【附代碼】

這篇具有很好參考價值的文章主要介紹了yolov5剪枝與知識蒸餾【附代碼】。希望對大家有所幫助。如果存在錯誤或未考慮完全的地方,請大家不吝賜教,您也可以點擊"舉報違法"按鈕提交疑問。

剪枝和知識蒸餾均屬于模型輕量化設計,剪枝是將已有網(wǎng)絡通過剪枝的手段得到輕量化網(wǎng)絡,可分為非結(jié)構(gòu)化剪枝結(jié)構(gòu)化剪,該技術(shù)可以免去人為設計輕量網(wǎng)絡,而是通過計算各個權(quán)重或者通道的貢獻度大小,剪去貢獻度小的權(quán)重或通道,再經(jīng)過微調(diào)訓練恢復精度,得到最終的模型,這種方法自然也是可以的,但在某些任務中,如果剪枝較多效果會很差,即便微調(diào)訓練也恢復不了多少精度。

本文所用到的剪枝是通道剪枝(結(jié)構(gòu)化剪枝),可以參考我另外一篇博客(這篇文章被多個開源社區(qū)收藏,所以值得一試):YOLOv5通道剪枝,同時在我其他博客中還實現(xiàn)了YOLOV4,YOLOX,YOLOR,YOLOV7等剪枝,歡迎點贊收藏。

知識蒸餾是在一個精度高的大模型和一個精度低的小模型之間建立損失函數(shù),將大模型"壓縮"到小模型中【并不是嚴格意義上的壓縮】。這也是近兩年用的比較多的手段,之前的知識的蒸餾均是在分類網(wǎng)絡中進行,現(xiàn)在也開始應用于目標檢測。分類網(wǎng)絡的知識蒸餾可以參考:知識蒸餾,自蒸餾

目標檢測的知識蒸餾參考:SSD知識蒸餾

知識蒸餾的蒸餾方式有在線式和離線式,還可分為特征蒸餾和邏輯蒸餾。在這里我公布的代碼是離線式的邏輯蒸餾。

目錄

項目說明

環(huán)境說明

1.訓練自己的數(shù)據(jù)集

2.對任意卷積層進行剪枝

3.剪枝后的訓練

4.剪枝后的模型預測

5.知識蒸餾訓練

代碼


項目說明

1.訓練自己的數(shù)據(jù)集

2.對任意卷積層進行剪枝

3.剪枝后的訓練

4.剪枝后的模型預測

5.利用知識蒸餾對剪枝后模型進行訓練

環(huán)境說明

gitpython>=3.1.30
matplotlib>=3.3
numpy>=1.18.5
opencv-python>=4.1.1
Pillow>=7.1.2
psutil ?# system resources
PyYAML>=5.3.1
requests>=2.23.0
scipy>=1.4.1
thop>=0.1.1 ?# FLOPs computation
torch>=1.7.0 ?# see https://pytorch.org/get-started/locally (recommended)
torchvision>=0.8.1
tqdm>=4.64.0
ultralytics>=8.0.100
torch_pruning==0.2.7
pandas>=1.1.4
seaborn>=0.11.0

1.訓練自己的數(shù)據(jù)集

將自己制作好的數(shù)據(jù)集放在dataset文件下,目錄形式如下:

dataset

|-- Annotations

|-- ImageSets

|-- images

|-- labels

?Annotations是存放xml標簽文件的,images是存放圖像的ImageSets存放四個txt文件【后面運行代碼的時候會自動生成】,labels是將xml轉(zhuǎn)txt文件。

1.運行makeTXT.py。這將會在ImageSets文件夾下生成 trainval.txt,test.txt,train.txt,val.txt四個文件【如果你打開這些txt文件,里面僅有圖像的名字】。

2.打開voc_label.py,并修改代碼 classes=[""]填入自己的類名,比如你的是訓練貓和狗,那么就是classes=["dog","cat"],然后運行該程序。此時會在labels文件下生成對應每個圖像的txt文件,形式如下:【最前面的0是類對應的索引,我這里只有一個類,后面的四個數(shù)為box的參數(shù),均歸一化以后的,分別表示box的左上和右下坐標,等訓練的時候會處理成center_x,center_y,w, h】。形式如下。

0 0.4723557692307693 0.5408653846153847 0.34375 0.8990384615384616
0 0.8834134615384616 0.5793269230769231 0.21875 0.8221153846153847?

3.在data文件夾下新建一個mydata.yaml文件。內(nèi)容如下【你也可以把coco.yaml復制過來】。

你只需要修改nc以及names即可,nc是類的數(shù)量,names是類的名字。

train: ./dataset/train.txt
val: ./dataset/val.txt
test: ./dataset/test.txt

# number of classes
nc: 1

# class names
names: ['target']

4.終端輸入?yún)?shù),開始訓練。

以yolov5s為例:

python train.py --weights yolov5s.pt --cfg models/yolov5s.yaml --data data/mydata.yaml

from n params module arguments 0 -1 1 3520 models.common.Conv [3, 32, 6, 2, 2] 1 -1 1 18560 models.common.Conv [32, 64, 3, 2] 2 -1 1 18816 models.common.C3 [64, 64, 1] 3 -1 1 73984 models.common.Conv [64, 128, 3, 2] 4 -1 2 115712 models.common.C3 [128, 128, 2] 5 -1 1 295424 models.common.Conv [128, 256, 3, 2] 6 -1 3 625152 models.common.C3 [256, 256, 3] 7 -1 1 1180672 models.common.Conv [256, 512, 3, 2] 8 -1 1 1182720 models.common.C3 [512, 512, 1] 9 -1 1 656896 models.common.SPPF [512, 512, 5] 10 -1 1 131584 models.common.Conv [512, 256, 1, 1] 11 -1 1 0 torch.nn.modules.upsampling.Upsample [None, 2, 'nearest'] 12 [-1, 6] 1 0 models.common.Concat [1] 13 -1 1 361984 models.common.C3 [512, 256, 1, False] 20 -1 1 296448 models.common.C3 [256, 256, 1, False] 21 -1 1 590336 models.common.Conv [256, 256, 3, 2] 22 [-1, 10] 1 0 models.common.Concat [1] 23 -1 1 1182720 models.common.C3 [512, 512, 1, False] 24 [17, 20, 23] 1 16182 models.yolo.Detect [1, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]] Model Summary: 270 layers, 7022326 parameters, 7022326 gradients, 15.8 GFLOPs

Starting training for 300 epochs...

Epoch gpu_mem box obj cls labels img_size 0/299 0.589G 0.0779 0.03841 0 4 640: 6%|████▋ | 23/359 [00:23<04:15, 1.31it/s]

?看到以上信息就開始訓練了。

2.對任意卷積層進行剪枝

在利用剪枝功能前,需要安裝一下剪枝的庫。需要安裝0.2.7版本,0.2.8有粉絲說有問題。剪枝時的一些log信息會自動保存在logs文件夾下,每個log的大小我設置的為1MB,如果有其他需要大家可以更改。

pip install torch_pruning==0.2.7

YOLOv5與我之前寫過的剪枝不同,v5在訓練保存后的權(quán)重本身就保存了完整的model,即用的是torch.save(model,...),而不是torch.save(model.state_dict(),...),因此不需要單獨在對網(wǎng)絡結(jié)構(gòu)保存一次。

模型剪枝代碼在tools/prunmodel.py。你只需要找到這部分代碼進行修改:我這里是以剪枝整個backbone的卷積層為例,如果你要剪枝的是其他層按需修改.included_layers內(nèi)就是你要剪枝的層。

    """
    這里寫要剪枝的層
    """
    included_layers = []
    for layer in model.model[:10]:
        if type(layer) is Conv:
            included_layers.append(layer.conv)
        elif type(layer) is C3:
            included_layers.append(layer.cv1.conv)
            included_layers.append(layer.cv2.conv)
            included_layers.append(layer.cv3.conv)
        elif type(layer) is SPPF:
            included_layers.append(layer.cv1.conv)
            included_layers.append(layer.cv2.conv)

接下來在找到下面這行代碼,amount為剪枝率,同樣也是按需修改。【這里需要明白的一點,這里的剪枝率僅是對你要剪枝的所有層剪枝這么多,并不是把網(wǎng)絡從頭到尾全部剪,有些粉絲說我選了一層,剪枝率50%,怎么模型還那么大,沒啥變化,這個就是他搞混了,他以為是對整個網(wǎng)絡剪枝50%】。

pruning_plan = DG.get_pruning_plan(m, tp.prune_conv, idxs=strategy(m.weight, amount=0.8))

?接下來調(diào)用剪枝函數(shù),傳入?yún)?shù)為自己的訓練好的權(quán)重文件路徑。

layer_pruning('../runs/train/exp/weights/best.pt')

見到如下形式,就說明剪枝成功了,剪枝以后的權(quán)重會保存在model_data下,名字為layer_pruning.pt。

這里需要說明一下,保存的權(quán)重文件中不僅包含了網(wǎng)絡結(jié)構(gòu)和權(quán)值內(nèi)容,還有優(yōu)化器的權(quán)值,如果僅僅保存網(wǎng)絡結(jié)構(gòu)和權(quán)值也是可以的,這樣pt會更小一點,我這里默認都保存是為了和官方pt格式一致。

-------------
[ <DEP: prune_conv => prune_conv on model.9.cv2.conv (Conv2d(208, 512, kernel_size=(1, 1), stride=(1, 1), bias=False))>, Index=[0, 1, 2, 3, 7, 8, 10, 11, 12, 13, 16, 17, 18, 19, 21, 22, 23, 25, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 53, 54, 56, 57, 58, 59, 60, 61, 62, 63, 65, 67, 69, 70, 71, 72, 73, 74, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 89, 90, 91, 92, 95, 96, 97, 99, 100, 102, 103, 104, 105, 106, 107, 109, 110, 111, 113, 114, 115, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 135, 137, 139, 142, 143, 144, 146, 148, 150, 152, 153, 154, 155, 156, 157, 158, 159, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 173, 174, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 215, 216, 217, 219, 220, 221, 222, 223, 224, 225, 226, 228, 229, 230, 232, 233, 234, 235, 236, 237, 239, 240, 241, 242, 243, 246, 247, 248, 249, 251, 252, 253, 254, 257, 258, 259, 260, 263, 264, 265, 266, 267, 268, 270, 271, 272, 273, 274, 275, 276, 277, 278, 280, 281, 282, 283, 284, 285, 286, 287, 288, 292, 293, 294, 295, 296, 297, 299, 301, 302, 303, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 317, 318, 321, 322, 323, 324, 325, 326, 327, 329, 330, 331, 332, 334, 335, 338, 339, 341, 342, 343, 344, 346, 347, 349, 351, 353, 354, 355, 356, 357, 358, 359, 361, 362, 363, 364, 365, 366, 368, 369, 370, 372, 373, 374, 375, 378, 379, 381, 382, 383, 385, 386, 387, 388, 389, 390, 391, 392, 393, 395, 396, 397, 398, 399, 401, 402, 403, 404, 405, 407, 408, 411, 413, 414, 415, 416, 418, 419, 420, 421, 422, 423, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 440, 441, 442, 443, 444, 445, 446, 448, 449, 451, 452, 453, 454, 455, 456, 457, 458, 459, 461, 463, 465, 466, 468, 470, 472, 473, 474, 475, 476, 477, 478, 479, 480, 482, 483, 484, 485, 486, 487, 488, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 502, 503, 505, 506, 507, 510, 511], NumPruned=85072]
[ <DEP: prune_conv => prune_batchnorm on model.9.cv2.bn (BatchNorm2d(512, eps=0.001, momentum=0.03, affine=True, track_running_stats=True))>, Index=[0, 1, 2, 3, 7, 8, 10, 11, 12, 13, 16, 17, 18, 19, 21, 22, 23, 25, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 53, 54, 56, 57, 58, 59, 60, 61, 62, 63, 65, 67, 69, 70, 71, 72, 73, 74, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 89, 90, 91, 92, 95, 96, 97, 99, 100, 102, 103, 104, 105, 106, 107, 109, 110, 111, 113, 114, 115, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 135, 137, 139, 142, 143, 144, 146, 148, 150, 152, 153, 154, 155, 156, 157, 158, 159, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 173, 174, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 215, 216, 217, 219, 220, 221, 222, 223, 224, 225, 226, 228, 229, 230, 232, 233, 234, 235, 236, 237, 239, 240, 241, 242, 243, 246, 247, 248, 249, 251, 252, 253, 254, 257, 258, 259, 260, 263, 264, 265, 266, 267, 268, 270, 271, 272, 273, 274, 275, 276, 277, 278, 280, 281, 282, 283, 284, 285, 286, 287, 288, 292, 293, 294, 295, 296, 297, 299, 301, 302, 303, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 317, 318, 321, 322, 323, 324, 325, 326, 327, 329, 330, 331, 332, 334, 335, 338, 339, 341, 342, 343, 344, 346, 347, 349, 351, 353, 354, 355, 356, 357, 358, 359, 361, 362, 363, 364, 365, 366, 368, 369, 370, 372, 373, 374, 375, 378, 379, 381, 382, 383, 385, 386, 387, 388, 389, 390, 391, 392, 393, 395, 396, 397, 398, 399, 401, 402, 403, 404, 405, 407, 408, 411, 413, 414, 415, 416, 418, 419, 420, 421, 422, 423, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 440, 441, 442, 443, 444, 445, 446, 448, 449, 451, 452, 453, 454, 455, 456, 457, 458, 459, 461, 463, 465, 466, 468, 470, 472, 473, 474, 475, 476, 477, 478, 479, 480, 482, 483, 484, 485, 486, 487, 488, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 502, 503, 505, 506, 507, 510, 511], NumPruned=818]
[ <DEP: prune_batchnorm => _prune_elementwise_op on _ElementWiseOp()>, Index=[0, 1, 2, 3, 7, 8, 10, 11, 12, 13, 16, 17, 18, 19, 21, 22, 23, 25, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 53, 54, 56, 57, 58, 59, 60, 61, 62, 63, 65, 67, 69, 70, 71, 72, 73, 74, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 89, 90, 91, 92, 95, 96, 97, 99, 100, 102, 103, 104, 105, 106, 107, 109, 110, 111, 113, 114, 115, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 135, 137, 139, 142, 143, 144, 146, 148, 150, 152, 153, 154, 155, 156, 157, 158, 159, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 173, 174, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 215, 216, 217, 219, 220, 221, 222, 223, 224, 225, 226, 228, 229, 230, 232, 233, 234, 235, 236, 237, 239, 240, 241, 242, 243, 246, 247, 248, 249, 251, 252, 253, 254, 257, 258, 259, 260, 263, 264, 265, 266, 267, 268, 270, 271, 272, 273, 274, 275, 276, 277, 278, 280, 281, 282, 283, 284, 285, 286, 287, 288, 292, 293, 294, 295, 296, 297, 299, 301, 302, 303, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 317, 318, 321, 322, 323, 324, 325, 326, 327, 329, 330, 331, 332, 334, 335, 338, 339, 341, 342, 343, 344, 346, 347, 349, 351, 353, 354, 355, 356, 357, 358, 359, 361, 362, 363, 364, 365, 366, 368, 369, 370, 372, 373, 374, 375, 378, 379, 381, 382, 383, 385, 386, 387, 388, 389, 390, 391, 392, 393, 395, 396, 397, 398, 399, 401, 402, 403, 404, 405, 407, 408, 411, 413, 414, 415, 416, 418, 419, 420, 421, 422, 423, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 440, 441, 442, 443, 444, 445, 446, 448, 449, 451, 452, 453, 454, 455, 456, 457, 458, 459, 461, 463, 465, 466, 468, 470, 472, 473, 474, 475, 476, 477, 478, 479, 480, 482, 483, 484, 485, 486, 487, 488, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 502, 503, 505, 506, 507, 510, 511], NumPruned=0]
[ <DEP: _prune_elementwise_op => _prune_elementwise_op on _ElementWiseOp()>, Index=[0, 1, 2, 3, 7, 8, 10, 11, 12, 13, 16, 17, 18, 19, 21, 22, 23, 25, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 53, 54, 56, 57, 58, 59, 60, 61, 62, 63, 65, 67, 69, 70, 71, 72, 73, 74, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 89, 90, 91, 92, 95, 96, 97, 99, 100, 102, 103, 104, 105, 106, 107, 109, 110, 111, 113, 114, 115, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 135, 137, 139, 142, 143, 144, 146, 148, 150, 152, 153, 154, 155, 156, 157, 158, 159, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 173, 174, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 215, 216, 217, 219, 220, 221, 222, 223, 224, 225, 226, 228, 229, 230, 232, 233, 234, 235, 236, 237, 239, 240, 241, 242, 243, 246, 247, 248, 249, 251, 252, 253, 254, 257, 258, 259, 260, 263, 264, 265, 266, 267, 268, 270, 271, 272, 273, 274, 275, 276, 277, 278, 280, 281, 282, 283, 284, 285, 286, 287, 288, 292, 293, 294, 295, 296, 297, 299, 301, 302, 303, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 317, 318, 321, 322, 323, 324, 325, 326, 327, 329, 330, 331, 332, 334, 335, 338, 339, 341, 342, 343, 344, 346, 347, 349, 351, 353, 354, 355, 356, 357, 358, 359, 361, 362, 363, 364, 365, 366, 368, 369, 370, 372, 373, 374, 375, 378, 379, 381, 382, 383, 385, 386, 387, 388, 389, 390, 391, 392, 393, 395, 396, 397, 398, 399, 401, 402, 403, 404, 405, 407, 408, 411, 413, 414, 415, 416, 418, 419, 420, 421, 422, 423, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 440, 441, 442, 443, 444, 445, 446, 448, 449, 451, 452, 453, 454, 455, 456, 457, 458, 459, 461, 463, 465, 466, 468, 470, 472, 473, 474, 475, 476, 477, 478, 479, 480, 482, 483, 484, 485, 486, 487, 488, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 502, 503, 505, 506, 507, 510, 511], NumPruned=0]
[ <DEP: _prune_elementwise_op => prune_related_conv on model.10.conv (Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False))>, Index=[0, 1, 2, 3, 7, 8, 10, 11, 12, 13, 16, 17, 18, 19, 21, 22, 23, 25, 27, 28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 53, 54, 56, 57, 58, 59, 60, 61, 62, 63, 65, 67, 69, 70, 71, 72, 73, 74, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 89, 90, 91, 92, 95, 96, 97, 99, 100, 102, 103, 104, 105, 106, 107, 109, 110, 111, 113, 114, 115, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 132, 133, 135, 137, 139, 142, 143, 144, 146, 148, 150, 152, 153, 154, 155, 156, 157, 158, 159, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 173, 174, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 215, 216, 217, 219, 220, 221, 222, 223, 224, 225, 226, 228, 229, 230, 232, 233, 234, 235, 236, 237, 239, 240, 241, 242, 243, 246, 247, 248, 249, 251, 252, 253, 254, 257, 258, 259, 260, 263, 264, 265, 266, 267, 268, 270, 271, 272, 273, 274, 275, 276, 277, 278, 280, 281, 282, 283, 284, 285, 286, 287, 288, 292, 293, 294, 295, 296, 297, 299, 301, 302, 303, 306, 307, 308, 309, 310, 311, 312, 313, 314, 315, 317, 318, 321, 322, 323, 324, 325, 326, 327, 329, 330, 331, 332, 334, 335, 338, 339, 341, 342, 343, 344, 346, 347, 349, 351, 353, 354, 355, 356, 357, 358, 359, 361, 362, 363, 364, 365, 366, 368, 369, 370, 372, 373, 374, 375, 378, 379, 381, 382, 383, 385, 386, 387, 388, 389, 390, 391, 392, 393, 395, 396, 397, 398, 399, 401, 402, 403, 404, 405, 407, 408, 411, 413, 414, 415, 416, 418, 419, 420, 421, 422, 423, 425, 426, 427, 428, 429, 430, 431, 432, 433, 434, 435, 436, 437, 438, 440, 441, 442, 443, 444, 445, 446, 448, 449, 451, 452, 453, 454, 455, 456, 457, 458, 459, 461, 463, 465, 466, 468, 470, 472, 473, 474, 475, 476, 477, 478, 479, 480, 482, 483, 484, 485, 486, 487, 488, 490, 491, 492, 493, 494, 495, 496, 497, 498, 499, 500, 502, 503, 505, 506, 507, 510, 511], NumPruned=104704]
190594 parameters will be pruned
-------------

2022-09-29 12:30:50.396 | INFO ? ? | __main__:layer_pruning:75 - ? Params: 7022326 => 3056461

2022-09-29 12:30:50.691 | INFO ? ? | __main__:layer_pruning:89 - 剪枝完成

?如果你僅僅就想剪一層,可以這樣寫:

included_layers = [model.model[3].conv] # 僅僅想剪一個卷積層

3.剪枝后的訓練

這里需要和稀疏訓練區(qū)別一下,因為很多人在之前項目中問我有沒有稀疏訓練。我這里的通道剪枝是離線式的,也就是針對已經(jīng)訓練好的模型進行剪枝,而邊訓練邊剪枝是在線式剪枝,這個訓練過程也就是稀疏訓練,所以還是有區(qū)別的。

訓練后的剪枝訓練與訓練部分是一樣的,只不過加一個pt參數(shù)而已。命令如下:

python train.py --weights model_data/layer_pruning.pt --data data/mydata.yaml --pt 

4.剪枝后的模型預測

剪枝后的預測,和正常預測一樣。

python detect.py --weights model_data/layer_pruning.pt --source [你的圖像路徑]

這里再說明一下!!本文章只是給大家造個輪子,具體最終的剪枝效果,需要根據(jù)自己的需求以及實際效果來實現(xiàn),我對整個backbone剪枝80%后的微調(diào)訓練反正是效果很不好,對SPPF后其他的層剪枝還稍微好點,網(wǎng)上也有很多人說對backbone剪枝效果不行。

5.知識蒸餾訓練

項目需求:想用知識蒸餾做剪枝后網(wǎng)絡的微調(diào)訓練

教師網(wǎng)絡:未剪枝前的

學生網(wǎng)絡:剪枝后的

由于學生網(wǎng)絡是剪枝后的,因此可以脫離模型的yaml配置文件。

本項目的知識蒸餾是邏輯蒸餾(沒有做特征層的蒸餾)。

模型實例化代碼

s_ckpt = torch.load(s_weights, map_location=device)
s_model = s_ckpt['model']  # 學生網(wǎng)絡

# 教師網(wǎng)絡的創(chuàng)建
t_ckpt = torch.load(t_weights, map_location=device)
t_model = Model(t_cfg, ch=3, nc=nc, anchors=hyp.get('anchors')).to(device)  # teacher model create

蒸餾的關(guān)鍵代碼

其中d_weight是蒸餾權(quán)重??梢愿鶕?jù)自己的實際情況調(diào)整。

s_pred = s_model(imgs)  # student forward
_, t_pred = t_model(imgs)  # teacher forward
s_hard_loss, loss_items = compute_loss(s_pred, targets.to(device))  # student hard loss
d_outputs_loss = compute_distillation_output_loss(s_pred, t_pred, s_model, d_weight=10)
loss = d_outputs_loss + s_hard_loss

--t_weights:教師網(wǎng)絡權(quán)重路徑

--s_weights:學生網(wǎng)絡權(quán)重路徑

--data:data.yaml路徑

--kd:開啟蒸餾訓練

python train_dil.py --t_weights best.pt --s_weights layer_pruning.pt --data data/mydata.yaml --batch-size 16 --kd

訓練后的結(jié)果會保存在runs/train/exp_kd中


代碼

GitHub - YINYIPENG-EN/Knowledge_distillation_Pruning_Yolov5: 本項目支持對剪枝后的yolov5模型進行知識蒸餾訓練(This project supports knowledge distillation training for the pruned YOLOv5 model)


補充說明:測試效果要根據(jù)實際應用場景、數(shù)據(jù)集、網(wǎng)絡模型等有關(guān),本文章發(fā)布的代碼并不是萬能的~?


2024.01.28更新功能:添加了用已訓練好的模型自動標注數(shù)據(jù)集,歡迎使用文章來源地址http://www.zghlxwxcb.cn/news/detail-454357.html

到了這里,關(guān)于yolov5剪枝與知識蒸餾【附代碼】的文章就介紹完了。如果您還想了解更多內(nèi)容,請在右上角搜索TOY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!

本文來自互聯(lián)網(wǎng)用戶投稿,該文觀點僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務,不擁有所有權(quán),不承擔相關(guān)法律責任。如若轉(zhuǎn)載,請注明出處: 如若內(nèi)容造成侵權(quán)/違法違規(guī)/事實不符,請點擊違法舉報進行投訴反饋,一經(jīng)查實,立即刪除!

領支付寶紅包贊助服務器費用

相關(guān)文章

  • 【深度學習之模型優(yōu)化】模型剪枝、模型量化、知識蒸餾概述

    【深度學習之模型優(yōu)化】模型剪枝、模型量化、知識蒸餾概述

    ? ? ? ??模型部署優(yōu)化這個方向其實比較寬泛。從模型完成訓練,到最終將模型部署到實際硬件上,整個流程中會涉及到很多不同層面的工作,每一個環(huán)節(jié)對技術(shù)點的要求也不盡相同。但本質(zhì)的工作無疑是通過減小模型大小,提高推理速度等,使得模型能夠成功部署在各個硬

    2024年01月23日
    瀏覽(16)
  • YOLOv5剪枝??| 模型剪枝實戰(zhàn)篇

    YOLOv5剪枝??| 模型剪枝實戰(zhàn)篇

    本篇博文所用代碼為開源項目修改得到,且不適合基礎太差的同學。 本篇文章主要講解代碼的使用方式,手把手帶你實現(xiàn)YOLOv5模型剪枝操作。 0. 環(huán)境準備 終端鍵入:

    2024年02月05日
    瀏覽(47)
  • yolov5s模型剪枝詳細過程(v6.0)

    yolov5s模型剪枝詳細過程(v6.0)

    本文參考github上大神的開源剪枝項目進行學習與分享,具體鏈接放在文后,希望與大家多多交流! 在官方源碼上訓練yolov5模型,支持v6.0分支的n/s/m/l模型,我這里使用的是v5s,得到后將項目clone到本機上 cd進入文件夾后,新建runs文件夾,將訓練好的模型放入runs/your_train/weigh

    2024年02月03日
    瀏覽(25)
  • 從0開始做yolov5模型剪枝

    從0開始做yolov5模型剪枝

    【整個流程中,在正常train,sparityTrain,prune,finetune遇到10多個的問題,包括AttributeError、ModuleNotFoundError、RuntimeError、SyntaxError、TypeError等問題的解決方法,詳見內(nèi)容】 為了將現(xiàn)有模型移植到ARM平臺,同時保證模型準確率的基礎上,減少模型的算力消耗和推理時間。 之前有做

    2024年02月11日
    瀏覽(22)
  • 改進的yolov5目標檢測-yolov5替換骨干網(wǎng)絡-yolo剪枝(TensorRT及NCNN部署)

    改進的yolov5目標檢測-yolov5替換骨干網(wǎng)絡-yolo剪枝(TensorRT及NCNN部署)

    2022.10.30 復現(xiàn)TPH-YOLOv5 2022.10.31 完成替換backbone為Ghostnet 2022.11.02 完成替換backbone為Shufflenetv2 2022.11.05 完成替換backbone為Mobilenetv3Small 2022.11.10 完成EagleEye對YOLOv5系列剪枝支持 2022.11.14 完成MQBench對YOLOv5系列量化支持 2022.11.16 完成替換backbone為EfficientNetLite-0 2022.11.26 完成替換backbone為

    2024年01月17日
    瀏覽(28)
  • Yolov5口罩佩戴實時檢測項目(模型剪枝+opencv+python推理)

    Yolov5口罩佩戴實時檢測項目(模型剪枝+opencv+python推理)

    如果只是想體驗項目,請直接跳轉(zhuǎn)到本文第2節(jié),或者跳轉(zhuǎn)到我的facemask_detect。 剪枝的代碼可以查看我的github:yolov5-6.2-pruning 第1章是講述如何得到第2章用到的onnx格式的模型文件(我的項目里直接提供了這個文件)。 第2章開始講述如何使用cv2.dnn加載onnx文件并推理yolov5n模型

    2023年04月08日
    瀏覽(24)
  • yolov8(目標檢測、圖像分割、關(guān)鍵點檢測)知識蒸餾:logit和feature-based蒸餾方法的實現(xiàn)

    yolov8(目標檢測、圖像分割、關(guān)鍵點檢測)知識蒸餾:logit和feature-based蒸餾方法的實現(xiàn)

    在目標檢測中,知識蒸餾的原理主要是利用教師模型(通常是大型的深度神經(jīng)網(wǎng)絡)的豐富知識來指導學生模型(輕量級的神經(jīng)網(wǎng)絡)的學習過程。通過蒸餾,學生模型能夠在保持較高性能的同時,減小模型的復雜度和計算成本。 知識蒸餾實現(xiàn)的方式有多種,但核心目標是將

    2024年04月28日
    瀏覽(96)
  • 量化、蒸餾、分解、剪枝

    ????????量化、蒸餾、分解和剪枝都是用于深度學習模型壓縮和優(yōu)化的算法。 ???????? 量化 是一種用于減少深度學習模型計算量和內(nèi)存消耗的技術(shù)。在深度學習中,模型通常使用高精度的浮點數(shù)表示參數(shù)和激活值,但這種表示方式會占用大量的內(nèi)存和計算資源。而量

    2024年02月05日
    瀏覽(17)
  • 知識蒸餾實戰(zhàn)代碼教學二(代碼實戰(zhàn)部分)

    知識蒸餾實戰(zhàn)代碼教學二(代碼實戰(zhàn)部分)

    ? ? ? ? (1)首先我們要先訓練出較大模型既teacher模型。(在圖中沒有出現(xiàn)) ? ? ? ? (2)再對teacher模型進行蒸餾,此時我們已經(jīng)有一個訓練好的teacher模型,所以我們能很容易知道teacher模型輸入特征x之后,預測出來的結(jié)果teacher_preds標簽。 ? ? ? ? (3)此時,求到老師

    2024年02月20日
    瀏覽(19)
  • 知識蒸餾實戰(zhàn)代碼教學一(原理部分)

    知識蒸餾實戰(zhàn)代碼教學一(原理部分)

    ????????知識蒸餾(Knowledge Distillation)源自于一篇由Hinton等人于2015年提出的論文《Distilling the Knowledge in a Neural Network》。這個方法旨在將一個大型、復雜的模型的知識(通常稱為教師模型)轉(zhuǎn)移到一個小型、簡化的模型(通常稱為學生模型)中。通過這種方式,學生模型

    2024年02月20日
    瀏覽(24)

覺得文章有用就打賞一下文章作者

支付寶掃一掃打賞

博客贊助

微信掃一掃打賞

請作者喝杯咖啡吧~博客贊助

支付寶掃一掃領取紅包,優(yōu)惠每天領

二維碼1

領取紅包

二維碼2

領紅包