国产 无码 综合区,色欲AV无码国产永久播放,无码天堂亚洲国产AV,国产日韩欧美女同一区二区

【FedAvg論文筆記】&【代碼復(fù)現(xiàn)】

這篇具有很好參考價(jià)值的文章主要介紹了【FedAvg論文筆記】&【代碼復(fù)現(xiàn)】。希望對大家有所幫助。如果存在錯(cuò)誤或未考慮完全的地方,請大家不吝賜教,您也可以點(diǎn)擊"舉報(bào)違法"按鈕提交疑問。

目錄

一、FedAvg原始論文筆記

1、聯(lián)邦優(yōu)化問題:?

2、聯(lián)邦平均算法:

FedSGD算法:

FedAvg算法:

實(shí)驗(yàn)結(jié)果:

3、代碼解釋

?3.1、main_fed.py主函數(shù)

3.2、Fed.py:

3.3、Nets.py:模型定義

3.4、option.py超參數(shù)設(shè)置

3.5、sampling.py:

3.6、update.py :局部更新

3.7、main_nn.py對照組 普通的nn


一、FedAvg原始論文筆記

聯(lián)邦平均算法經(jīng)典論文:McMahan B, Moore E, Ramage D, et al. Communication-efficient learning of deep networks from decentralized data[C]//Artificial intelligence and statistics. PMLR, 2017: 1273-1282.

我們知道聯(lián)邦學(xué)習(xí)的思想就在于分布式的機(jī)器學(xué)習(xí),同時(shí)兼顧了數(shù)據(jù)安全問題。而聯(lián)邦平均算法是其中最典型的算法之一,F(xiàn)edAvg算法將每個(gè)客戶端上的本地隨機(jī)梯度下降和執(zhí)行模型的平均服務(wù)器結(jié)合在一起。

1、聯(lián)邦優(yōu)化問題:?

?1、數(shù)據(jù)非獨(dú)立同分布

?2、數(shù)據(jù)分布的不平衡性

?3、用戶規(guī)模大

?4、通信有限

其中最重要的就是要理解什么是客戶端數(shù)據(jù)集非獨(dú)立同分布?

舉個(gè)栗子,假設(shè)某數(shù)據(jù)集A的train data中有5(1-5)個(gè)類別的手寫數(shù)字250張,client1 本地?cái)?shù)據(jù)集只有1、2手寫數(shù)字50張(此時(shí)的1數(shù)據(jù)集占比為1/5),client2擁有的2、3、4、5手寫圖片張200(4/5),可想而知他們利用本地?cái)?shù)據(jù)集進(jìn)行學(xué)習(xí),client1只能學(xué)習(xí)到1,2。client2只能學(xué)習(xí)到2、3、4再通過依靠數(shù)據(jù)集占比的權(quán)重聚合后,所得到的全局模型對1的學(xué)習(xí)能力會變得更弱。從這個(gè)例子來看,客戶端數(shù)據(jù)集非獨(dú)立同分布提現(xiàn)了樣本類別少,不能代表全局樣本的分布。

更有復(fù)雜的情況,樣本標(biāo)簽混亂,不單一的情況下,數(shù)據(jù)集非獨(dú)立同分布情況會更嚴(yán)重。

2、聯(lián)邦平均算法:

我們需要注意的是,相比于傳統(tǒng)的數(shù)據(jù)中心處理模式,在聯(lián)邦學(xué)習(xí)中,客戶端本地的計(jì)算量和服務(wù)器中聚合模型所花費(fèi)的計(jì)算量是花費(fèi)很小的,但客戶端與服務(wù)器之間的通信代價(jià)較大,故文中提出兩種方法以降低通信成本:

1、增加并行性(即使用更多的客戶端獨(dú)立訓(xùn)練模型

2、增加每個(gè)客戶端計(jì)算量

首先本文提出FedSGD算法:

FedSGD算法

對K個(gè)客戶端的數(shù)據(jù)計(jì)算其損失梯度,(F(Wt)表示在模型wt下數(shù)據(jù)的損失函數(shù)):【FedAvg論文筆記】&【代碼復(fù)現(xiàn)】

聚合K客戶端的損失梯度,得到t+1輪模型參數(shù):【FedAvg論文筆記】&【代碼復(fù)現(xiàn)】

而FedAvg算法就是在在本地執(zhí)行了多次的FedSGD,在選定一定比例的客戶端參加訓(xùn)練,而不是全部(實(shí)驗(yàn)部分會指出,全部的客戶端參加比部分客戶端才加的收斂速度慢,模型精度低。)

FedAvg算法:

在客戶端進(jìn)行局部模型的更新:【FedAvg論文筆記】&【代碼復(fù)現(xiàn)】

在服務(wù)器將局部模型上傳,只進(jìn)行一個(gè)平均算法:【FedAvg論文筆記】&【代碼復(fù)現(xiàn)】

可以看出,該算法將計(jì)算量放在了本地客戶端,而服務(wù)器只用于聚合平均。故我們可以在平均步驟之前進(jìn)行多次局部模型的更新。(這兒不防思考一下,這個(gè)次數(shù)是不是越多越好,我們知道過少本地?cái)?shù)據(jù)集樣本,過多的本地迭代輪次會造成什么問題?————過擬合

