一、前言
模型權(quán)值平均是一種用于改善深度神經(jīng)網(wǎng)絡(luò)泛化性能的技術(shù)。通過(guò)對(duì)訓(xùn)練過(guò)程中不同時(shí)間步的模型權(quán)值進(jìn)行平均,可以得到更寬的極值點(diǎn)(optima)并提高模型的泛化能力。 在PyTorch中,官方提供了實(shí)現(xiàn)模型權(quán)值平均的方法。
這里我們首先介紹指數(shù)移動(dòng)平均(EMA)方法,它使用一個(gè)衰減系數(shù)來(lái)平衡當(dāng)前權(quán)值和先前平均權(quán)值。其次,介紹了隨機(jī)加權(quán)平均(SWA)方法,它通過(guò)將當(dāng)前權(quán)值與先前平均權(quán)值進(jìn)行加權(quán)平均來(lái)更新權(quán)值。最后,介紹了Tanh自適應(yīng)指數(shù)移動(dòng)EMA算法(T_ADEMA),它使用Tanh函數(shù)來(lái)調(diào)整衰減系數(shù),以更好地適應(yīng)訓(xùn)練過(guò)程中的不同階段。
為了方便使用這些權(quán)值平均方法,我將官方的代碼寫成了一個(gè)基類AveragingBaseModel,以此引出EMAModel、SWAModel和T_ADEMAModel等方法。這些類可以用于包裝原始模型,并在訓(xùn)練過(guò)程中更新平均權(quán)值。 為了驗(yàn)證這些權(quán)值平均方法的效果,我還在ResNet18模型上進(jìn)行了簡(jiǎn)單的實(shí)驗(yàn)。實(shí)驗(yàn)結(jié)果表明,使用權(quán)值平均方法可以提高模型的準(zhǔn)確率,尤其是在訓(xùn)練后期。
但請(qǐng)注意,博客中所提供的代碼示例僅用于演示權(quán)值平均的原理和PyTorch的實(shí)現(xiàn)方式,并不能保證在所有情況下都能取得理想的效果。實(shí)際應(yīng)用中,還需要根據(jù)具體任務(wù)和數(shù)據(jù)集來(lái)選擇適合的權(quán)值平均方法和參數(shù)設(shè)置。
二、算法介紹
基類實(shí)現(xiàn)
這里我們的基類完全是參照于torch源碼部分,僅僅進(jìn)行了一點(diǎn)細(xì)微的修改。
它首先通過(guò)de_parallel函數(shù)將原始模型轉(zhuǎn)換為單個(gè)GPU模型。de_parallel函數(shù)用于處理并行模型,將其轉(zhuǎn)換為單個(gè)GPU模型。然后,它將轉(zhuǎn)換后的模型復(fù)制到適當(dāng)?shù)脑O(shè)備(CPU或GPU)上(這一步很重要,問(wèn)題大多數(shù)就是因?yàn)橛?jì)算不匹配),并注冊(cè)一個(gè)名為n_averaged的緩沖區(qū),用于跟蹤已平均的次數(shù)。
在forward方法中,它簡(jiǎn)單地將調(diào)用傳遞給轉(zhuǎn)換后的模型。update方法首先獲取當(dāng)前模型和新模型的參數(shù),并將它們轉(zhuǎn)換為可迭代對(duì)象,用于更新平均權(quán)值。它接受一個(gè)新的模型作為參數(shù),并將其與當(dāng)前模型(已平均的權(quán)值)進(jìn)行比較。
from copy import deepcopy
from pyzjr.core.general import is_parallel
import itertools
from torch.nn import Module
def de_parallel(model):
"""
將并行模型(DataParallel 或 DistributedDataParallel)轉(zhuǎn)換為單 GPU 模型。
"""
return model.module if is_parallel(model) else model
class AveragingBaseModel(Module):
def __init__(self, model, cuda=False, avg_fn=None, use_buffers=False):
super(AveragingBaseModel, self).__init__()
device = 'cuda' if cuda and torch.cuda.is_available() else 'cpu'
self.module = deepcopy(de_parallel(model))
self.module = self.module.to(device)
self.register_buffer('n_averaged',
torch.tensor(0, dtype=torch.long, device=device))
self.avg_fn = avg_fn
self.use_buffers = use_buffers
def forward(self, *args, **kwargs):
return self.module(*args, **kwargs)
def update(self, model):
self_param = itertools.chain(self.module.parameters(), self.module.buffers() if self.use_buffers else [])
model_param = itertools.chain(model.parameters(), model.buffers() if self.use_buffers else [])
self_param_detached = [p.detach() for p in self_param]
model_param_detached = [p.detach().to(p_averaged.device) for p, p_averaged in zip(model_param, self_param_detached)]
if self.n_averaged == 0:
for p_averaged, p_model in zip(self_param_detached, model_param_detached):
p_averaged.copy_(p_model)
if self.n_averaged > 0:
for p_averaged, p_model in zip(self_param_detached, model_param_detached):
n_averaged = self.n_averaged.to(p_averaged.device)
p_averaged.copy_(self.avg_fn(p_averaged, p_model, n_averaged))
if not self.use_buffers:
for b_swa, b_model in zip(self.module.buffers(), model.buffers()):
b_swa.copy_(b_model.to(b_swa.device).detach())
self.n_averaged += 1
若當(dāng)前模型尚未進(jìn)行過(guò)平均(即n_averaged為0),則直接將新模型的參數(shù)復(fù)制到當(dāng)前模型中。若當(dāng)前模型已經(jīng)進(jìn)行過(guò)平均,則通過(guò)avg_fn函數(shù)計(jì)算當(dāng)前模型和新模型的加權(quán)平均,并將結(jié)果復(fù)制到當(dāng)前模型中。如果use_buffers為True,則會(huì)將緩沖區(qū)從新模型復(fù)制到當(dāng)前模型。最后,n_averaged增加1,表示已進(jìn)行一次平均。
指數(shù)移動(dòng)平均(EMA)
EMA被用于根據(jù)當(dāng)前參數(shù)和之前的平均參數(shù)來(lái)更新平均參數(shù)。其計(jì)算公式如下所示:
這里的EMA param是當(dāng)前的平均參數(shù),current param是當(dāng)前的參數(shù),decay是一個(gè)介于0和1之間的衰減因子,它用于控制當(dāng)前參數(shù)對(duì)平均參數(shù)的貢獻(xiàn)程度。decay越接近1,平均參數(shù)對(duì)當(dāng)前參數(shù)的影響就越小,反之亦是。
def get_ema_avg_fn(decay=0.999):
@torch.no_grad()
def ema_update(ema_param, current_param, num_averaged):
return decay * ema_param + (1 - decay) * current_param
return ema_update
class EMAModel(AveragingBaseModel):
def __init__(self, model, cuda = False, decay=0.9, use_buffers=False):
super().__init__(model=model, cuda=cuda, avg_fn=get_ema_avg_fn(decay), use_buffers=use_buffers)
隨機(jī)加權(quán)平均(SWA)
SWA通過(guò)對(duì)神經(jīng)網(wǎng)絡(luò)的權(quán)重進(jìn)行平均來(lái)改善模型的泛化能力。其計(jì)算公式如下所示:
SWA param是新的平均參數(shù),averaged param是之前的平均參數(shù),current param是當(dāng)前的參數(shù),num avg是已經(jīng)平均的參數(shù)數(shù)量。
def get_swa_avg_fn():
@torch.no_grad()
def swa_update(averaged_param, current_param, num_averaged):
return averaged_param + (current_param - averaged_param) / (num_averaged + 1)
return swa_update
class SWAModel(AveragingBaseModel):
def __init__(self, model, cuda = False,use_buffers=False):
super().__init__(model=model, cuda=cuda, avg_fn=get_swa_avg_fn(), use_buffers=use_buffers)
Tanh自適應(yīng)指數(shù)移動(dòng)EMA算法(T_ADEMA)
這一個(gè)是在查詢資料的時(shí)候,找到的一篇論文描述的,是否有效,還得經(jīng)過(guò)實(shí)驗(yàn)才對(duì)。
全文閱讀--XML全文閱讀--中國(guó)知網(wǎng) (cnki.net)
論文表示是為了在神經(jīng)網(wǎng)絡(luò)訓(xùn)練過(guò)程中根據(jù)不同的訓(xùn)練階段更有效地過(guò)濾噪聲,所提出的公式:
T_ADEMA param是新的平均參數(shù),avg param是之前的平均參數(shù),current param是當(dāng)前的參數(shù),num avg是已經(jīng)平均的參數(shù)數(shù)量。alpha是一個(gè)控制衰減速率的超參數(shù)。通過(guò)將參數(shù)數(shù)量作為輸入傳遞給切線函數(shù)的參數(shù),動(dòng)態(tài)地計(jì)算衰減因子。切線函數(shù)(tanh)的輸出范圍為[-1, 1],隨著參數(shù)數(shù)量的增加,衰減因子會(huì)逐漸趨近于1。由于切線函數(shù)的特性,當(dāng)參數(shù)數(shù)量較小時(shí),衰減因子接近于0;當(dāng)參數(shù)數(shù)量較大時(shí),衰減因子接近于1。
def get_t_adema(alpha=0.9):
num_averaged = [0] # 使用列表包裝可變對(duì)象,以在閉包中引用
@torch.no_grad()
def t_adema_update(averaged_param, current_param, num_averageds):
num_averaged[0] += 1
decay = alpha * torch.tanh(torch.tensor(num_averaged[0], dtype=torch.float32))
tadea_update = decay * averaged_param + (1 - decay) * current_param
return tadea_update
return t_adema_update
class T_ADEMAModel(AveragingBaseModel):
def __init__(self, model, cuda=False, alpha=0.9, use_buffers=False):
super().__init__(model=model, cuda=cuda, avg_fn=get_t_adema(alpha), use_buffers=use_buffers)
三、構(gòu)建一個(gè)簡(jiǎn)單的實(shí)驗(yàn)測(cè)試
這一部分我正在做實(shí)驗(yàn),下面是調(diào)用了一個(gè)簡(jiǎn)單的resnet18網(wǎng)絡(luò),看看邏輯上面是否有錯(cuò)。
if __name__=="__main__":
# 創(chuàng)建 ResNet18 模型
import torch
import torchvision.models as models
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.optim.swa_utils import AveragedModel
class RandomDataset(torch.utils.data.Dataset):
def __init__(self, size=224):
self.data = torch.randn(size, 3, 224, 224)
self.labels = torch.randint(0, 2, (size,))
def __getitem__(self, index):
return self.data[index], self.labels[index]
def __len__(self):
return len(self.data)
model = models.resnet18(pretrained=False)
# model = model.to('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
# 創(chuàng)建數(shù)據(jù)加載器
train_dataset = RandomDataset()
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# 定義權(quán)重平均模型
swa_model = SWAModel(model, cuda=True)
ema_model = EMAModel(model, cuda=True)
t_adema_model = T_ADEMAModel(model, cuda=True)
for epoch in range(5):
for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{5}"):
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# 更新權(quán)重平均模型
ema_model.update(model)
swa_model.update(model)
t_adema_model.update(model)
# 測(cè)試模型
test_dataset = RandomDataset(size=100)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
def evaluate(model):
model.eval() # 切換到評(píng)估模式
correct = 0
total = 0
with torch.no_grad():
for inputs, labels in test_loader:
inputs, labels = inputs.to('cuda'), labels.to('cuda')
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
print(f"模型準(zhǔn)確率:{accuracy * 100:.2f}%")
# 原模型測(cè)試
print("Model Evaluation:")
evaluate(model.to('cuda')) #
# 測(cè)試權(quán)重平均模型
print("SWAModel Evaluation:")
evaluate(swa_model.to('cuda'))
print("EMAModel Evaluation:")
evaluate(ema_model.to('cuda'))
print("T-ADEMAModel Evaluation:")
evaluate(t_adema_model.to('cuda'))
運(yùn)行效果:
Model Evaluation: 模型準(zhǔn)確率:46.00% SWAModel Evaluation: 模型準(zhǔn)確率:54.00% EMAModel Evaluation: 模型準(zhǔn)確率:58.00% T - ADEMAModel Evaluation: 模型準(zhǔn)確率:58.00%
僅僅是測(cè)試是否能夠跑通,過(guò)程中也有比原模型要低的時(shí)候,而且權(quán)值平均主要是用于訓(xùn)練中后期,所以有沒(méi)有效果應(yīng)該需要自己去做實(shí)驗(yàn)。文章來(lái)源:http://www.zghlxwxcb.cn/news/detail-808874.html
當(dāng)前你可以下載pip install pyzjr==1.2.9,調(diào)用from pyzjr.nn import EMAModel運(yùn)行。文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-808874.html
到了這里,關(guān)于模型的權(quán)值平均的原理和Pytorch的實(shí)現(xiàn)的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!