前言
今天在這里紀(jì)錄一下如何對(duì)torch網(wǎng)絡(luò)的層進(jìn)行更改:變更,增加,刪除與查找
這里拿VGG16網(wǎng)絡(luò)舉例,先看一下網(wǎng)絡(luò)結(jié)構(gòu)
import torch
import torch.nn as nn
from torchvision import models
net = models.vgg11(pretrained=True)
一、在網(wǎng)絡(luò)中添加一層:
net網(wǎng)絡(luò)是一個(gè)樹型結(jié)構(gòu), net下面有三個(gè)結(jié)點(diǎn),分別是(features, avgpoll, classifier), 我們先在features結(jié)點(diǎn)添加一層’lastlayer’層
net.features.add_module('lastlayer', nn.Conv2d(512,512, kernel_size=3, stride=1, padding=1))
- 在classifier結(jié)點(diǎn)添加一個(gè)線性層:
net.classifier.add_module('Linear', nn.Linear(1000, 10))
二、修改網(wǎng)絡(luò)中的某一層
- 以features 結(jié)點(diǎn)舉例
net.features[8] = nn.Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
- 以classifier結(jié)點(diǎn)舉例
net.classifier[6] = nn.Linear(1000, 5)
注意: 這里我嘗試對(duì)Linear這一層進(jìn)行更新, 但是Linear名字是字符串, 提取不出來,所以應(yīng)該在之前添加網(wǎng)絡(luò)時(shí)候, 名字不要取字符串, 否則會(huì)報(bào)錯(cuò) ‘ 'str' object cannot be interpreted as an integer’。
三、網(wǎng)絡(luò)層的刪除
方法一:使用關(guān)鍵字del刪除層(推薦)
刪除前
model = prepare_vitmodel('mae_visualize_vit_large_ganloss.pth', 'vit_large_patch16')
del model.head # 刪除層
model
刪除后
方法二:將層設(shè)置為空層
以features舉例 classifier結(jié)點(diǎn)的操作相同,這里直接使用nn.Sequential()對(duì)改層設(shè)置為空即可
net.features[13] = nn.Sequential()
文章來源:http://www.zghlxwxcb.cn/news/detail-649834.html
四、網(wǎng)絡(luò)層的切片
net.features = nn.Sequential(*list(net.features.children())[:-4])
可以看到后面4層被去除了, 就是說可以使用列表切片的方法來刪除網(wǎng)絡(luò)層
net.classifier 對(duì)應(yīng) net.classifier.children()
net.features 對(duì)應(yīng) net.features.children()文章來源地址http://www.zghlxwxcb.cn/news/detail-649834.html
五、網(wǎng)絡(luò)層的凍結(jié)
#凍結(jié)指定層的預(yù)訓(xùn)練參數(shù):
net.feature[26].weight.requires_grad = False
到了這里,關(guān)于pytorch對(duì)網(wǎng)絡(luò)層的增加,刪除,變更和切片的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!