而上述計(jì)算量的大小由三個(gè)參數(shù)控制,即為C(客戶端隨機(jī)選取的比例)、E(客戶端在第t輪通過本地?cái)?shù)據(jù)集訓(xùn)練的次數(shù))、B(參與本地局部模型更新所需的數(shù)據(jù)批量size)

所以,上述的FedSGD算法中有:C=1,E=1,B=無窮大

故,對于第K個(gè)客戶端本地?cái)?shù)據(jù)集大小為nk時(shí),可得到這個(gè)客戶端每輪的本地更新數(shù)為:

【FedAvg論文筆記】&【代碼復(fù)現(xiàn)】

ps:客戶端本地?cái)?shù)據(jù)集與局部訓(xùn)練輪次的乘積/批量處理大小,為這個(gè)本輪客戶端本地SGD的次數(shù),F(xiàn)edAvg的偽代碼如下:

【FedAvg論文筆記】&【代碼復(fù)現(xiàn)】

【FedAvg論文筆記】&【代碼復(fù)現(xiàn)】

實(shí)驗(yàn)結(jié)果:

1、基于mnist數(shù)據(jù)集手寫照片的數(shù)字識別任務(wù):

MNIST 2NN :一個(gè)簡單的多層感知器,2個(gè)隱藏層,每個(gè)隱藏層200個(gè)單元,使用ReLu激活(199210個(gè)參數(shù))

CNN:由兩個(gè)5x5卷積層的CNN層(第一層有32個(gè)通道,第二個(gè)有64個(gè),每個(gè)之后是2x2 max池化),一個(gè)全連接層(有512個(gè)單元)和ReLu激活,最后是一個(gè)softmax輸出層(1663370個(gè)參數(shù))

增加并行性實(shí)驗(yàn):使用比例C控制并行處理的客戶端數(shù)量

增加本地計(jì)算量實(shí)驗(yàn)結(jié)果:使用B(更新數(shù)據(jù)批量大?。┖虴(本地?cái)?shù)據(jù)訓(xùn)練次數(shù))來控制本地計(jì)算量

下圖

可以看到隨著比例C的增大,訓(xùn)練輪數(shù)在減小,C過大時(shí)會在指定時(shí)間內(nèi)達(dá)不到希望的準(zhǔn)確度。

?

【FedAvg論文筆記】&【代碼復(fù)現(xiàn)】

(左圖)可以看到對于獨(dú)立同分布數(shù)據(jù),B=無窮大、E=1(數(shù)據(jù)大批量更新、本地訓(xùn)練1次時(shí))的效果最差,B=10、E=20時(shí)(小批量數(shù)據(jù)更新、本地訓(xùn)練數(shù)據(jù)次數(shù)20次)精度最高。

(右圖)類似于右圖效果。

【FedAvg論文筆記】&【代碼復(fù)現(xiàn)】

而在下圖中:可以清楚的看到并不是局部模型更新的次數(shù)越高越好,E=1比E=5的訓(xùn)練效果要好得多。

【FedAvg論文筆記】&【代碼復(fù)現(xiàn)】

思考:FedAvg算法的局限性主要在于:對于網(wǎng)絡(luò)的連通性要求十分嚴(yán)格,不同的客戶端規(guī)定采用一致的局部模型更新次數(shù)的做法過于死板,可能會導(dǎo)致模型過擬合。

但是,FedAvg會“拋棄落后者”或者合并“落后者”信息,即直接丟棄無法完成指定計(jì)算輪數(shù)E的設(shè)備,或者將未完成的設(shè)備信息聚合,會影響模型的收斂,加大計(jì)算量。(后面的prox算法,主要會解決這個(gè)問題)

3、代碼解釋

在此之前,看到這兒的人一定要懂得,跑一個(gè)項(xiàng)目,就一定得看項(xiàng)目的readme文件,這個(gè)文件里面幾乎什么都會寫到,比如這個(gè)項(xiàng)目所依賴的環(huán)境。配置環(huán)境不難 但就是很煩人。

