這兩天自己手寫了一個(gè)可以簡單實(shí)現(xiàn)通道剪枝的代碼,在這篇文章中也會對代碼進(jìn)行講解,方便大家在自己代碼中的使用。
如果還想學(xué)習(xí)YOLO系列的剪枝代碼,可以參考我其他文章,下面的這些文章都是我根據(jù)通道剪枝的論文在YOLO上進(jìn)行的實(shí)現(xiàn),而本篇文章是我自己寫的,也是希望能幫助一些想學(xué)剪枝的人入門,希望多多支持:
YOLOv4剪枝
YOLOX剪枝
YOLOR剪枝
YOLOv5剪枝
YOLOv7剪枝
目錄
網(wǎng)絡(luò)定義
剪枝代碼詳解?
計(jì)算各通道貢獻(xiàn)度
對貢獻(xiàn)度進(jìn)行排序?
計(jì)算要剪掉的通道數(shù)量
新建卷積層?
?權(quán)重的重分配
新卷積代替model中的舊卷積?
新建BN層?
剪枝前后網(wǎng)絡(luò)結(jié)構(gòu)以及參數(shù)對比?
?完整代碼
更新內(nèi)容:
?剪枝:
繪制2D權(quán)重:
繪制3D權(quán)重?
還有一點(diǎn)需要說明,本篇文章現(xiàn)僅支持卷積層的剪枝(后續(xù)會不斷更新其他網(wǎng)絡(luò)類型),暫未加入其他類型的剪枝,比如BN層,所以各位在嘗試的需要注意一下,不然容易報(bào)錯(新版本已經(jīng)支持BN層的輕量化處理,已在github中同步更新)。接下來步入正題。
通道剪枝屬于結(jié)構(gòu)化剪枝的一種,該方法可以根據(jù)各通道權(quán)重大小來進(jìn)行修剪??梢詫⒛切┴暙I(xiàn)度小的通道刪除,僅保留貢獻(xiàn)度大的通道,最終得到修剪后的新卷積,以此減少參數(shù),同時(shí)也希望較少的減少精度損失。
一般情況會用L1或者L2來計(jì)算各通道權(quán)重,然后對通道進(jìn)行排序后再剪枝。
網(wǎng)絡(luò)定義
首先我們先定義一個(gè)全卷積網(wǎng)絡(luò)(僅有卷積層和激活函數(shù)層),該網(wǎng)絡(luò)由8層卷積構(gòu)成,代碼如下:
class Model(nn.Module):
def __init__(self, in_channels):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, padding=1, bias=False)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False)
self.act2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False)
self.act3 = nn.ReLU(inplace=True)
self.conv4 = nn.Conv2d(128, 256, 3, 1, 1, bias=False)
self.act4 = nn.ReLU(inplace=True)
self.conv5 = nn.Conv2d(256, 512, 3, 1, 1, bias=False)
self.act5 = nn.ReLU(inplace=True)
self.conv6 = nn.Conv2d(512, 1024, 3, 1, 1, bias=False)
self.act6 = nn.ReLU(inplace=True)
self.conv7 = nn.Conv2d(1024, 2048, 3, 1, 1, bias=False)
self.act7 = nn.ReLU(inplace=True)
self.conv8 = nn.Conv2d(2048, 4096, 3, 1, 1, bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.act2(x)
x = self.conv3(x)
x = self.act3(x)
x = self.conv4(x)
x = self.act4(x)
x = self.conv5(x)
x = self.act5(x)
x = self.conv6(x)
x = self.act6(x)
x = self.conv7(x)
x = self.act7(x)
out = self.conv8(x)
return out
剪枝代碼詳解?
接下來是根據(jù)剪枝的思想寫剪枝函數(shù)(完整的代碼我會在文末附上)。
定義剪枝函數(shù)prune,這里傳入兩個(gè)參數(shù),model:即傳入我們前面定義的網(wǎng)絡(luò)。percentage:剪枝率,比如當(dāng)percentage為0.5的時(shí)候表示對該卷積50%的通道進(jìn)行剪枝。這里的importance是一個(gè)字典類型,用來存儲各個(gè)卷積層通道L1值。
def prune(model, percentage):
# 計(jì)算每個(gè)通道的L1-norm并排序
importance = {}
計(jì)算各通道貢獻(xiàn)度
model.named_modules()可以獲得模型每層的名字以及該層的類型,比如對前面定義的模型進(jìn)行遍歷時(shí),name='conv1',module=nn.Conv2d。
通過isinstance用來判斷我們剪枝的類型,我這里寫的是nn.Conv2d,表示對卷積進(jìn)行剪枝(暫未加入BN層)。?
torch.norm是可以計(jì)算范數(shù),我們傳入的數(shù)據(jù)是該層的所有通道的權(quán)值,1表示L1-norm,如果你寫2就是2范數(shù),dim=(1,2,3)是對該維度進(jìn)行計(jì)算。因?yàn)槲覀兙矸e核的shape是[out_channels,in_channels,kernel_size,kernel_size],比如conv1的shape就是[32,3,3,3],因此dim=(1,2,3)。
所以下述代碼表示:判斷網(wǎng)絡(luò)各層屬性是否為卷積層,如果是卷積,那么在輸出通道維度上計(jì)算該卷積各通道的L1范數(shù)。
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
importance[name] = torch.norm(module.weight.data, 1, dim=(1, 2, 3))
計(jì)算值如下(這里只舉一個(gè)層為例):
?{'conv1': tensor([2.3424, 2.3291, 2.2797, 3.1257, 2.7289, 2.4918, 2.4897, 2.9199, 2.0484,
? ? ? ? 2.4627, 2.5531, 2.2539, 2.4477, 2.3570, 2.5563, 2.9574, 2.7499, 2.0182,
? ? ? ? 2.8837, 2.5835, 2.8180, 2.2055, 3.0783, 2.7072, 2.8927, 2.4416, 2.7805,
? ? ? ? 2.7791, 2.6328, 2.8975, 2.9354, 2.6887])}
對貢獻(xiàn)度進(jìn)行排序?
這一行代碼就是對上面計(jì)算的L1進(jìn)行排序,只不過這里返回的sorted_channels是各個(gè)通道的索引。
# 對通道進(jìn)行排序,返回索引
sorted_channels = np.argsort(np.concatenate([x.cpu().numpy().flatten() for x in importance[name]]))
?得到的排序結(jié)果如下(從小到大排序),注意返回的是通道索引:
[17 ?8 21 11 ?2 ?1 ?0 13 25 12 ?9 ?6 ?5 10 14 19 28 31 23 ?4 16 27 26 20, 18 24 29 ?7 30 15 22 ?3]
計(jì)算要剪掉的通道數(shù)量
?num_channels_to_prune是要剪掉的通道數(shù)量,比如此時(shí)我設(shè)置的剪枝率為0.5,conv1的輸出通道為32,那么剪去50%就是16個(gè)。
# 要剪掉的通道數(shù)量
num_channels_to_prune = int(len(sorted_channels) * percentage)
下面為輸出結(jié)果,表示conv1層要剪16個(gè)通道?
2023-04-19 09:05:42.241 | INFO ? ? | __main__:prune:70 - The number of channels that need to be cut off in the conv1 layer is 16
?這16個(gè)通道索引為:
conv1 layer pruning channel index is [17 ?8 21 11 ?2 ?1 ?0 13 25 12 ?9 ?6 ?5 10 14 19]
新建卷積層?
new_module是新建的卷積層,該卷積層用來接收剪枝后的結(jié)果。
這里需要注意一點(diǎn)的是,我這里輸入通道in_channels用的是3 if module.in_channels==3 else in_channels,這是因?yàn)槿绻热缒銓onv1剪枝后,那么該層的輸出通道會改變,下一層的conv2的輸入通道如果不變的化會報(bào)shape的錯誤,因?yàn)橄聦拥妮斎胧巧蠈拥妮敵?,因此每層剪枝的時(shí)候需要記錄一下通道的變化。然后其他屬性不變。
new_module = nn.Conv2d(in_channels=3 if module.in_channels == 3 else in_channels, # *
out_channels=module.out_channels - num_channels_to_prune,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=(module.bias is not None)
).to(next(model.parameters()).device)
in_channels = new_module.out_channels # 因?yàn)榍耙粚拥妮敵鐾ǖ罆绊懴乱粚拥妮斎胪ǖ?/code>
此時(shí)創(chuàng)建的new_module為,可以看到新建的卷積輸出通道為16:
Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
同時(shí)可以看一下這個(gè)new_module卷積部分默認(rèn)的權(quán)重參數(shù)(注意留意一下這里,后面要做對比的):
Parameter containing:
tensor([[[[ 0.1232, ?0.0262, -0.0958],
? ? ? ? ? [ 0.0085, -0.1569, -0.1070],
? ? ? ? ? [-0.1693, -0.1114, -0.1518]],? ? ? ? ?[[-0.0057, ?0.1428, ?0.0811],
? ? ? ? ? [ 0.0324, -0.1620, -0.1143],
? ? ? ? ? [-0.0407, ?0.1052, -0.1360]],? ? ? ? ?[[-0.1781, -0.0648, -0.1358],
? ? ? ? ? [-0.0793, -0.0506, -0.1243],
? ? ? ? ? [ 0.1060, ?0.0986, ?0.0328]]],
?權(quán)重的重分配
由于前num_channels_to_prune是我們剪枝不要的,因此只保留后面的通道,所以通過module.weight.data[num_channels_to_prune:,:c1,...]將原來的權(quán)重傳給新卷積。
# 重新分配權(quán)重 權(quán)重的shape[out_channels, in_channels, k, k]
c2, c1, _, _ = new_module.weight.data.shape
new_module.weight.data[...] = module.weight.data[num_channels_to_prune:, :c1, ...]
if module.bias is not None:
new_module.bias.data[...] = module.bias.data[sorted_channels[num_channels_to_prune:]]
先看一下conv1中原來的權(quán)值:
conv1:對應(yīng)代碼中的module
tensor([[[[-0.0095, -0.1064, -0.0761],
? ? ? ? ? [-0.0687, ?0.1567, ?0.0410],
? ? ? ? ? [-0.1303, -0.0556, ?0.0263]],? ? ? ? ?[[ 0.1690, -0.0342, ?0.0444],
? ? ? ? ? [ 0.0423, ?0.1286, ?0.1294],
? ? ? ? ? [-0.1861, ?0.1208, ?0.1759]],? ? ? ? ?[[ 0.1747, -0.0429, ?0.0311],
? ? ? ? ? [ 0.1235, -0.1835, -0.0983],
? ? ? ? ? [-0.1890, -0.1257, ?0.0798]]],
再來看一下權(quán)值重新分配,可以和上面未傳入?yún)?shù)的new_module做對比,是不是發(fā)現(xiàn)現(xiàn)在權(quán)值已經(jīng)更新了:
此時(shí)的new_module :
tensor([[[[-0.0095, -0.1064, -0.0761],
? ? ? ? ? [-0.0687, ?0.1567, ?0.0410],
? ? ? ? ? [-0.1303, -0.0556, ?0.0263]],? ? ? ? ?[[ 0.1690, -0.0342, ?0.0444],
? ? ? ? ? [ 0.0423, ?0.1286, ?0.1294],
? ? ? ? ? [-0.1861, ?0.1208, ?0.1759]],? ? ? ? ?[[ 0.1747, -0.0429, ?0.0311],
? ? ? ? ? [ 0.1235, -0.1835, -0.0983],
? ? ? ? ? [-0.1890, -0.1257, ?0.0798]]],
?
通過上述過程就產(chǎn)生了新的剪枝后的卷積了。
新卷積代替model中的舊卷積?
最后就是用新的卷積new_module替換我們網(wǎng)絡(luò)中舊的卷積。僅一行代碼就可以解決。
setattr(prune_model, f"{name}", new_module)
可以看一下打印,此時(shí)的model中的conv1輸出通道變成了16,說明剪枝并替換成功。?
Model(
? (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
? (act1): ReLU(inplace=True)
新建BN層?
如果有BN層,那么對BN層也做輕量化處理。過程與上面卷積層一樣。同時(shí)用新BN層替換舊的。
elif isinstance(module, nn.BatchNorm2d):
new_bn = nn.BatchNorm2d(num_features=new_module.out_channels,
eps=module.eps,
momentum=module.momentum,
affine=module.affine,
track_running_stats=module.track_running_stats).to(next(model.parameters()).device)
new_bn.weight.data[...] = module.weight.data[sorted_channels[num_channels_to_prune:]]
if module.bias is not None:
new_bn.bias.data[...] = module.bias.data[sorted_channels[num_channels_to_prune:]]
# 用新bn替換舊bn
setattr(prune_model, f"{name}", new_bn)
剪枝前后網(wǎng)絡(luò)結(jié)構(gòu)以及參數(shù)對比?
現(xiàn)在可以對比一下剪枝前后打印的網(wǎng)絡(luò)解構(gòu),已經(jīng)能夠發(fā)現(xiàn)剪枝后各層通道數(shù)量減少了一半。
剪枝前:
model: Model(
? (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
? (act1): ReLU(inplace=True)
? (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
? (act2): ReLU(inplace=True)
? (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
? (act3): ReLU(inplace=True)
? (conv4): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
? (act4): ReLU(inplace=True)
? (conv5): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
? (act5): ReLU(inplace=True)
? (conv6): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
? (act6): ReLU(inplace=True)
? (conv7): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
? (act7): ReLU(inplace=True)
? (conv8): Conv2d(2048, 4096, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
剪枝后:
pruned model: Model(
? (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
? (act1): ReLU(inplace=True)
? (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
? (act2): ReLU(inplace=True)
? (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
? (act3): ReLU(inplace=True)
? (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
? (act4): ReLU(inplace=True)
? (conv5): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
? (act5): ReLU(inplace=True)
? (conv6): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
? (act6): ReLU(inplace=True)
? (conv7): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
? (act7): ReLU(inplace=True)
? (conv8): Conv2d(1024, 2048, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)?
?再看一下剪枝前后參數(shù)對比:
可以看到參數(shù)少了不少。
Number of parameter: 100.66M
Number of pruned model parameter: 25.16M
?完整代碼
import numpy as np
import torch
import torch.nn as nn
from loguru import logger
def count_params(module):
return sum([p.numel() for p in module.parameters()])
class Model(nn.Module):
def __init__(self, in_channels):
super(Model, self).__init__()
self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, padding=1, bias=False)
self.act1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False)
self.act2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=False)
self.act3 = nn.ReLU(inplace=True)
self.conv4 = nn.Conv2d(128, 256, 3, 1, 1, bias=False)
self.act4 = nn.ReLU(inplace=True)
self.conv5 = nn.Conv2d(256, 512, 3, 1, 1, bias=False)
self.act5 = nn.ReLU(inplace=True)
self.conv6 = nn.Conv2d(512, 1024, 3, 1, 1, bias=False)
self.act6 = nn.ReLU(inplace=True)
self.conv7 = nn.Conv2d(1024, 2048, 3, 1, 1, bias=False)
self.act7 = nn.ReLU(inplace=True)
self.conv8 = nn.Conv2d(2048, 4096, 3, 1, 1, bias=False)
def forward(self, x):
x = self.conv1(x)
x = self.act1(x)
x = self.conv2(x)
x = self.act2(x)
x = self.conv3(x)
x = self.act3(x)
x = self.conv4(x)
x = self.act4(x)
x = self.conv5(x)
x = self.act5(x)
x = self.conv6(x)
x = self.act6(x)
x = self.conv7(x)
x = self.act7(x)
out = self.conv8(x)
return out
def prune(model, percentage):
# 計(jì)算每個(gè)通道的L1-norm并排序
importance_conv = {}
prune_model = model
for name, module in model.named_modules():
if isinstance(module, (nn.Conv2d, nn.BatchNorm2d)):
# torch.norm用于計(jì)算張量的范數(shù),可以計(jì)算每個(gè)通道上的L1范數(shù) conv.weight.data shape [out_channels,in_channels, k,k]
if isinstance(module, nn.Conv2d):
importance_conv[name] = torch.norm(module.weight.data, 1, dim=(1, 2, 3))
# 對通道進(jìn)行排序,返回索引
sorted_channels = np.argsort(np.concatenate([x.cpu().numpy().flatten() for x in importance_conv[name]]))
# logger.info(f"{name} layer channel sorting results {sorted_channels}")
# 要剪掉的通道數(shù)量
num_channels_to_prune = int(len(sorted_channels) * percentage)
logger.info(
f"The number of channels that need to be cut off in the {name} layer is {num_channels_to_prune}")
logger.info(f"{name} layer pruning channel index is {sorted_channels[:num_channels_to_prune]}")
new_module = nn.Conv2d(in_channels=3 if module.in_channels == 3 else in_channels,
out_channels=module.out_channels - num_channels_to_prune,
kernel_size=module.kernel_size,
stride=module.stride,
padding=module.padding,
dilation=module.dilation,
groups=module.groups,
bias=(module.bias is not None)
).to(next(model.parameters()).device)
in_channels = new_module.out_channels # 因?yàn)榍耙粚拥妮敵鐾ǖ罆绊懴乱粚拥妮斎胪ǖ? # 重新分配權(quán)重 權(quán)重的shape[out_channels, in_channels, k, k]
c2, c1, _, _ = new_module.weight.data.shape
new_module.weight.data[...] = module.weight.data[num_channels_to_prune:, :c1, ...]
if module.bias is not None:
new_module.bias.data[...] = module.bias.data[sorted_channels[num_channels_to_prune:]]
# 用新卷積替換舊卷積
setattr(prune_model, f"{name}", new_module)
elif isinstance(module, nn.BatchNorm2d):
new_bn = nn.BatchNorm2d(num_features=new_module.out_channels,
eps=module.eps,
momentum=module.momentum,
affine=module.affine,
track_running_stats=module.track_running_stats).to(next(model.parameters()).device)
new_bn.weight.data[...] = module.weight.data[sorted_channels[num_channels_to_prune:]]
if module.bias is not None:
new_bn.bias.data[...] = module.bias.data[sorted_channels[num_channels_to_prune:]]
# 用新bn替換舊bn
setattr(prune_model, f"{name}", new_bn)
return prune_model
model = Model(3)
total_param = count_params(model)
torch.save(model, "model.pth")
print(f'\033[5;33m model: {model}\033[0m')
x = torch.randn(1, 3, 32, 32)
prune_model = prune(model, 0.5)
print(f'\033[1;36m pruned model: {prune_model}\033[0m')
total_prune_param = count_params(prune_model)
print("Number of parameter: %.2fM" % (total_param / 1e6))
print("Number of pruned model parameter: %.2fM" % (total_prune_param / 1e6))
torch.save(prune_model, "pruned.pth")
out = prune_model(x)
上面代碼中有兩行需要注意,torch.save(prune_model)而不是torch.save(prune_model.state_dict())【兩者的區(qū)別是前者會將網(wǎng)絡(luò)模型和權(quán)值全部報(bào)錯,后者只保存權(quán)值,這點(diǎn)必須注意,如果要實(shí)現(xiàn)微調(diào)訓(xùn)練必須用前者進(jìn)行保存,不然會報(bào)keys的shape問題】。out = prune_model(x)是用來判斷剪枝后的模型能否正常輸出。
如果你網(wǎng)絡(luò)的最后一層的輸出通道為num_classes,那建議你最后一層不要剪枝,不然就影響了分類輸出。
更新內(nèi)容:
2023.04.21更新內(nèi)容:
?對上述剪枝代碼進(jìn)行了整理,同時(shí)加入了2D和3D權(quán)重的繪制。
prunmodel_.py參數(shù)說明:
--prune:是否開啟剪枝功能
--percent:剪枝率,默認(rèn)0.5
--save:是否保存模型
--plt:繪制2D卷積權(quán)重圖
--plt_3d:繪制3D卷積權(quán)重圖
--layer_name:需要繪制的權(quán)重層名字
項(xiàng)目代碼鏈接:
GitHub - YINYIPENG-EN/deeplearning_channel_prune: pytorch環(huán)境下卷積層的通道剪枝
?剪枝:
python prunmodel_.py --prune --percent 0.5
繪制2D權(quán)重:
這里以繪制conv1為例
python prunmodel_.py --plt --layer_name 'conv1.weight'
繪制3D權(quán)重?
python prunmodel_.py --plt_3d --layer_name 'conv1.weight'
2023.04.22更新內(nèi)容:
????????支持BN層的輕量化,可實(shí)現(xiàn)對VGG網(wǎng)絡(luò)的剪枝。文章來源:http://www.zghlxwxcb.cn/news/detail-450510.html
后續(xù)將不定時(shí)更新其他類型的剪枝,希望多多支持~~文章來源地址http://www.zghlxwxcb.cn/news/detail-450510.html
到了這里,關(guān)于卷積神經(jīng)網(wǎng)絡(luò)輕量化教程之通道剪枝【附代碼】的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!