国产 无码 综合区,色欲AV无码国产永久播放,无码天堂亚洲国产AV,国产日韩欧美女同一区二区

模型的權(quán)值平均的原理和Pytorch的實(shí)現(xiàn)

這篇具有很好參考價(jià)值的文章主要介紹了模型的權(quán)值平均的原理和Pytorch的實(shí)現(xiàn)。希望對(duì)大家有所幫助。如果存在錯(cuò)誤或未考慮完全的地方,請(qǐng)大家不吝賜教,您也可以點(diǎn)擊"舉報(bào)違法"按鈕提交疑問(wèn)。

一、前言

模型權(quán)值平均是一種用于改善深度神經(jīng)網(wǎng)絡(luò)泛化性能的技術(shù)。通過(guò)對(duì)訓(xùn)練過(guò)程中不同時(shí)間步的模型權(quán)值進(jìn)行平均,可以得到更寬的極值點(diǎn)(optima)并提高模型的泛化能力。 在PyTorch中,官方提供了實(shí)現(xiàn)模型權(quán)值平均的方法。

這里我們首先介紹指數(shù)移動(dòng)平均(EMA)方法,它使用一個(gè)衰減系數(shù)來(lái)平衡當(dāng)前權(quán)值和先前平均權(quán)值。其次,介紹了隨機(jī)加權(quán)平均(SWA)方法,它通過(guò)將當(dāng)前權(quán)值與先前平均權(quán)值進(jìn)行加權(quán)平均來(lái)更新權(quán)值。最后,介紹了Tanh自適應(yīng)指數(shù)移動(dòng)EMA算法(T_ADEMA),它使用Tanh函數(shù)來(lái)調(diào)整衰減系數(shù),以更好地適應(yīng)訓(xùn)練過(guò)程中的不同階段。

為了方便使用這些權(quán)值平均方法,我將官方的代碼寫成了一個(gè)基類AveragingBaseModel,以此引出EMAModel、SWAModel和T_ADEMAModel等方法。這些類可以用于包裝原始模型,并在訓(xùn)練過(guò)程中更新平均權(quán)值。 為了驗(yàn)證這些權(quán)值平均方法的效果,我還在ResNet18模型上進(jìn)行了簡(jiǎn)單的實(shí)驗(yàn)。實(shí)驗(yàn)結(jié)果表明,使用權(quán)值平均方法可以提高模型的準(zhǔn)確率,尤其是在訓(xùn)練后期。

但請(qǐng)注意,博客中所提供的代碼示例僅用于演示權(quán)值平均的原理和PyTorch的實(shí)現(xiàn)方式,并不能保證在所有情況下都能取得理想的效果。實(shí)際應(yīng)用中,還需要根據(jù)具體任務(wù)和數(shù)據(jù)集來(lái)選擇適合的權(quán)值平均方法和參數(shù)設(shè)置。

二、算法介紹

基類實(shí)現(xiàn)

這里我們的基類完全是參照于torch源碼部分,僅僅進(jìn)行了一點(diǎn)細(xì)微的修改。

它首先通過(guò)de_parallel函數(shù)將原始模型轉(zhuǎn)換為單個(gè)GPU模型。de_parallel函數(shù)用于處理并行模型,將其轉(zhuǎn)換為單個(gè)GPU模型。然后,它將轉(zhuǎn)換后的模型復(fù)制到適當(dāng)?shù)脑O(shè)備(CPU或GPU)上(這一步很重要,問(wèn)題大多數(shù)就是因?yàn)橛?jì)算不匹配),并注冊(cè)一個(gè)名為n_averaged的緩沖區(qū),用于跟蹤已平均的次數(shù)。

在forward方法中,它簡(jiǎn)單地將調(diào)用傳遞給轉(zhuǎn)換后的模型。update方法首先獲取當(dāng)前模型和新模型的參數(shù),并將它們轉(zhuǎn)換為可迭代對(duì)象,用于更新平均權(quán)值。它接受一個(gè)新的模型作為參數(shù),并將其與當(dāng)前模型(已平均的權(quán)值)進(jìn)行比較。