代碼是在Git上獲取的:federated-learning · GitHub

【FedAvg論文筆記】&【代碼復(fù)現(xiàn)】

?3.1、main_fed.py主函數(shù)

首先,前面一大段基本是導(dǎo)入工具包的過程,這個(gè)不重要:

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch

from utils.sampling import mnist_iid, mnist_noniid, cifar_iid
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar
from models.Fed import FedAvg
from models.test import test_img

接下來是main函數(shù):

首先傳參,接下來調(diào)用設(shè)備,首選cuda 其次cpu?

if __name__ == '__main__':
    # parse args
    args = args_parser()#用于調(diào)用option.py的函數(shù)
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

接下來是加載數(shù)據(jù)集,劃分?jǐn)?shù)據(jù)集,這兒注意,‘../data/mnist/’的意思是 將mnist數(shù)據(jù)集下載到一級文件夾下的data文件夾中,也可以手動指定。。

    if args.dataset == 'mnist':
        #tensor就是個(gè)多維數(shù)組
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        # trans_mnist處理方式 將圖片轉(zhuǎn)化為tensor張量類型,進(jìn)行歸一化處理
        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
        dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
        # 數(shù)據(jù)集的訓(xùn)練和測試調(diào)用datasets庫 數(shù)據(jù)集內(nèi)容被下載到data文件夾中的cifar和mnist文件夾

        # sample users
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = mnist_noniid(dataset_train, args.num_users)
            # 數(shù)據(jù)劃分方式將數(shù)據(jù)分為 iid 和 non-iid 兩種

    elif args.dataset == 'cifar':#類似對mnist上面的操作
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar)
        dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in CIFAR10')
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape

接下來是build model 階段:

# build model
    #這兒得使用model文件夾下定義的nets.py中的神經(jīng)網(wǎng)絡(luò)模型
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)#打印具體網(wǎng)絡(luò)結(jié)構(gòu)
    net_glob.train()#對網(wǎng)絡(luò)進(jìn)行訓(xùn)練

接下來是復(fù)制權(quán)重與訓(xùn)練過程:

# copy weights復(fù)制權(quán)重
    w_glob = net_glob.state_dict()

    # training
    #fedavg 核心代碼
    loss_train = []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0 # 預(yù)測損失,計(jì)數(shù)器
    net_best = None
    best_loss = None 
    val_acc_list, net_list = [], []# 剛開始 先置空

    if args.all_clients:
        print("Aggregation over all clients")
        w_locals = [w_glob for i in range(args.num_users)]# 給參與訓(xùn)練的局部下發(fā)全局初始模型
    for iter in range(args.epochs):# epochs 局部迭代輪次
        loss_locals = [] # 局部預(yù)測損失
        if not args.all_clients:
            w_locals = []
        
        m = max(int(args.frac * args.num_users), 1)#每輪被選參與聯(lián)邦學(xué)習(xí)的用戶比例frac
        #sample client
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)#隨機(jī)選取用戶
        
        for idx in idxs_users:
            #local model training process
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            # 初始的本地模型利用deepcopy函數(shù) 深復(fù)制來源于 全局下發(fā)的初始模型 net_glob 傳給(args.device)計(jì)算局部損失
            if args.all_clients:
                w_locals[idx] = copy.deepcopy(w)
            else:
                w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))# 局部損失以列表的形式往后添加
            #w_locals以列表的形式匯總本地客戶端訓(xùn)練權(quán)重結(jié)果
            
        # update global weights全局更新
        w_glob = FedAvg(w_locals)# 調(diào)用FedAvg函數(shù)進(jìn)行更新聚合 得到全局模型

        # copy weight to net_glob
        net_glob.load_state_dict(w_glob)#復(fù)制權(quán)重 準(zhǔn)備下發(fā)

        # print loss在每輪后打印輸出全局訓(xùn)練損失
        loss_avg = sum(loss_locals) / len(loss_locals)
        print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
        loss_train.append(loss_avg)

    # plot loss curve
    plt.figure()
    plt.plot(range(len(loss_train)), loss_train)
    plt.ylabel('train_loss')
    plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))

接下來就是,測試:

# testing
    net_glob.eval()# eavl()函數(shù) 關(guān)閉batch normalization與dropout 處理
    acc_train, loss_train = test_img(net_glob, dataset_train, args)
    acc_test, loss_test = test_img(net_glob, dataset_test, args)
    print("Training accuracy: {:.2f}".format(acc_train))
    print("Testing accuracy: {:.2f}".format(acc_test))

