一、概述
剪枝(Pruning)的一些概念:
- 當(dāng)提及神經(jīng)網(wǎng)絡(luò)的"參數(shù)"時(shí),大多數(shù)情況指的是網(wǎng)絡(luò)的學(xué)習(xí)型參數(shù),也就是權(quán)重矩陣weights和偏置bias;
- 現(xiàn)代網(wǎng)絡(luò)的參數(shù)量大概在百萬(wàn)至數(shù)十億之間,因此實(shí)際上在一個(gè)網(wǎng)絡(luò)中也并不是所有權(quán)值都是重要的,剪枝的作用就是削減那些不重要權(quán)重矩陣的一種直接壓縮模型的方式;
- 對(duì)于一個(gè)已經(jīng)訓(xùn)練好的模型,切斷或刪除某些連接,同時(shí)保證不對(duì)精度造成重大影響,這樣得到的模型就是一個(gè)參數(shù)較少的剪枝模型;
- 從生物學(xué)的角度來(lái)說(shuō),人類(lèi)在成長(zhǎng)過(guò)程中突觸會(huì)減少,但思維能力反而更強(qiáng)了;
- 和dropout的區(qū)別:dropout具有隨機(jī)性,剪枝具有針對(duì)性;
下面看一下剪枝的實(shí)際操作圖:
二、策略
剪枝主要有以下幾種方法:
1、迭代式剪枝:訓(xùn)練權(quán)重——剪枝(根據(jù)閾值)——重新訓(xùn)練權(quán)重【最常用】
2、動(dòng)態(tài)剪枝:剪枝和訓(xùn)練同時(shí)進(jìn)行,在網(wǎng)絡(luò)的優(yōu)化目標(biāo)中加入權(quán)重的稀疏正則項(xiàng),使得網(wǎng)絡(luò)訓(xùn)練時(shí)部分權(quán)重趨近于0;
3、對(duì)推理過(guò)程中單個(gè)目標(biāo)剪枝;
總結(jié):大多數(shù)的剪枝方法實(shí)際上是迭代的方式進(jìn)行的,因?yàn)樾藜艉笾匦掠?xùn)練,可以讓模型因修剪操作導(dǎo)致的精度下降恢復(fù)過(guò)來(lái),然后在進(jìn)行下一次修剪,直到達(dá)到精度下降的閾值,就不再修剪;
策略對(duì)比圖:
從圖中可以看出,單純剪枝到50%精度就開(kāi)始下降,剪枝后訓(xùn)練到80%精度才開(kāi)始下降,迭代進(jìn)行剪枝到90%精度才下降;
拓展:
實(shí)際上剪枝的大類(lèi)分為幾種:
1、非結(jié)構(gòu)化剪枝:也就是上述介紹的將不重要的權(quán)重置為0;
2、結(jié)構(gòu)化剪枝:將模型的一個(gè)完整結(jié)構(gòu)剪除,比如channels、filters、layers;
3、自動(dòng)化剪枝:NAS,需要大量的算力支持;
三、優(yōu)缺點(diǎn)
優(yōu)點(diǎn):
-
可以應(yīng)用在訓(xùn)練期間或訓(xùn)練結(jié)束后;
-
對(duì)于任意一個(gè)結(jié)構(gòu),可以自主控制推理時(shí)間/模型大小與準(zhǔn)確率之間的平衡;
-
可應(yīng)用于卷積層和全連接層;
缺點(diǎn):
- 沒(méi)有直接切換到一個(gè)更好的網(wǎng)絡(luò)來(lái)的有效;
四、代碼案例
首先需要明確,剪枝是需要對(duì)模型層做一定修改的;
本次代碼是基于小模型LeNet進(jìn)行剪枝實(shí)驗(yàn);
1、對(duì)模型結(jié)構(gòu)中的Liner層進(jìn)行修改,添加mask這個(gè)變量(自定義MaskedLinear層)
class MaskedLinear(Module):
def __init__(self, in_features, out_features, bias=True):
super(MaskedLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
# 將weight轉(zhuǎn)換為可學(xué)習(xí)的變量
self.weight = Parameter(torch.Tensor(out_features, in_features))
# 初始化mask的值為1,并轉(zhuǎn)換為可學(xué)習(xí)的變量
self.mask = Parameter(torch.ones([out_features, in_features]), requires_grad=False)
if bias:
# 對(duì)bias進(jìn)行初始化
self.bias = Parameter(torch.Tensor(out_features))
else:
# 將bias設(shè)置為空
self.register_parameter('bias', None)
self.reset_parameters()
# 參數(shù)初始化
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
# 前向傳播(實(shí)際上也是使用標(biāo)準(zhǔn)的Liner層)
def forward(self, input):
# 其中的weight、mask都定義成可變的可學(xué)習(xí)變量
return F.linear(input, self.weight * self.mask, self.bias)
LeNet的定義沒(méi)有做任何修改,也就是幾層全連接層,就不在這里進(jìn)行代碼展示了;
2、對(duì)模型每一層學(xué)習(xí)到的參數(shù)進(jìn)行處理
for name, p in model.named_parameters():
if 'mask' in name:
continue
# 模型參數(shù)
tensor = p.data.cpu().numpy()
# 梯度信息
grad_tensor = p.grad.data.cpu().numpy()
# 將參數(shù)的值為0的,梯度也更新為0
grad_tensor = np.where(tensor == 0, 0, grad_tensor)
p.grad.data = torch.from_numpy(grad_tensor).to(device)
3、統(tǒng)計(jì)每一層參數(shù)的非零數(shù)量,可用于展示剪枝的效果
def print_nonzeros(model):
nonzero = total = 0
for name, p in model.named_parameters():
if 'mask' in name:
continue
tensor = p.data.cpu().numpy()
# 用numpy中的函數(shù)統(tǒng)計(jì)tensor中非0值的數(shù)量
nz_count = np.count_nonzero(tensor)
total_params = np.prod(tensor.shape)
nonzero += nz_count
total += total_params
print(f'{name:20} | nonzeros = {nz_count:7} / {total_params:7} ({100 * nz_count / total_params:6.2f}%) | total_pruned = {total_params - nz_count :7} | shape = {tensor.shape}')
print(f'alive: {nonzero}, pruned : {total - nonzero}, total: {total}, Compression rate : {total/nonzero:10.2f}x ({100 * (total-nonzero) / total:6.2f}% pruned)')
4、實(shí)現(xiàn)剪枝的具體操作
# 參數(shù)s控制剪枝的力度
def prune_by_std(self, s=0.25):
for name, module in self.named_modules():
if name in ['fc1', 'fc2', 'fc3']:
# 取weight值得標(biāo)準(zhǔn)差乘以s
threshold = np.std(module.weight.data.cpu().numpy()) * s
# 打印每一層計(jì)算標(biāo)準(zhǔn)差閾值后得結(jié)果
print(f'Pruning with threshold : {threshold} for layer {name}')
# 得到閾值后進(jìn)行剪枝
module.prune(threshold)
# 具體實(shí)現(xiàn)剪枝的函數(shù)
def prune(self, threshold):
weight_dev = self.weight.device
# mask就是一開(kāi)始傳入的參數(shù),全為1
mask_dev = self.mask.device
# Convert Tensors to numpy and calculate
tensor = self.weight.data.cpu().numpy()
mask = self.mask.data.cpu().numpy()
# 更新mask(小于閾值的時(shí)候?yàn)?,不小于閾值的還是為1)
new_mask = np.where(abs(tensor) < threshold, 0, mask)
# weight和新的mask進(jìn)行矩陣相乘
self.weight.data = torch.from_numpy(tensor * new_mask).to(weight_dev)
# 更新對(duì)應(yīng)的mask
self.mask.data = torch.from_numpy(new_mask).to(mask_dev)
說(shuō)明:這里進(jìn)行剪枝后,模型的精度會(huì)有下降,需要進(jìn)行重新訓(xùn)練;
重新訓(xùn)練直接用原來(lái)的優(yōu)化器參數(shù)訓(xùn)練即可,此時(shí)置為0的weight也不再參與梯度優(yōu)化;
五、結(jié)果展示
剪枝前,經(jīng)過(guò)了100個(gè)epoch:
此時(shí)精度為95.23%,wight參數(shù)全部不為0;
經(jīng)過(guò)剪枝后:
此時(shí)可以看出,精度下降到85.08%,但weight的數(shù)值縮小了接近22倍,大大減少了參數(shù)量;
剪枝后再重新訓(xùn)練100個(gè)epoch:
此時(shí)精度又回到了97%,甚至比剪枝前還高,并且壓縮度也保持不變;
總結(jié)
剪枝的操作總結(jié)下來(lái)分為幾步:
模型的訓(xùn)練 —— 修改要剪枝的層(添加同weight維度的mask) —— 進(jìn)行剪枝后推理 —— 根據(jù)剪枝后的權(quán)重重新訓(xùn)練
下圖給到了剪枝的一個(gè)建議:
個(gè)人理解:剪枝本質(zhì)就是忽略那些低于閾值的參數(shù),從而減少參數(shù)量,使得模型得到壓縮;文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-798635.html
實(shí)際上在每一種結(jié)構(gòu)中都可以用到剪枝,弊端就是工作量較大,需要針對(duì)不同層進(jìn)行修改,并且還要重新訓(xùn)練,如果剪枝的力度過(guò)大,可能導(dǎo)致和剪枝前精度相差過(guò)大;文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-798635.html
到了這里,關(guān)于【模型壓縮】(二)—— 剪枝的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!