from copy import deepcopy
from pyzjr.core.general import is_parallel
import itertools
from torch.nn import Module

def de_parallel(model):
    """
    將并行模型(DataParallel 或 DistributedDataParallel)轉(zhuǎn)換為單 GPU 模型。
    """
    return model.module if is_parallel(model) else model

class AveragingBaseModel(Module):
    def __init__(self, model, cuda=False, avg_fn=None, use_buffers=False):
        super(AveragingBaseModel, self).__init__()
        device = 'cuda' if cuda and torch.cuda.is_available() else 'cpu'
        self.module = deepcopy(de_parallel(model))
        self.module = self.module.to(device)
        self.register_buffer('n_averaged',
                             torch.tensor(0, dtype=torch.long, device=device))
        self.avg_fn = avg_fn
        self.use_buffers = use_buffers

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)

    def update(self, model):
        self_param = itertools.chain(self.module.parameters(), self.module.buffers() if self.use_buffers else [])
        model_param = itertools.chain(model.parameters(), model.buffers() if self.use_buffers else [])

        self_param_detached = [p.detach() for p in self_param]
        model_param_detached = [p.detach().to(p_averaged.device) for p, p_averaged in zip(model_param, self_param_detached)]

        if self.n_averaged == 0:
            for p_averaged, p_model in zip(self_param_detached, model_param_detached):
                p_averaged.copy_(p_model)

        if self.n_averaged > 0:
            for p_averaged, p_model in zip(self_param_detached, model_param_detached):
                n_averaged = self.n_averaged.to(p_averaged.device)
                p_averaged.copy_(self.avg_fn(p_averaged, p_model, n_averaged))

        if not self.use_buffers:
            for b_swa, b_model in zip(self.module.buffers(), model.buffers()):
                b_swa.copy_(b_model.to(b_swa.device).detach())

        self.n_averaged += 1

若當(dāng)前模型尚未進(jìn)行過(guò)平均(即n_averaged為0),則直接將新模型的參數(shù)復(fù)制到當(dāng)前模型中。若當(dāng)前模型已經(jīng)進(jìn)行過(guò)平均,則通過(guò)avg_fn函數(shù)計(jì)算當(dāng)前模型和新模型的加權(quán)平均,并將結(jié)果復(fù)制到當(dāng)前模型中。如果use_buffers為True,則會(huì)將緩沖區(qū)從新模型復(fù)制到當(dāng)前模型。最后,n_averaged增加1,表示已進(jìn)行一次平均。

指數(shù)移動(dòng)平均(EMA)

EMA被用于根據(jù)當(dāng)前參數(shù)和之前的平均參數(shù)來(lái)更新平均參數(shù)。其計(jì)算公式如下所示:

模型的權(quán)值平均的原理和Pytorch的實(shí)現(xiàn),Pytorch學(xué)習(xí)及實(shí)戰(zhàn),pytorch復(fù)現(xiàn),pytorch,人工智能,python

這里的EMA param是當(dāng)前的平均參數(shù),current param是當(dāng)前的參數(shù),decay是一個(gè)介于0和1之間的衰減因子,它用于控制當(dāng)前參數(shù)對(duì)平均參數(shù)的貢獻(xiàn)程度。decay越接近1,平均參數(shù)對(duì)當(dāng)前參數(shù)的影響就越小,反之亦是。

def get_ema_avg_fn(decay=0.999):
    @torch.no_grad()
    def ema_update(ema_param, current_param, num_averaged):
        return decay * ema_param + (1 - decay) * current_param
    return ema_update

class EMAModel(AveragingBaseModel):
    def __init__(self, model, cuda = False, decay=0.9, use_buffers=False):
        super().__init__(model=model, cuda=cuda, avg_fn=get_ema_avg_fn(decay), use_buffers=use_buffers)