3.2、Fed.py:

FedAvg函數(shù)定義如下:

def FedAvg(w):
    w_avg = copy.deepcopy(w[0]) # 利用深拷貝獲取初始w_0
    for k in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[k] += w[i][k] # 累加
        w_avg[k] = torch.div(w_avg[k], len(w)) #平均
    return w_avg

3.3、Nets.py:模型定義

繼承nn.Module類構(gòu)造自己的神經(jīng)網(wǎng)絡(luò),定義輸入、隱藏、輸出層,利用nn.linear設(shè)置網(wǎng)絡(luò)中的全連接。定義前向傳播 forward()

import torch
from torch import nn
import torch.nn.functional as F

class MLP(nn.Module):#多層感知機(jī)
    def __init__(self,dim_in,dim_hidden,dim_out):#定義
        super(MLP,self).__init__()#進(jìn)行初始化
        self.layer_input = nn.Linear(dim_in, dim_hidden)#nn.linear線性變換
        self.relu = nn.ReLU()#激活函數(shù)
        self.dropout = nn.Dropout()#防止過擬合而設(shè)置的
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = x.view(-1, x.shape[1] * x.shape[-2] * x.shape[-1])
        #shape快速讀取矩陣向量的形狀,將其傳入全連接層
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return x

定義處理mnist、cifar數(shù)據(jù)集的CNN:這個(gè)也是繼承nn.module:

class CNNMnist(nn.Module):#處理MNIST的CNN
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        #兩個(gè)卷積層
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        #卷積核大小為5*5,nn.conv2d為2維卷積神經(jīng)網(wǎng)絡(luò)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        #in_channel=10,out_channel=20
        self.conv2_drop = nn.Dropout2d()
        #全連接層
        self.fc1 = nn.Linear(320, 50)#輸入特征和輸出特征數(shù)
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x):
        #卷積層-》池化層-》激活函數(shù)
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])#展開數(shù)據(jù),將要輸入全連接層
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x


class CNNCifar(nn.Module):#卷積神經(jīng)網(wǎng)絡(luò)
    def __init__(self, args):
        super(CNNCifar, self).__init__()
        #兩個(gè)卷積層
        self.conv1 = nn.Conv2d(3, 6, 5)#輸入三個(gè)通道圖片,產(chǎn)生6個(gè)特征
        self.pool = nn.MaxPool2d(2, 2)#最大池化層2*2
        self.conv2 = nn.Conv2d(6, 16, 5)#產(chǎn)生16個(gè)更深層次的特征
        self.fc1 = nn.Linear(16 * 5 * 5, 120)#添加全連接層
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)#平鋪圖片為16*5*5
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

3.4、option.py超參數(shù)設(shè)置

python文件中,實(shí)驗(yàn)參數(shù)可在這兒修改 也可以在終端運(yùn)行的時(shí)候直接鍵入。

import argparse

def args_parser():
    parser = argparse.ArgumentParser()
    # federated arguments
    parser.add_argument('--epochs', type=int, default=10, help="rounds of training")
    parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
    parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C")
    parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")
    parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
    parser.add_argument('--bs', type=int, default=128, help="test batch size")
    parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
    parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")
    parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample")

    # model arguments
    parser.add_argument('--model', type=str, default='mlp', help='model name')
    parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
    parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
                        help='comma-separated kernel size to use for convolution')
    parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
    parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets")
    parser.add_argument('--max_pool', type=str, default='True',
                        help="Whether use max pooling rather than strided convolutions")

    # other arguments
    parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
    parser.add_argument('--iid', action='store_true', help='whether i.i.d or not')
    parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
    parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges")
    parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
    parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
    parser.add_argument('--verbose', action='store_true', help='verbose print')
    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
    parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients')
    args = parser.parse_args()
    return args

3.5、sampling.py:

將數(shù)據(jù)集中的數(shù)據(jù)樣本劃分成iid/non-iid數(shù)據(jù)樣本,分配給Client。

對于獨(dú)立同分布情況,將數(shù)據(jù)集中的數(shù)據(jù)打亂,為每個(gè)Client隨機(jī)分配600個(gè)。

對于non-iid情況,根據(jù)數(shù)據(jù)集標(biāo)簽將數(shù)據(jù)集排序,將其劃分為200組大小為300的數(shù)據(jù)切片,每個(gè)client分配兩個(gè)切片。

import numpy as np
from torchvision import datasets, transforms

