国产 无码 综合区,色欲AV无码国产永久播放,无码天堂亚洲国产AV,国产日韩欧美女同一区二区

如何使用torch.nn.utils.prune稀疏神經(jīng)網(wǎng)絡(luò),以及如何擴(kuò)展它以實(shí)現(xiàn)自己的自定義剪裁技術(shù)

這篇具有很好參考價(jià)值的文章主要介紹了如何使用torch.nn.utils.prune稀疏神經(jīng)網(wǎng)絡(luò),以及如何擴(kuò)展它以實(shí)現(xiàn)自己的自定義剪裁技術(shù)。希望對(duì)大家有所幫助。如果存在錯(cuò)誤或未考慮完全的地方,請(qǐng)大家不吝賜教,您也可以點(diǎn)擊"舉報(bào)違法"按鈕提交疑問(wèn)。

模型剪裁教程

最新的深度學(xué)習(xí)技術(shù)依賴于難以部署的過(guò)度參數(shù)化模型。 相反,已知生物神經(jīng)網(wǎng)絡(luò)使用有效的稀疏連通性。 為了減少內(nèi)存,電池和硬件消耗,同時(shí)又不犧牲精度,在設(shè)備上部署輕量級(jí)模型并通過(guò)私有設(shè)備上計(jì)算來(lái)確保私密性,確定通過(guò)減少模型中的參數(shù)數(shù)量來(lái)壓縮模型的最佳技術(shù)很重要。 在研究方面,剪裁用于研究參數(shù)過(guò)度配置和參數(shù)不足網(wǎng)絡(luò)在學(xué)習(xí)動(dòng)態(tài)方面的差異,以研究幸運(yùn)的稀疏子網(wǎng)絡(luò)的作用(“彩票”),以及初始化,作為破壞性的神經(jīng)結(jié)構(gòu)搜索技術(shù)等等。

在本教程中,您將學(xué)習(xí)如何使用torch.nn.utils.prune稀疏神經(jīng)網(wǎng)絡(luò),以及如何擴(kuò)展它以實(shí)現(xiàn)自己的自定義剪裁技術(shù)。

要求

"torch>=1.4.0a0+8e8a5e0"

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

創(chuàng)建模型

在本教程中,我們使用 LeCun 等人,1998 年的 LeNet 架構(gòu)。

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

檢查模塊

讓我們檢查一下 LeNet 模型中的(未剪裁)conv1層。 現(xiàn)在它將包含兩個(gè)參數(shù)weightbias,并且沒(méi)有緩沖區(qū)。

module = model.conv1
print(list(module.named_parameters()))

出:

[('weight', Parameter containing:
tensor([[[[ 0.1552,  0.0102, -0.1944],
          [ 0.0263,  0.1374, -0.3139],
          [ 0.2838,  0.1943,  0.0948]]],

        [[[-0.0296, -0.2514,  0.1300],
          [ 0.0756, -0.3155, -0.2900],
          [-0.1840,  0.1143, -0.0120]]],

        [[[-0.2383, -0.3022,  0.2295],
          [-0.0050,  0.2485, -0.3230],
          [-0.1317, -0.0054,  0.2659]]],

        [[[-0.0932,  0.1316,  0.0670],
          [ 0.0572, -0.1845,  0.0870],
          [ 0.1372,  0.1080,  0.0324]]],

        [[[ 0.0908, -0.3280,  0.0365],
          [-0.3108,  0.2317, -0.2271],
          [ 0.1171,  0.2113, -0.2259]]],

        [[[ 0.0407,  0.0512,  0.0954],
          [-0.0437,  0.0302, -0.1317],
          [ 0.2573,  0.0626,  0.0883]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.1803,  0.1331, -0.3267,  0.3173, -0.0349,  0.1828], device='cuda:0',
       requires_grad=True))]

print(list(module.named_buffers()))

出:

[]

剪裁模塊

要剪裁模塊(在此示例中,為 LeNet 架構(gòu)的conv1層),請(qǐng)首先從torch.nn.utils.prune中可用的那些技術(shù)中選擇一種剪裁技術(shù)(或通過(guò)子類化BasePruningMethod實(shí)現(xiàn)您自己的東西)。 然后,指定模塊和該模塊中要剪裁的參數(shù)的名稱。 最后,使用所選剪裁技術(shù)所需的適當(dāng)關(guān)鍵字參數(shù),指定剪裁參數(shù)。

在此示例中,我們將在conv1層中名為weight的參數(shù)中隨機(jī)剪裁 30% 的連接。 模塊作為第一個(gè)參數(shù)傳遞給函數(shù); name使用其字符串標(biāo)識(shí)符在該模塊中標(biāo)識(shí)參數(shù); amount表示與剪裁的連接百分比(如果是介于 0 和 1 之間的浮點(diǎn)數(shù)),或表示與剪裁的連接的絕對(duì)數(shù)量(如果它是非負(fù)整數(shù))。

prune.random_unstructured(module, name="weight", amount=0.3)

剪裁是通過(guò)從參數(shù)中刪除weight并將其替換為名為weight_orig的新參數(shù)(即,將"_orig"附加到初始參數(shù)name)來(lái)進(jìn)行的。 weight_orig存儲(chǔ)未剪裁的張量版本。 bias未剪裁,因此它將保持完整。

print(list(module.named_parameters()))

出:

[('bias', Parameter containing:
tensor([-0.1803,  0.1331, -0.3267,  0.3173, -0.0349,  0.1828], device='cuda:0',
       requires_grad=True)), ('weight_orig', Parameter containing:
tensor([[[[ 0.1552,  0.0102, -0.1944],
          [ 0.0263,  0.1374, -0.3139],
          [ 0.2838,  0.1943,  0.0948]]],

        [[[-0.0296, -0.2514,  0.1300],
          [ 0.0756, -0.3155, -0.2900],
          [-0.1840,  0.1143, -0.0120]]],

        [[[-0.2383, -0.3022,  0.2295],
          [-0.0050,  0.2485, -0.3230],
          [-0.1317, -0.0054,  0.2659]]],

        [[[-0.0932,  0.1316,  0.0670],
          [ 0.0572, -0.1845,  0.0870],
          [ 0.1372,  0.1080,  0.0324]]],

        [[[ 0.0908, -0.3280,  0.0365],
          [-0.3108,  0.2317, -0.2271],
          [ 0.1171,  0.2113, -0.2259]]],

        [[[ 0.0407,  0.0512,  0.0954],
          [-0.0437,  0.0302, -0.1317],
          [ 0.2573,  0.0626,  0.0883]]]], device='cuda:0', requires_grad=True))]

通過(guò)以上選擇的剪裁技術(shù)生成的剪裁掩碼將保存為名為weight_mask的模塊緩沖區(qū)(即,將"_mask"附加到初始參數(shù)name)。

print(list(module.named_buffers()))

出:

[('weight_mask', tensor([[[[1., 1., 0.],
          [0., 0., 1.],
          [1., 0., 1.]]],

        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],

        [[[1., 1., 0.],
          [1., 0., 0.],
          [1., 0., 1.]]],

        [[[1., 1., 1.],
          [1., 0., 1.],
          [1., 1., 1.]]],

        [[[1., 1., 1.],
          [0., 0., 1.],
          [1., 1., 1.]]],

        [[[1., 0., 0.],
          [1., 0., 1.],
          [1., 0., 0.]]]], device='cuda:0'))]

為了使正向傳播不更改即可工作,需要存在weight屬性。 在torch.nn.utils.prune中實(shí)現(xiàn)的剪裁技術(shù)計(jì)算權(quán)重的剪裁版本(通過(guò)將掩碼與原始參數(shù)組合)并將它們存儲(chǔ)在屬性weight中。 注意,這不再是module的參數(shù),現(xiàn)在只是一個(gè)屬性。

print(module.weight)

出:

tensor([[[[ 0.1552,  0.0102, -0.0000],
          [ 0.0000,  0.0000, -0.3139],
          [ 0.2838,  0.0000,  0.0948]]],

        [[[-0.0296, -0.2514,  0.1300],
          [ 0.0756, -0.3155, -0.2900],
          [-0.1840,  0.1143, -0.0120]]],

        [[[-0.2383, -0.3022,  0.0000],
          [-0.0050,  0.0000, -0.0000],
          [-0.1317, -0.0000,  0.2659]]],

        [[[-0.0932,  0.1316,  0.0670],
          [ 0.0572, -0.0000,  0.0870],
          [ 0.1372,  0.1080,  0.0324]]],

        [[[ 0.0908, -0.3280,  0.0365],
          [-0.0000,  0.0000, -0.2271],
          [ 0.1171,  0.2113, -0.2259]]],

        [[[ 0.0407,  0.0000,  0.0000],
          [-0.0437,  0.0000, -0.1317],
          [ 0.2573,  0.0000,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

最后,使用 PyTorch 的forward_pre_hooks在每次向前傳遞之前應(yīng)用剪裁。 具體來(lái)說(shuō),當(dāng)剪裁module時(shí)(如我們?cè)诖颂幩龅哪菢樱鼘榕c之關(guān)聯(lián)的每個(gè)參數(shù)獲取forward_pre_hook進(jìn)行剪裁。 在這種情況下,由于到目前為止我們只剪裁了名稱為weight的原始參數(shù),因此只會(huì)出現(xiàn)一個(gè)鉤子。

print(module._forward_pre_hooks)

出:

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fda78275e48>)])

為了完整起見(jiàn),我們現(xiàn)在也可以剪裁bias,以查看module的參數(shù),緩沖區(qū),掛鉤和屬性如何變化。 僅出于嘗試另一種剪裁技術(shù)的目的,在此我們按 L1 范數(shù)剪裁偏差中的 3 個(gè)最小條目,如l1_unstructured剪裁函數(shù)中所實(shí)現(xiàn)的。

prune.l1_unstructured(module, name="bias", amount=3)

現(xiàn)在,我們希望命名參數(shù)同時(shí)包含weight_orig(從前)和bias_orig。 緩沖區(qū)將包括weight_maskbias_mask。 兩個(gè)張量的剪裁后的版本將作為模塊屬性存在,并且該模塊現(xiàn)在將具有兩個(gè)forward_pre_hooks

print(list(module.named_parameters()))

出:

[('weight_orig', Parameter containing:
tensor([[[[ 0.1552,  0.0102, -0.1944],
          [ 0.0263,  0.1374, -0.3139],
          [ 0.2838,  0.1943,  0.0948]]],

        [[[-0.0296, -0.2514,  0.1300],
          [ 0.0756, -0.3155, -0.2900],
          [-0.1840,  0.1143, -0.0120]]],

        [[[-0.2383, -0.3022,  0.2295],
          [-0.0050,  0.2485, -0.3230],
          [-0.1317, -0.0054,  0.2659]]],

        [[[-0.0932,  0.1316,  0.0670],
          [ 0.0572, -0.1845,  0.0870],
          [ 0.1372,  0.1080,  0.0324]]],

        [[[ 0.0908, -0.3280,  0.0365],
          [-0.3108,  0.2317, -0.2271],
          [ 0.1171,  0.2113, -0.2259]]],

        [[[ 0.0407,  0.0512,  0.0954],
          [-0.0437,  0.0302, -0.1317],
          [ 0.2573,  0.0626,  0.0883]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.1803,  0.1331, -0.3267,  0.3173, -0.0349,  0.1828], device='cuda:0',
       requires_grad=True))]

print(list(module.named_buffers()))

出:

[('weight_mask', tensor([[[[1., 1., 0.],
          [0., 0., 1.],
          [1., 0., 1.]]],

        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],

        [[[1., 1., 0.],
          [1., 0., 0.],
          [1., 0., 1.]]],

        [[[1., 1., 1.],
          [1., 0., 1.],
          [1., 1., 1.]]],

        [[[1., 1., 1.],
          [0., 0., 1.],
          [1., 1., 1.]]],

        [[[1., 0., 0.],
          [1., 0., 1.],
          [1., 0., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]

print(module.bias)

出:

tensor([-0.0000,  0.0000, -0.3267,  0.3173, -0.0000,  0.1828], device='cuda:0',
       grad_fn=<MulBackward0>)

print(module._forward_pre_hooks)

出:

OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7fda78275e48>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7fda80bbe470>)])

迭代式剪裁

一個(gè)模塊中的同一參數(shù)可以被多次剪裁,各種剪裁調(diào)用的效果等于連接應(yīng)用的各種蒙版的組合。 PruningContainercompute_mask方法可處理新遮罩與舊遮罩的組合。

例如,假設(shè)我們現(xiàn)在想進(jìn)一步剪裁module.weight,這一次是使用沿著張量的第 0 軸的結(jié)構(gòu)化剪裁(第 0 軸對(duì)應(yīng)于卷積層的輸出通道,并且對(duì)于conv1具有 6 維) ,基于渠道的 L2 規(guī)范。 這可以通過(guò)ln_structuredn=2dim=0函數(shù)來(lái)實(shí)現(xiàn)。

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to
# 50% (3 out of 6) of the channels, while preserving the action of the
# previous mask.
print(module.weight)

出:

tensor([[[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000]]],

        [[[-0.0296, -0.2514,  0.1300],
          [ 0.0756, -0.3155, -0.2900],
          [-0.1840,  0.1143, -0.0120]]],

        [[[-0.2383, -0.3022,  0.0000],
          [-0.0050,  0.0000, -0.0000],
          [-0.1317, -0.0000,  0.2659]]],

        [[[-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]],

        [[[ 0.0908, -0.3280,  0.0365],
          [-0.0000,  0.0000, -0.2271],
          [ 0.1171,  0.2113, -0.2259]]],

        [[[ 0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

現(xiàn)在,對(duì)應(yīng)的鉤子將為torch.nn.utils.prune.PruningContainer類型,并將存儲(chǔ)應(yīng)用于weight參數(shù)的剪裁歷史。

for hook in module._forward_pre_hooks.values():
    if hook._tensor_name == "weight":  # select out the correct hook
        break

print(list(hook))  # pruning history in the container

出:

[<torch.nn.utils.prune.RandomUnstructured object at 0x7fda78275e48>, <torch.nn.utils.prune.LnStructured object at 0x7fda80071828>]

序列化剪裁的模型

所有相關(guān)的張量,包括掩碼緩沖區(qū)和用于計(jì)算剪裁的張量的原始參數(shù),都存儲(chǔ)在模型的state_dict中,因此可以根據(jù)需要輕松地序列化和保存。

print(model.state_dict().keys())

出:

odict_keys(['conv1.weight_orig', 'conv1.bias_orig', 'conv1.weight_mask', 'conv1.bias_mask', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias', 'fc3.weight', 'fc3.bias'])

刪除剪裁重新參數(shù)化

要使剪裁永久化,請(qǐng)刪除weight_origweight_mask的重新參數(shù)化,然后刪除forward_pre_hook,我們可以使用torch.nn.utils.pruneremove函數(shù)。 請(qǐng)注意,這不會(huì)撤消剪裁,好像從未發(fā)生過(guò)。 而是通過(guò)將參數(shù)weight重新分配給模型參數(shù)(剪裁后的版本)來(lái)使其永久不變。

刪除重新參數(shù)化之前:

print(list(module.named_parameters()))

出:

[('weight_orig', Parameter containing:
tensor([[[[ 0.1552,  0.0102, -0.1944],
          [ 0.0263,  0.1374, -0.3139],
          [ 0.2838,  0.1943,  0.0948]]],

        [[[-0.0296, -0.2514,  0.1300],
          [ 0.0756, -0.3155, -0.2900],
          [-0.1840,  0.1143, -0.0120]]],

        [[[-0.2383, -0.3022,  0.2295],
          [-0.0050,  0.2485, -0.3230],
          [-0.1317, -0.0054,  0.2659]]],

        [[[-0.0932,  0.1316,  0.0670],
          [ 0.0572, -0.1845,  0.0870],
          [ 0.1372,  0.1080,  0.0324]]],

        [[[ 0.0908, -0.3280,  0.0365],
          [-0.3108,  0.2317, -0.2271],
          [ 0.1171,  0.2113, -0.2259]]],

        [[[ 0.0407,  0.0512,  0.0954],
          [-0.0437,  0.0302, -0.1317],
          [ 0.2573,  0.0626,  0.0883]]]], device='cuda:0', requires_grad=True)), ('bias_orig', Parameter containing:
tensor([-0.1803,  0.1331, -0.3267,  0.3173, -0.0349,  0.1828], device='cuda:0',
       requires_grad=True))]

print(list(module.named_buffers()))

出:

[('weight_mask', tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],

        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],

        [[[1., 1., 0.],
          [1., 0., 0.],
          [1., 0., 1.]]],

        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]],

        [[[1., 1., 1.],
          [0., 0., 1.],
          [1., 1., 1.]]],

        [[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]]]], device='cuda:0')), ('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]

print(module.weight)

出:

tensor([[[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000]]],

        [[[-0.0296, -0.2514,  0.1300],
          [ 0.0756, -0.3155, -0.2900],
          [-0.1840,  0.1143, -0.0120]]],

        [[[-0.2383, -0.3022,  0.0000],
          [-0.0050,  0.0000, -0.0000],
          [-0.1317, -0.0000,  0.2659]]],

        [[[-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]],

        [[[ 0.0908, -0.3280,  0.0365],
          [-0.0000,  0.0000, -0.2271],
          [ 0.1171,  0.2113, -0.2259]]],

        [[[ 0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

刪除重新參數(shù)化后:

prune.remove(module, 'weight')
print(list(module.named_parameters()))

出:

[('bias_orig', Parameter containing:
tensor([-0.1803,  0.1331, -0.3267,  0.3173, -0.0349,  0.1828], device='cuda:0',
       requires_grad=True)), ('weight', Parameter containing:
tensor([[[[ 0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000]]],

        [[[-0.0296, -0.2514,  0.1300],
          [ 0.0756, -0.3155, -0.2900],
          [-0.1840,  0.1143, -0.0120]]],

        [[[-0.2383, -0.3022,  0.0000],
          [-0.0050,  0.0000, -0.0000],
          [-0.1317, -0.0000,  0.2659]]],

        [[[-0.0000,  0.0000,  0.0000],
          [ 0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000,  0.0000]]],

        [[[ 0.0908, -0.3280,  0.0365],
          [-0.0000,  0.0000, -0.2271],
          [ 0.1171,  0.2113, -0.2259]]],

        [[[ 0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000],
          [ 0.0000,  0.0000,  0.0000]]]], device='cuda:0', requires_grad=True))]

print(list(module.named_buffers()))

出:

[('bias_mask', tensor([0., 0., 1., 1., 0., 1.], device='cuda:0'))]

剪裁模型中的多個(gè)參數(shù)

通過(guò)指定所需的剪裁技術(shù)和參數(shù),我們可以輕松地剪裁網(wǎng)絡(luò)中的多個(gè)張量,也許根據(jù)它們的類型,如在本示例中將看到的那樣。

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist

出:

dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])

全局剪裁

到目前為止,我們僅查看了通常稱為“局部”剪裁的情況,即通過(guò)比較每個(gè)條目的統(tǒng)計(jì)信息(權(quán)重,激活度,梯度等)來(lái)逐個(gè)剪裁模型中的張量的做法。 到該張量中的其他條目。 但是,一種通用且可能更強(qiáng)大的技術(shù)是通過(guò)刪除(例如)刪除整個(gè)模型中最低的 20% 的連接,而不是刪除每一層中最低的 20% 的連接來(lái)一次剪裁模型。 這很可能導(dǎo)致每個(gè)層的剪裁百分比不同。 讓我們看看如何使用torch.nn.utils.prune中的global_unstructured進(jìn)行操作。

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

現(xiàn)在,我們可以檢查每個(gè)剪裁參數(shù)中引起的稀疏性,該稀疏性將不等于每層中的 20%。 但是,全局稀疏度將(大約)為 20%。

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100\. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100\. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)

出:

Sparsity in conv1.weight: 3.70%
Sparsity in conv2.weight: 8.10%
Sparsity in fc1.weight: 22.05%
Sparsity in fc2.weight: 12.29%
Sparsity in fc3.weight: 8.45%
Global sparsity: 20.00%

使用自定義剪裁函數(shù)擴(kuò)展torch.nn.utils.prune

要實(shí)現(xiàn)自己的剪裁函數(shù),可以通過(guò)繼承BasePruningMethod基類的子類來(lái)擴(kuò)展nn.utils.prune模塊,這與所有其他剪裁方法一樣。 基類為您實(shí)現(xiàn)以下方法:__call__,apply_mask,apply,pruneremove。 除了一些特殊情況外,您無(wú)需為新的剪裁技術(shù)重新實(shí)現(xiàn)這些方法。 但是,您將必須實(shí)現(xiàn)__init__(構(gòu)造器)和compute_mask(有關(guān)如何根據(jù)剪裁技術(shù)的邏輯為給定張量計(jì)算掩碼的說(shuō)明)。 另外,您將必須指定此技術(shù)實(shí)現(xiàn)的剪裁類型(支持的選項(xiàng)為globalstructuredunstructured)。 需要確定在迭代應(yīng)用剪裁的情況下如何組合蒙版。 換句話說(shuō),當(dāng)剪裁預(yù)剪裁的參數(shù)時(shí),當(dāng)前的剪裁技術(shù)應(yīng)作用于參數(shù)的未剪裁部分。 指定PRUNING_TYPE將使PruningContainer(處理剪裁掩碼的迭代應(yīng)用)正確識(shí)別要剪裁的參數(shù)。

例如,假設(shè)您要實(shí)現(xiàn)一種剪裁技術(shù),以剪裁張量中的所有其他條目(或者-如果先前已剪裁過(guò)張量,則剪裁張量的其余未剪裁部分)。 這將是PRUNING_TYPE='unstructured',因?yàn)樗饔糜趯又械膯蝹€(gè)連接,而不作用于整個(gè)單元/通道('structured'),或作用于不同的參數(shù)('global')。

class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0
        return mask

現(xiàn)在,要將其應(yīng)用于nn.Module中的參數(shù),還應(yīng)該提供一個(gè)簡(jiǎn)單的函數(shù)來(lái)實(shí)例化該方法并將其應(yīng)用。

def foobar_unstructured(module, name):
    """Prunes tensor corresponding to parameter called `name` in `module`
    by removing every other entry in the tensors.
    Modifies module in place (and also return the modified module)
    by:
    1) adding a named buffer called `name+'_mask'` corresponding to the
    binary mask applied to the parameter `name` by the pruning method.
    The parameter `name` is replaced by its pruned version, while the
    original (unpruned) parameter is stored in a new parameter named
    `name+'_orig'`.

    Args:
        module (nn.Module): module containing the tensor to prune
        name (string): parameter name within `module` on which pruning
                will act.

    Returns:
        module (nn.Module): modified (i.e. pruned) version of the input
            module

    Examples:
        >>> m = nn.Linear(3, 4)
        >>> foobar_unstructured(m, name='bias')
    """
    FooBarPruningMethod.apply(module, name)
    return module

試試吧!

model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)

出:文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-525414.html

tensor([0., 1., 0., 1., 0., 1., 0., 1., 0., 1.])

到了這里,關(guān)于如何使用torch.nn.utils.prune稀疏神經(jīng)網(wǎng)絡(luò),以及如何擴(kuò)展它以實(shí)現(xiàn)自己的自定義剪裁技術(shù)的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!

本文來(lái)自互聯(lián)網(wǎng)用戶投稿,該文觀點(diǎn)僅代表作者本人,不代表本站立場(chǎng)。本站僅提供信息存儲(chǔ)空間服務(wù),不擁有所有權(quán),不承擔(dān)相關(guān)法律責(zé)任。如若轉(zhuǎn)載,請(qǐng)注明出處: 如若內(nèi)容造成侵權(quán)/違法違規(guī)/事實(shí)不符,請(qǐng)點(diǎn)擊違法舉報(bào)進(jìn)行投訴反饋,一經(jīng)查實(shí),立即刪除!

領(lǐng)支付寶紅包贊助服務(wù)器費(fèi)用

相關(guān)文章

覺(jué)得文章有用就打賞一下文章作者

支付寶掃一掃打賞

博客贊助

微信掃一掃打賞

請(qǐng)作者喝杯咖啡吧~博客贊助

支付寶掃一掃領(lǐng)取紅包,優(yōu)惠每天領(lǐng)

二維碼1

領(lǐng)取紅包

二維碼2

領(lǐng)紅包