隨機(jī)加權(quán)平均(SWA)

SWA通過(guò)對(duì)神經(jīng)網(wǎng)絡(luò)的權(quán)重進(jìn)行平均來(lái)改善模型的泛化能力。其計(jì)算公式如下所示:

模型的權(quán)值平均的原理和Pytorch的實(shí)現(xiàn),Pytorch學(xué)習(xí)及實(shí)戰(zhàn),pytorch復(fù)現(xiàn),pytorch,人工智能,python

SWA param是新的平均參數(shù),averaged param是之前的平均參數(shù),current param是當(dāng)前的參數(shù),num avg是已經(jīng)平均的參數(shù)數(shù)量。

def get_swa_avg_fn():
    @torch.no_grad()
    def swa_update(averaged_param, current_param, num_averaged):
        return averaged_param + (current_param - averaged_param) / (num_averaged + 1)
    return swa_update

class SWAModel(AveragingBaseModel):
    def __init__(self, model, cuda = False,use_buffers=False):
        super().__init__(model=model, cuda=cuda, avg_fn=get_swa_avg_fn(), use_buffers=use_buffers)

Tanh自適應(yīng)指數(shù)移動(dòng)EMA算法(T_ADEMA)

這一個(gè)是在查詢資料的時(shí)候,找到的一篇論文描述的,是否有效,還得經(jīng)過(guò)實(shí)驗(yàn)才對(duì)。

全文閱讀--XML全文閱讀--中國(guó)知網(wǎng) (cnki.net)

論文表示是為了在神經(jīng)網(wǎng)絡(luò)訓(xùn)練過(guò)程中根據(jù)不同的訓(xùn)練階段更有效地過(guò)濾噪聲,所提出的公式:

模型的權(quán)值平均的原理和Pytorch的實(shí)現(xiàn),Pytorch學(xué)習(xí)及實(shí)戰(zhàn),pytorch復(fù)現(xiàn),pytorch,人工智能,python

T_ADEMA param是新的平均參數(shù),avg param是之前的平均參數(shù),current param是當(dāng)前的參數(shù),num avg是已經(jīng)平均的參數(shù)數(shù)量。alpha是一個(gè)控制衰減速率的超參數(shù)。通過(guò)將參數(shù)數(shù)量作為輸入傳遞給切線函數(shù)的參數(shù),動(dòng)態(tài)地計(jì)算衰減因子。切線函數(shù)(tanh)的輸出范圍為[-1, 1],隨著參數(shù)數(shù)量的增加,衰減因子會(huì)逐漸趨近于1。由于切線函數(shù)的特性,當(dāng)參數(shù)數(shù)量較小時(shí),衰減因子接近于0;當(dāng)參數(shù)數(shù)量較大時(shí),衰減因子接近于1。

def get_t_adema(alpha=0.9):
    num_averaged = [0]  # 使用列表包裝可變對(duì)象,以在閉包中引用
    @torch.no_grad()
    def t_adema_update(averaged_param, current_param, num_averageds):
        num_averaged[0] += 1
        decay = alpha * torch.tanh(torch.tensor(num_averaged[0], dtype=torch.float32))
        tadea_update = decay * averaged_param + (1 - decay) * current_param
        return tadea_update
    return t_adema_update

class T_ADEMAModel(AveragingBaseModel):
    def __init__(self, model, cuda=False, alpha=0.9, use_buffers=False):
        super().__init__(model=model, cuda=cuda, avg_fn=get_t_adema(alpha), use_buffers=use_buffers)

三、構(gòu)建一個(gè)簡(jiǎn)單的實(shí)驗(yàn)測(cè)試

這一部分我正在做實(shí)驗(yàn),下面是調(diào)用了一個(gè)簡(jiǎn)單的resnet18網(wǎng)絡(luò),看看邏輯上面是否有錯(cuò)。