def mnist_iid(dataset, num_users): # mnist獨(dú)立同分布數(shù)據(jù)采樣
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users) # num_items=MINIST數(shù)據(jù)集大小/用戶數(shù)量
    # 數(shù)據(jù)集以矩陣形式存在,行為user,列為iterm,則有:len(Dataset)=num_user*num_item
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]

    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        # 從序列中隨機(jī)采樣,且不重用
        all_idxs = list(set(all_idxs) - dict_users[i])
        # all_idxs 作為序列順序
    return dict_users

def mnist_noniid(dataset, num_users): # mnist非獨(dú)立同分布數(shù)據(jù)采樣
    """
    Sample non-I.I.D client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return:
    """
    num_shards, num_imgs = 200, 300
    # num_shards 200分片索引
    idx_shard = [i for i in range(num_shards)]
    dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
    idxs = np.arange(num_shards*num_imgs) # idxs1~6000
    labels = dataset.train_labels.numpy()
    # 用numpy 將mnist數(shù)據(jù)轉(zhuǎn)化成張量tensor格式

    # sort labels 標(biāo)簽分類
    idxs_labels = np.vstack((idxs, labels))
    # 按垂直方向?qū)dxs 與 labels堆疊構(gòu)成一個(gè)新的數(shù)組
    idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]# 排序
    idxs = idxs_labels[0,:]

    # divide and assign 分配
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))
        # 從idx中隨機(jī)選擇2個(gè) 分配給客戶端,不重復(fù)
        idx_shard = list(set(idx_shard) - rand_set) # idx_shard序列0~...
        for rand in rand_set:
            dict_users[i] = np.concatenate(
                (dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
                axis=0)# 行拼接
            # concatenate() 對應(yīng)數(shù)組拼接
            # idxs 存下標(biāo) num_imgs=300 當(dāng)rand=8時(shí),idxs[2400:2700]
            # dict_users[i]=【dict_user[i],300】 每個(gè)dict_users[i]有被隨機(jī)分配300個(gè)下標(biāo)數(shù)據(jù)
    return dict_users


def cifar_iid(dataset, num_users):# cifar 獨(dú)立同分布數(shù)據(jù)
    """
    Sample I.I.D. client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users


if __name__ == '__main__':
    dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.1307,), (0.3081,))
                                       # 將照片格式轉(zhuǎn)化成張量形式 
                                       #進(jìn)行歸一化處理
                                   ]))
    num = 100
    d = mnist_noniid(dataset_train, num)

3.6、update.py :局部更新

import torch
from torch import nn, autograd
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from sklearn import metrics


class DatasetSplit(Dataset): # 數(shù)據(jù)集劃分
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self): # 數(shù)據(jù)集大小
        return len(self.idxs)

    def __getitem__(self, item):
        # sampling中idxs
        image, label = self.dataset[self.idxs[item]]
        return image, label


class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss() # 交叉熵?fù)p失函數(shù)
        self.selected_clients = [] # 用戶選取
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)
        # 將劃分的數(shù)據(jù)集當(dāng)做本地?cái)?shù)據(jù)集 進(jìn)行小批量更新 batch_size=local_bs
        # shuffle 用于打亂數(shù)據(jù)集,每次都會以不同的順序返回

    def train(self, net):
        net.train()
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
        # 優(yōu)化器 SGD,加入動量momentum 學(xué)習(xí)率:lr

        epoch_loss = [] # 每迭代一次的損失
        for iter in range(self.args.local_ep):
            batch_loss = [] # 為了提高計(jì)算效率,不會對每個(gè)client進(jìn)行l(wèi)oss統(tǒng)計(jì),統(tǒng)計(jì)batch_loss
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                # enumerate()函數(shù)將()里面的內(nèi)容 轉(zhuǎn)化成為一個(gè)序列,一個(gè)一個(gè)的取出 batch_size大小的數(shù)據(jù),訓(xùn)練
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad() # 將其所有參數(shù)(包括子模塊的參數(shù))的梯度設(shè)置為零
                log_probs = net(images) # 獲得前向傳播結(jié)果
                loss = self.loss_func(log_probs, labels) #計(jì)算損失
                loss.backward() # 反向傳播損失
                optimizer.step()
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter,
                        batch_idx * len(images), 
                        len(self.ldr_train.dataset),
                        100. * batch_idx / len(self.ldr_train),
                        loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
            # 總的批量損失/批量個(gè)數(shù)=一個(gè)epoch的損失
            # 一行一行附加到epoch_loss序列中
            
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
        # 局部迭代loss之和/迭代輪次=平均每epoch損失

3.7、main_nn.py對照組 普通的nn

注意,這兒Git上的的main_nn.py中定義了text函數(shù),這與調(diào)用的pytest發(fā)生了矛盾,所以我將text()改成了ceshi()

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torchvision import datasets, transforms

from utils.options import args_parser
from models.Nets import MLP, CNNMnist, CNNCifar

# main_nn.py普通nn對比main_Fed.py
# 運(yùn)行測試集并輸出準(zhǔn)確率與Loss大?。ń徊骒睾瘮?shù),適用于多標(biāo)簽分類任務(wù))
def ceshi(net_g, data_loader):
    # testing
    net_g.eval() # 關(guān)閉歸一化化與dropout
    test_loss = 0
    correct = 0
    l = len(data_loader) # 載入數(shù)據(jù)集大小
    for idx, (data, target) in enumerate(data_loader):# 一個(gè)一個(gè)取出載入的數(shù)據(jù)
        data, target = data.to(args.device), target.to(args.device) # 傳到設(shè)備
        log_probs = net_g(data) # 獲得前向傳播結(jié)果
        test_loss += F.cross_entropy(log_probs, target).item()
        # 取出item的結(jié)果 計(jì)算交叉損失熵 付給test_loss
        y_pred = log_probs.data.max(1, keepdim=True)[1]
        # 最大值得索引位置為y_pred
        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
        # 通過與真實(shí)值的索引位置來對比


    test_loss /= len(data_loader.dataset)
    print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(data_loader.dataset),
        100. * correct / len(data_loader.dataset)))

    return correct, test_loss

# 與main_fed.py中的main函數(shù)相比,不調(diào)用fed.py即可
if __name__ == '__main__':
    # parse args
    args = args_parser()
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

    torch.manual_seed(args.seed)

    # load dataset and split users
    #分別對mnist cifar數(shù)據(jù)集載入 劃分
    if args.dataset == 'mnist':
        dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
        img_size = dataset_train[0][0].shape
    elif args.dataset == 'cifar':
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('./data/cifar', train=True, transform=transform, target_transform=None, download=True)
        img_size = dataset_train[0][0].shape
    else:
        exit('Error: unrecognized dataset')

    # build model
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)

    # training
    optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum)
    train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True)

    list_loss = []
    net_glob.train()
    for epoch in range(args.epochs):
        batch_loss = []
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(args.device), target.to(args.device)
            optimizer.zero_grad()
            output = net_glob(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 50 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.item()))
            batch_loss.append(loss.item())
        loss_avg = sum(batch_loss)/len(batch_loss)
        print('\nTrain loss:', loss_avg)
        list_loss.append(loss_avg)

    # plot loss
    plt.figure()
    plt.plot(range(len(list_loss)), list_loss)
    plt.xlabel('epochs')
    plt.ylabel('train loss')
    plt.savefig('./log/nn_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs))

    # testing
    if args.dataset == 'mnist':
        dataset_test = datasets.MNIST('./data/mnist/', train=False, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
        test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
    elif args.dataset == 'cifar':
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_test = datasets.CIFAR10('./data/cifar', train=False, transform=transform, target_transform=None, download=True)
        test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
    else:
        exit('Error: unrecognized dataset')

    print('test on', len(dataset_test), 'samples')
    test_acc, test_loss = ceshi(net_glob, test_loader)

參考:

聯(lián)邦學(xué)習(xí)方法FedAvg實(shí)戰(zhàn)(Pytorch) - 知乎 (zhihu.com)

FedAvg源碼學(xué)習(xí)_mnist_iid_idkmn_的博客-CSDN博客

機(jī)器學(xué)習(xí)中的獨(dú)立同分布_半夜起來敲代碼的博客-CSDN博客_機(jī)器學(xué)習(xí) 獨(dú)立同分布

從零開始 | FedAvg 代碼實(shí)現(xiàn)詳解 - 知乎 (zhihu.com)

pytorch教程之nn.Module類詳解——使用Module類來自定義模型_LoveMIss-Y的博客-CSDN博客

【代碼解析(3)】Communication-Efficient Learning of Deep Networks from Decentralized Data_enumerate(self.trainloader)_緘默的天空之城的博客-CSDN博客文章來源地址http://www.zghlxwxcb.cn/news/detail-429192.html

到了這里,關(guān)于【FedAvg論文筆記】&【代碼復(fù)現(xiàn)】的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!

本文來自互聯(lián)網(wǎng)用戶投稿,該文觀點(diǎn)僅代表作者本人,不代表本站立場。本站僅提供信息存儲空間服務(wù),不擁有所有權(quán),不承擔(dān)相關(guān)法律責(zé)任。如若轉(zhuǎn)載,請注明出處: 如若內(nèi)容造成侵權(quán)/違法違規(guī)/事實(shí)不符,請點(diǎn)擊違法舉報(bào)進(jìn)行投訴反饋,一經(jīng)查實(shí),立即刪除!

領(lǐng)支付寶紅包贊助服務(wù)器費(fèi)用

相關(guān)文章

  • 【自用】SAM模型論文筆記與復(fù)現(xiàn)代碼(segment-anything-model)

    【自用】SAM模型論文筆記與復(fù)現(xiàn)代碼(segment-anything-model)

    一個(gè) prompt encoder ,對提示進(jìn)行編碼, image encoder 對圖像編碼,生成embedding, 最后融合2個(gè) encoder ,再接一個(gè)輕量的 mask decoder ,輸出最后的mask。 模型結(jié)構(gòu)示意圖: 流程圖: 模型的結(jié)構(gòu)如上圖所示. prompt會經(jīng)過 prompt encoder , 圖像會經(jīng)過 image encoder 。然后將兩部分embedding經(jīng)過一個(gè)

    2024年01月24日
    瀏覽(24)
  • 經(jīng)典神經(jīng)網(wǎng)絡(luò)論文超詳細(xì)解讀(八)——ResNeXt學(xué)習(xí)筆記(翻譯+精讀+代碼復(fù)現(xiàn))

    經(jīng)典神經(jīng)網(wǎng)絡(luò)論文超詳細(xì)解讀(八)——ResNeXt學(xué)習(xí)筆記(翻譯+精讀+代碼復(fù)現(xiàn))

    今天我們一起來學(xué)習(xí)何愷明大神的又一經(jīng)典之作:? ResNeXt(《Aggregated Residual Transformations for Deep Neural Networks》) 。這個(gè)網(wǎng)絡(luò)可以被解釋為 VGG、ResNet?和 Inception 的結(jié)合體,它通過重復(fù)多個(gè)block(如在 VGG 中)塊組成,每個(gè)block塊聚合了多種轉(zhuǎn)換(如 Inception),同時(shí)考慮到跨層

    2024年02月03日
    瀏覽(30)
  • 經(jīng)典神經(jīng)網(wǎng)絡(luò)論文超詳細(xì)解讀(五)——ResNet(殘差網(wǎng)絡(luò))學(xué)習(xí)筆記(翻譯+精讀+代碼復(fù)現(xiàn))

    經(jīng)典神經(jīng)網(wǎng)絡(luò)論文超詳細(xì)解讀(五)——ResNet(殘差網(wǎng)絡(luò))學(xué)習(xí)筆記(翻譯+精讀+代碼復(fù)現(xiàn))

    《Deep Residual Learning for Image Recognition》這篇論文是何愷明等大佬寫的,在深度學(xué)習(xí)領(lǐng)域相當(dāng)經(jīng)典,在2016CVPR獲得best paper。今天就讓我們一起來學(xué)習(xí)一下吧! 論文原文:https://arxiv.org/abs/1512.03385 前情回顧: 經(jīng)典神經(jīng)網(wǎng)絡(luò)論文超詳細(xì)解讀(一)——AlexNet學(xué)習(xí)筆記(翻譯+精讀)

    2024年02月08日
    瀏覽(23)
  • MFAN論文閱讀筆記(待復(fù)現(xiàn))

    MFAN論文閱讀筆記(待復(fù)現(xiàn))

    論文標(biāo)題:MFAN: Multi-modal Feature-enhanced Attention Networks for Rumor Detection 論文作者:Jiaqi Zheng, Xi Zhang, Sanchuan Guo, Quan Wang, Wenyu Zang, Yongdong Zhang 論文來源:IJCAI 2022 代碼來源:Code 一系列 基于深度神經(jīng)網(wǎng)絡(luò) 融合 文本和視覺特征 以產(chǎn)生多模態(tài)后表示的多媒體謠言檢測器被提出,其表現(xiàn)

    2024年02月08日
    瀏覽(24)
  • FixMatch+DST論文閱讀筆記(待復(fù)現(xiàn))

    FixMatch+DST論文閱讀筆記(待復(fù)現(xiàn))

    論文標(biāo)題:FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence 論文作者:Kihyuk Sohn, David Berthelot, Chun-Liang Li, Zizhao Zhang, Nicholas Carlini, Ekin D. Cubuk, Alex Kurakin, Han Zhang, Colin Raffel 論文來源:NeurIPS 2020 代碼來源:Code 半監(jiān)督學(xué)習(xí)有效的利用沒有標(biāo)注的數(shù)據(jù),從而提高模型的

    2024年02月08日
    瀏覽(20)
  • 【單目3D目標(biāo)檢測】SMOKE論文解析與代碼復(fù)現(xiàn)

    【單目3D目標(biāo)檢測】SMOKE論文解析與代碼復(fù)現(xiàn)

    在正篇之前,有必要先了解一下yacs庫,因?yàn)镾MOKE源碼的參數(shù)配置文件,都是基于yacs庫建立起來的,不學(xué)看不懂?。。。。?yacs是一個(gè)用于定義和管理參數(shù)配置的庫(例如用于訓(xùn)練模型的超參數(shù)或可配置模型超參數(shù)等)。yacs使用yaml文件來配置參數(shù)。另外,yacs是在py-fast -rcnn和

    2024年02月09日
    瀏覽(48)
  • 目標(biāo)檢測論文解讀復(fù)現(xiàn)之十:基于YOLOv5的遙感圖像目標(biāo)檢測(代碼已復(fù)現(xiàn))

    目標(biāo)檢測論文解讀復(fù)現(xiàn)之十:基于YOLOv5的遙感圖像目標(biāo)檢測(代碼已復(fù)現(xiàn))

    前言 ? ? ? ?此前出了目標(biāo)改進(jìn)算法專欄,但是對于應(yīng)用于什么場景,需要什么改進(jìn)方法對應(yīng)與自己的應(yīng)用場景有效果,并且多少改進(jìn)點(diǎn)能發(fā)什么水平的文章,為解決大家的困惑,此系列文章旨在給大家解讀最新目標(biāo)檢測算法論文,幫助大家解答疑惑。解讀的系列文章,本人

    2024年02月06日
    瀏覽(31)
  • AAAI最佳論文Informer 復(fù)現(xiàn)(含python notebook代碼)

    AAAI最佳論文Informer 復(fù)現(xiàn)(含python notebook代碼)

    Github論文源碼 由于很菜,零基礎(chǔ)看源碼的時(shí)候喜歡按照代碼運(yùn)行的順序來跑一遍一個(gè)batch,從外層一點(diǎn)點(diǎn)拆進(jìn)去,看代碼內(nèi)部的邏輯。最初復(fù)現(xiàn)的時(shí)候大部分都沿用args里的default,后面再嘗試改用自己的數(shù)據(jù)+調(diào)參(哈哈至今也無法參透調(diào)參的這部分,希望不是玄學(xué)。。。) 記

    2024年02月20日
    瀏覽(21)
  • [論文筆記] Open-Sora 1、sora復(fù)現(xiàn)方案概覽

    GitHub - hpcaitech/Open-Sora: Unofficial implementation of OpenAI\\\'s Sora Open-Sora已涵蓋: 提供 完整的Sora復(fù)現(xiàn)架構(gòu)方案 ,包含從數(shù)據(jù)處理到訓(xùn)練推理全流程。 支持 動態(tài)分辨率 ,訓(xùn)練時(shí)可直接訓(xùn)練任意分辨率的視頻,無需進(jìn)行縮放。 支持 多種模型結(jié)構(gòu) 。由于Sora實(shí)際模型結(jié)構(gòu)未知,我們實(shí)現(xiàn)

    2024年03月10日
    瀏覽(20)
  • 卷積神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)—Resnet50(論文精讀+pytorch代碼復(fù)現(xiàn))

    卷積神經(jīng)網(wǎng)絡(luò)學(xué)習(xí)—Resnet50(論文精讀+pytorch代碼復(fù)現(xiàn))

    如果說在CNN領(lǐng)域一定要學(xué)習(xí)一個(gè)卷積神經(jīng)網(wǎng)絡(luò),那一定非Resnet莫屬了。 接下來我將按照:Resnet論文解讀、Pytorch實(shí)現(xiàn)ResNet50模型兩部分,進(jìn)行講解,博主也是初學(xué)者,不足之處歡迎大家批評指正。 預(yù)備知識 :卷積網(wǎng)絡(luò)的深度越深,提取的特征越高級,性能越好,但傳統(tǒng)的卷積

    2024年01月19日
    瀏覽(18)

覺得文章有用就打賞一下文章作者

支付寶掃一掃打賞

博客贊助

微信掃一掃打賞

請作者喝杯咖啡吧~博客贊助

支付寶掃一掃領(lǐng)取紅包,優(yōu)惠每天領(lǐng)

二維碼1

領(lǐng)取紅包

二維碼2

領(lǐng)紅包