if __name__=="__main__":
    # 創(chuàng)建 ResNet18 模型
    import torch
    import torchvision.models as models
    from torch.utils.data import DataLoader
    from tqdm import tqdm
    from torch.optim.swa_utils import AveragedModel

    class RandomDataset(torch.utils.data.Dataset):
        def __init__(self, size=224):
            self.data = torch.randn(size, 3, 224, 224)
            self.labels = torch.randint(0, 2, (size,))

        def __getitem__(self, index):
            return self.data[index], self.labels[index]

        def __len__(self):
            return len(self.data)


    model = models.resnet18(pretrained=False)
    # model = model.to('cuda')
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.CrossEntropyLoss()

    # 創(chuàng)建數(shù)據(jù)加載器
    train_dataset = RandomDataset()
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

    # 定義權(quán)重平均模型
    swa_model = SWAModel(model, cuda=True)
    ema_model = EMAModel(model, cuda=True)
    t_adema_model = T_ADEMAModel(model, cuda=True)

    for epoch in range(5):
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{5}"):
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # 更新權(quán)重平均模型
            ema_model.update(model)
            swa_model.update(model)
            t_adema_model.update(model)

    # 測(cè)試模型
    test_dataset = RandomDataset(size=100)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


    def evaluate(model):
        model.eval()  # 切換到評(píng)估模式
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to('cuda'), labels.to('cuda')
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        accuracy = correct / total
        print(f"模型準(zhǔn)確率:{accuracy * 100:.2f}%")

    # 原模型測(cè)試
    print("Model Evaluation:")
    evaluate(model.to('cuda'))   #
    # 測(cè)試權(quán)重平均模型
    print("SWAModel Evaluation:")
    evaluate(swa_model.to('cuda'))

    print("EMAModel Evaluation:")
    evaluate(ema_model.to('cuda'))

    print("T-ADEMAModel Evaluation:")
    evaluate(t_adema_model.to('cuda'))

運(yùn)行效果:

Model Evaluation:
模型準(zhǔn)確率:46.00%
SWAModel Evaluation:
模型準(zhǔn)確率:54.00%
EMAModel Evaluation:
模型準(zhǔn)確率:58.00%
T - ADEMAModel Evaluation:
模型準(zhǔn)確率:58.00%

僅僅是測(cè)試是否能夠跑通,過(guò)程中也有比原模型要低的時(shí)候,而且權(quán)值平均主要是用于訓(xùn)練中后期,所以有沒(méi)有效果應(yīng)該需要自己去做實(shí)驗(yàn)。

當(dāng)前你可以下載pip install pyzjr==1.2.9,調(diào)用from pyzjr.nn import EMAModel運(yùn)行。文章來(lái)源地址http://www.zghlxwxcb.cn/news/detail-808874.html

到了這里,關(guān)于模型的權(quán)值平均的原理和Pytorch的實(shí)現(xiàn)的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!

本文來(lái)自互聯(lián)網(wǎng)用戶投稿,該文觀點(diǎn)僅代表作者本人,不代表本站立場(chǎng)。本站僅提供信息存儲(chǔ)空間服務(wù),不擁有所有權(quán),不承擔(dān)相關(guān)法律責(zé)任。如若轉(zhuǎn)載,請(qǐng)注明出處: 如若內(nèi)容造成侵權(quán)/違法違規(guī)/事實(shí)不符,請(qǐng)點(diǎn)擊違法舉報(bào)進(jìn)行投訴反饋,一經(jīng)查實(shí),立即刪除!

領(lǐng)支付寶紅包贊助服務(wù)器費(fèi)用

相關(guān)文章

覺(jué)得文章有用就打賞一下文章作者

支付寶掃一掃打賞

博客贊助

微信掃一掃打賞

請(qǐng)作者喝杯咖啡吧~博客贊助

支付寶掃一掃領(lǐng)取紅包,優(yōu)惠每天領(lǐng)

二維碼1

領(lǐng)取紅包

二維碼2

領(lǐng)紅包