目錄
一、FedAvg原始論文筆記
1、聯(lián)邦優(yōu)化問題:?
2、聯(lián)邦平均算法:
FedSGD算法:
FedAvg算法:
實(shí)驗(yàn)結(jié)果:
3、代碼解釋
?3.1、main_fed.py主函數(shù)
3.2、Fed.py:
3.3、Nets.py:模型定義
3.4、option.py超參數(shù)設(shè)置
3.5、sampling.py:
3.6、update.py :局部更新
3.7、main_nn.py對照組 普通的nn
一、FedAvg原始論文筆記
聯(lián)邦平均算法經(jīng)典論文:McMahan B, Moore E, Ramage D, et al. Communication-efficient learning of deep networks from decentralized data[C]//Artificial intelligence and statistics. PMLR, 2017: 1273-1282.
我們知道聯(lián)邦學(xué)習(xí)的思想就在于分布式的機(jī)器學(xué)習(xí),同時(shí)兼顧了數(shù)據(jù)安全問題。而聯(lián)邦平均算法是其中最典型的算法之一,F(xiàn)edAvg算法將每個(gè)客戶端上的本地隨機(jī)梯度下降和執(zhí)行模型的平均服務(wù)器結(jié)合在一起。
1、聯(lián)邦優(yōu)化問題:?
?1、數(shù)據(jù)非獨(dú)立同分布
?2、數(shù)據(jù)分布的不平衡性
?3、用戶規(guī)模大
?4、通信有限
其中最重要的就是要理解什么是客戶端數(shù)據(jù)集非獨(dú)立同分布?
舉個(gè)栗子,假設(shè)某數(shù)據(jù)集A的train data中有5(1-5)個(gè)類別的手寫數(shù)字250張,client1 本地?cái)?shù)據(jù)集只有1、2手寫數(shù)字50張(此時(shí)的1數(shù)據(jù)集占比為1/5),client2擁有的2、3、4、5手寫圖片張200(4/5),可想而知他們利用本地?cái)?shù)據(jù)集進(jìn)行學(xué)習(xí),client1只能學(xué)習(xí)到1,2。client2只能學(xué)習(xí)到2、3、4再通過依靠數(shù)據(jù)集占比的權(quán)重聚合后,所得到的全局模型對1的學(xué)習(xí)能力會變得更弱。從這個(gè)例子來看,客戶端數(shù)據(jù)集非獨(dú)立同分布提現(xiàn)了樣本類別少,不能代表全局樣本的分布。
更有復(fù)雜的情況,樣本標(biāo)簽混亂,不單一的情況下,數(shù)據(jù)集非獨(dú)立同分布情況會更嚴(yán)重。
2、聯(lián)邦平均算法:
我們需要注意的是,相比于傳統(tǒng)的數(shù)據(jù)中心處理模式,在聯(lián)邦學(xué)習(xí)中,客戶端本地的計(jì)算量和服務(wù)器中聚合模型所花費(fèi)的計(jì)算量是花費(fèi)很小的,但客戶端與服務(wù)器之間的通信代價(jià)較大,故文中提出兩種方法以降低通信成本:
1、增加并行性(即使用更多的客戶端獨(dú)立訓(xùn)練模型
2、增加每個(gè)客戶端計(jì)算量
首先本文提出FedSGD算法:
FedSGD算法:
對K個(gè)客戶端的數(shù)據(jù)計(jì)算其損失梯度,(F(Wt)表示在模型wt下數(shù)據(jù)的損失函數(shù)):
聚合K客戶端的損失梯度,得到t+1輪模型參數(shù):
而FedAvg算法就是在在本地執(zhí)行了多次的FedSGD,在選定一定比例的客戶端參加訓(xùn)練,而不是全部(實(shí)驗(yàn)部分會指出,全部的客戶端參加比部分客戶端才加的收斂速度慢,模型精度低。)
FedAvg算法:
在客戶端進(jìn)行局部模型的更新:
在服務(wù)器將局部模型上傳,只進(jìn)行一個(gè)平均算法:
可以看出,該算法將計(jì)算量放在了本地客戶端,而服務(wù)器只用于聚合平均。故我們可以在平均步驟之前進(jìn)行多次局部模型的更新。(這兒不防思考一下,這個(gè)次數(shù)是不是越多越好,我們知道過少本地?cái)?shù)據(jù)集樣本,過多的本地迭代輪次會造成什么問題?————過擬合
而上述計(jì)算量的大小由三個(gè)參數(shù)控制,即為C(客戶端隨機(jī)選取的比例)、E(客戶端在第t輪通過本地?cái)?shù)據(jù)集訓(xùn)練的次數(shù))、B(參與本地局部模型更新所需的數(shù)據(jù)批量size)
所以,上述的FedSGD算法中有:C=1,E=1,B=無窮大
故,對于第K個(gè)客戶端本地?cái)?shù)據(jù)集大小為nk時(shí),可得到這個(gè)客戶端每輪的本地更新數(shù)為:
ps:客戶端本地?cái)?shù)據(jù)集與局部訓(xùn)練輪次的乘積/批量處理大小,為這個(gè)本輪客戶端本地SGD的次數(shù),F(xiàn)edAvg的偽代碼如下:
實(shí)驗(yàn)結(jié)果:
1、基于mnist數(shù)據(jù)集手寫照片的數(shù)字識別任務(wù):
MNIST 2NN :一個(gè)簡單的多層感知器,2個(gè)隱藏層,每個(gè)隱藏層200個(gè)單元,使用ReLu激活(199210個(gè)參數(shù))
CNN:由兩個(gè)5x5卷積層的CNN層(第一層有32個(gè)通道,第二個(gè)有64個(gè),每個(gè)之后是2x2 max池化),一個(gè)全連接層(有512個(gè)單元)和ReLu激活,最后是一個(gè)softmax輸出層(1663370個(gè)參數(shù))
增加并行性實(shí)驗(yàn):使用比例C控制并行處理的客戶端數(shù)量
增加本地計(jì)算量實(shí)驗(yàn)結(jié)果:使用B(更新數(shù)據(jù)批量大?。┖虴(本地?cái)?shù)據(jù)訓(xùn)練次數(shù))來控制本地計(jì)算量
下圖
可以看到隨著比例C的增大,訓(xùn)練輪數(shù)在減小,C過大時(shí)會在指定時(shí)間內(nèi)達(dá)不到希望的準(zhǔn)確度。
?
(左圖)可以看到對于獨(dú)立同分布數(shù)據(jù),B=無窮大、E=1(數(shù)據(jù)大批量更新、本地訓(xùn)練1次時(shí))的效果最差,B=10、E=20時(shí)(小批量數(shù)據(jù)更新、本地訓(xùn)練數(shù)據(jù)次數(shù)20次)精度最高。
(右圖)類似于右圖效果。
而在下圖中:可以清楚的看到并不是局部模型更新的次數(shù)越高越好,E=1比E=5的訓(xùn)練效果要好得多。
思考:FedAvg算法的局限性主要在于:對于網(wǎng)絡(luò)的連通性要求十分嚴(yán)格,不同的客戶端規(guī)定采用一致的局部模型更新次數(shù)的做法過于死板,可能會導(dǎo)致模型過擬合。
但是,FedAvg會“拋棄落后者”或者合并“落后者”信息,即直接丟棄無法完成指定計(jì)算輪數(shù)E的設(shè)備,或者將未完成的設(shè)備信息聚合,會影響模型的收斂,加大計(jì)算量。(后面的prox算法,主要會解決這個(gè)問題)
3、代碼解釋
在此之前,看到這兒的人一定要懂得,跑一個(gè)項(xiàng)目,就一定得看項(xiàng)目的readme文件,這個(gè)文件里面幾乎什么都會寫到,比如這個(gè)項(xiàng)目所依賴的環(huán)境。配置環(huán)境不難 但就是很煩人。
代碼是在Git上獲取的:federated-learning · GitHub
?3.1、main_fed.py主函數(shù)
首先,前面一大段基本是導(dǎo)入工具包的過程,這個(gè)不重要:
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch
from utils.sampling import mnist_iid, mnist_noniid, cifar_iid
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar
from models.Fed import FedAvg
from models.test import test_img
接下來是main函數(shù):
首先傳參,接下來調(diào)用設(shè)備,首選cuda 其次cpu?
if __name__ == '__main__':
# parse args
args = args_parser()#用于調(diào)用option.py的函數(shù)
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
接下來是加載數(shù)據(jù)集,劃分?jǐn)?shù)據(jù)集,這兒注意,‘../data/mnist/’的意思是 將mnist數(shù)據(jù)集下載到一級文件夾下的data文件夾中,也可以手動指定。。
if args.dataset == 'mnist':
#tensor就是個(gè)多維數(shù)組
trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# trans_mnist處理方式 將圖片轉(zhuǎn)化為tensor張量類型,進(jìn)行歸一化處理
dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
# 數(shù)據(jù)集的訓(xùn)練和測試調(diào)用datasets庫 數(shù)據(jù)集內(nèi)容被下載到data文件夾中的cifar和mnist文件夾
# sample users
if args.iid:
dict_users = mnist_iid(dataset_train, args.num_users)
else:
dict_users = mnist_noniid(dataset_train, args.num_users)
# 數(shù)據(jù)劃分方式將數(shù)據(jù)分為 iid 和 non-iid 兩種
elif args.dataset == 'cifar':#類似對mnist上面的操作
trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar)
dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar)
if args.iid:
dict_users = cifar_iid(dataset_train, args.num_users)
else:
exit('Error: only consider IID setting in CIFAR10')
else:
exit('Error: unrecognized dataset')
img_size = dataset_train[0][0].shape
接下來是build model 階段:
# build model
#這兒得使用model文件夾下定義的nets.py中的神經(jīng)網(wǎng)絡(luò)模型
if args.model == 'cnn' and args.dataset == 'cifar':
net_glob = CNNCifar(args=args).to(args.device)
elif args.model == 'cnn' and args.dataset == 'mnist':
net_glob = CNNMnist(args=args).to(args.device)
elif args.model == 'mlp':
len_in = 1
for x in img_size:
len_in *= x
net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
else:
exit('Error: unrecognized model')
print(net_glob)#打印具體網(wǎng)絡(luò)結(jié)構(gòu)
net_glob.train()#對網(wǎng)絡(luò)進(jìn)行訓(xùn)練
接下來是復(fù)制權(quán)重與訓(xùn)練過程:
# copy weights復(fù)制權(quán)重
w_glob = net_glob.state_dict()
# training
#fedavg 核心代碼
loss_train = []
cv_loss, cv_acc = [], []
val_loss_pre, counter = 0, 0 # 預(yù)測損失,計(jì)數(shù)器
net_best = None
best_loss = None
val_acc_list, net_list = [], []# 剛開始 先置空
if args.all_clients:
print("Aggregation over all clients")
w_locals = [w_glob for i in range(args.num_users)]# 給參與訓(xùn)練的局部下發(fā)全局初始模型
for iter in range(args.epochs):# epochs 局部迭代輪次
loss_locals = [] # 局部預(yù)測損失
if not args.all_clients:
w_locals = []
m = max(int(args.frac * args.num_users), 1)#每輪被選參與聯(lián)邦學(xué)習(xí)的用戶比例frac
#sample client
idxs_users = np.random.choice(range(args.num_users), m, replace=False)#隨機(jī)選取用戶
for idx in idxs_users:
#local model training process
local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
# 初始的本地模型利用deepcopy函數(shù) 深復(fù)制來源于 全局下發(fā)的初始模型 net_glob 傳給(args.device)計(jì)算局部損失
if args.all_clients:
w_locals[idx] = copy.deepcopy(w)
else:
w_locals.append(copy.deepcopy(w))
loss_locals.append(copy.deepcopy(loss))# 局部損失以列表的形式往后添加
#w_locals以列表的形式匯總本地客戶端訓(xùn)練權(quán)重結(jié)果
# update global weights全局更新
w_glob = FedAvg(w_locals)# 調(diào)用FedAvg函數(shù)進(jìn)行更新聚合 得到全局模型
# copy weight to net_glob
net_glob.load_state_dict(w_glob)#復(fù)制權(quán)重 準(zhǔn)備下發(fā)
# print loss在每輪后打印輸出全局訓(xùn)練損失
loss_avg = sum(loss_locals) / len(loss_locals)
print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
loss_train.append(loss_avg)
# plot loss curve
plt.figure()
plt.plot(range(len(loss_train)), loss_train)
plt.ylabel('train_loss')
plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))
接下來就是,測試:
# testing
net_glob.eval()# eavl()函數(shù) 關(guān)閉batch normalization與dropout 處理
acc_train, loss_train = test_img(net_glob, dataset_train, args)
acc_test, loss_test = test_img(net_glob, dataset_test, args)
print("Training accuracy: {:.2f}".format(acc_train))
print("Testing accuracy: {:.2f}".format(acc_test))
3.2、Fed.py:
FedAvg函數(shù)定義如下:
def FedAvg(w):
w_avg = copy.deepcopy(w[0]) # 利用深拷貝獲取初始w_0
for k in w_avg.keys():
for i in range(1, len(w)):
w_avg[k] += w[i][k] # 累加
w_avg[k] = torch.div(w_avg[k], len(w)) #平均
return w_avg
3.3、Nets.py:模型定義
繼承nn.Module類構(gòu)造自己的神經(jīng)網(wǎng)絡(luò),定義輸入、隱藏、輸出層,利用nn.linear設(shè)置網(wǎng)絡(luò)中的全連接。定義前向傳播 forward()
import torch
from torch import nn
import torch.nn.functional as F
class MLP(nn.Module):#多層感知機(jī)
def __init__(self,dim_in,dim_hidden,dim_out):#定義
super(MLP,self).__init__()#進(jìn)行初始化
self.layer_input = nn.Linear(dim_in, dim_hidden)#nn.linear線性變換
self.relu = nn.ReLU()#激活函數(shù)
self.dropout = nn.Dropout()#防止過擬合而設(shè)置的
self.layer_hidden = nn.Linear(dim_hidden, dim_out)
def forward(self, x):
x = x.view(-1, x.shape[1] * x.shape[-2] * x.shape[-1])
#shape快速讀取矩陣向量的形狀,將其傳入全連接層
x = self.layer_input(x)
x = self.dropout(x)
x = self.relu(x)
x = self.layer_hidden(x)
return x
定義處理mnist、cifar數(shù)據(jù)集的CNN:這個(gè)也是繼承nn.module:
class CNNMnist(nn.Module):#處理MNIST的CNN
def __init__(self, args):
super(CNNMnist, self).__init__()
#兩個(gè)卷積層
self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
#卷積核大小為5*5,nn.conv2d為2維卷積神經(jīng)網(wǎng)絡(luò)
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
#in_channel=10,out_channel=20
self.conv2_drop = nn.Dropout2d()
#全連接層
self.fc1 = nn.Linear(320, 50)#輸入特征和輸出特征數(shù)
self.fc2 = nn.Linear(50, args.num_classes)
def forward(self, x):
#卷積層-》池化層-》激活函數(shù)
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
x = x.view(-1, x.shape[1] * x.shape[2] * x.shape[3])#展開數(shù)據(jù),將要輸入全連接層
x = F.relu(self.fc1(x))
x = F.dropout(x, training=self.training)
x = self.fc2(x)
return x
class CNNCifar(nn.Module):#卷積神經(jīng)網(wǎng)絡(luò)
def __init__(self, args):
super(CNNCifar, self).__init__()
#兩個(gè)卷積層
self.conv1 = nn.Conv2d(3, 6, 5)#輸入三個(gè)通道圖片,產(chǎn)生6個(gè)特征
self.pool = nn.MaxPool2d(2, 2)#最大池化層2*2
self.conv2 = nn.Conv2d(6, 16, 5)#產(chǎn)生16個(gè)更深層次的特征
self.fc1 = nn.Linear(16 * 5 * 5, 120)#添加全連接層
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, args.num_classes)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)#平鋪圖片為16*5*5
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
3.4、option.py超參數(shù)設(shè)置
python文件中,實(shí)驗(yàn)參數(shù)可在這兒修改 也可以在終端運(yùn)行的時(shí)候直接鍵入。
import argparse
def args_parser():
parser = argparse.ArgumentParser()
# federated arguments
parser.add_argument('--epochs', type=int, default=10, help="rounds of training")
parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C")
parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")
parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
parser.add_argument('--bs', type=int, default=128, help="test batch size")
parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")
parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample")
# model arguments
parser.add_argument('--model', type=str, default='mlp', help='model name')
parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
help='comma-separated kernel size to use for convolution')
parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets")
parser.add_argument('--max_pool', type=str, default='True',
help="Whether use max pooling rather than strided convolutions")
# other arguments
parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
parser.add_argument('--iid', action='store_true', help='whether i.i.d or not')
parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges")
parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
parser.add_argument('--verbose', action='store_true', help='verbose print')
parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients')
args = parser.parse_args()
return args
3.5、sampling.py:
將數(shù)據(jù)集中的數(shù)據(jù)樣本劃分成iid/non-iid數(shù)據(jù)樣本,分配給Client。
對于獨(dú)立同分布情況,將數(shù)據(jù)集中的數(shù)據(jù)打亂,為每個(gè)Client隨機(jī)分配600個(gè)。
對于non-iid情況,根據(jù)數(shù)據(jù)集標(biāo)簽將數(shù)據(jù)集排序,將其劃分為200組大小為300的數(shù)據(jù)切片,每個(gè)client分配兩個(gè)切片。
import numpy as np
from torchvision import datasets, transforms
def mnist_iid(dataset, num_users): # mnist獨(dú)立同分布數(shù)據(jù)采樣
"""
Sample I.I.D. client data from MNIST dataset
:param dataset:
:param num_users:
:return: dict of image index
"""
num_items = int(len(dataset)/num_users) # num_items=MINIST數(shù)據(jù)集大小/用戶數(shù)量
# 數(shù)據(jù)集以矩陣形式存在,行為user,列為iterm,則有:len(Dataset)=num_user*num_item
dict_users, all_idxs = {}, [i for i in range(len(dataset))]
for i in range(num_users):
dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
# 從序列中隨機(jī)采樣,且不重用
all_idxs = list(set(all_idxs) - dict_users[i])
# all_idxs 作為序列順序
return dict_users
def mnist_noniid(dataset, num_users): # mnist非獨(dú)立同分布數(shù)據(jù)采樣
"""
Sample non-I.I.D client data from MNIST dataset
:param dataset:
:param num_users:
:return:
"""
num_shards, num_imgs = 200, 300
# num_shards 200分片索引
idx_shard = [i for i in range(num_shards)]
dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}
idxs = np.arange(num_shards*num_imgs) # idxs1~6000
labels = dataset.train_labels.numpy()
# 用numpy 將mnist數(shù)據(jù)轉(zhuǎn)化成張量tensor格式
# sort labels 標(biāo)簽分類
idxs_labels = np.vstack((idxs, labels))
# 按垂直方向?qū)dxs 與 labels堆疊構(gòu)成一個(gè)新的數(shù)組
idxs_labels = idxs_labels[:,idxs_labels[1,:].argsort()]# 排序
idxs = idxs_labels[0,:]
# divide and assign 分配
for i in range(num_users):
rand_set = set(np.random.choice(idx_shard, 2, replace=False))
# 從idx中隨機(jī)選擇2個(gè) 分配給客戶端,不重復(fù)
idx_shard = list(set(idx_shard) - rand_set) # idx_shard序列0~...
for rand in rand_set:
dict_users[i] = np.concatenate(
(dict_users[i], idxs[rand*num_imgs:(rand+1)*num_imgs]),
axis=0)# 行拼接
# concatenate() 對應(yīng)數(shù)組拼接
# idxs 存下標(biāo) num_imgs=300 當(dāng)rand=8時(shí),idxs[2400:2700]
# dict_users[i]=【dict_user[i],300】 每個(gè)dict_users[i]有被隨機(jī)分配300個(gè)下標(biāo)數(shù)據(jù)
return dict_users
def cifar_iid(dataset, num_users):# cifar 獨(dú)立同分布數(shù)據(jù)
"""
Sample I.I.D. client data from CIFAR10 dataset
:param dataset:
:param num_users:
:return: dict of image index
"""
num_items = int(len(dataset)/num_users)
dict_users, all_idxs = {}, [i for i in range(len(dataset))]
for i in range(num_users):
dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
all_idxs = list(set(all_idxs) - dict_users[i])
return dict_users
if __name__ == '__main__':
dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
# 將照片格式轉(zhuǎn)化成張量形式
#進(jìn)行歸一化處理
]))
num = 100
d = mnist_noniid(dataset_train, num)
3.6、update.py :局部更新
import torch
from torch import nn, autograd
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from sklearn import metrics
class DatasetSplit(Dataset): # 數(shù)據(jù)集劃分
def __init__(self, dataset, idxs):
self.dataset = dataset
self.idxs = list(idxs)
def __len__(self): # 數(shù)據(jù)集大小
return len(self.idxs)
def __getitem__(self, item):
# sampling中idxs
image, label = self.dataset[self.idxs[item]]
return image, label
class LocalUpdate(object):
def __init__(self, args, dataset=None, idxs=None):
self.args = args
self.loss_func = nn.CrossEntropyLoss() # 交叉熵?fù)p失函數(shù)
self.selected_clients = [] # 用戶選取
self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)
# 將劃分的數(shù)據(jù)集當(dāng)做本地?cái)?shù)據(jù)集 進(jìn)行小批量更新 batch_size=local_bs
# shuffle 用于打亂數(shù)據(jù)集,每次都會以不同的順序返回
def train(self, net):
net.train()
# train and update
optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)
# 優(yōu)化器 SGD,加入動量momentum 學(xué)習(xí)率:lr
epoch_loss = [] # 每迭代一次的損失
for iter in range(self.args.local_ep):
batch_loss = [] # 為了提高計(jì)算效率,不會對每個(gè)client進(jìn)行l(wèi)oss統(tǒng)計(jì),統(tǒng)計(jì)batch_loss
for batch_idx, (images, labels) in enumerate(self.ldr_train):
# enumerate()函數(shù)將()里面的內(nèi)容 轉(zhuǎn)化成為一個(gè)序列,一個(gè)一個(gè)的取出 batch_size大小的數(shù)據(jù),訓(xùn)練
images, labels = images.to(self.args.device), labels.to(self.args.device)
net.zero_grad() # 將其所有參數(shù)(包括子模塊的參數(shù))的梯度設(shè)置為零
log_probs = net(images) # 獲得前向傳播結(jié)果
loss = self.loss_func(log_probs, labels) #計(jì)算損失
loss.backward() # 反向傳播損失
optimizer.step()
if self.args.verbose and batch_idx % 10 == 0:
print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
iter,
batch_idx * len(images),
len(self.ldr_train.dataset),
100. * batch_idx / len(self.ldr_train),
loss.item()))
batch_loss.append(loss.item())
epoch_loss.append(sum(batch_loss)/len(batch_loss))
# 總的批量損失/批量個(gè)數(shù)=一個(gè)epoch的損失
# 一行一行附加到epoch_loss序列中
return net.state_dict(), sum(epoch_loss) / len(epoch_loss)
# 局部迭代loss之和/迭代輪次=平均每epoch損失
3.7、main_nn.py對照組 普通的nn
注意,這兒Git上的的main_nn.py中定義了text函數(shù),這與調(diào)用的pytest發(fā)生了矛盾,所以我將text()改成了ceshi()
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torchvision import datasets, transforms
from utils.options import args_parser
from models.Nets import MLP, CNNMnist, CNNCifar
# main_nn.py普通nn對比main_Fed.py
# 運(yùn)行測試集并輸出準(zhǔn)確率與Loss大?。ń徊骒睾瘮?shù),適用于多標(biāo)簽分類任務(wù))
def ceshi(net_g, data_loader):
# testing
net_g.eval() # 關(guān)閉歸一化化與dropout
test_loss = 0
correct = 0
l = len(data_loader) # 載入數(shù)據(jù)集大小
for idx, (data, target) in enumerate(data_loader):# 一個(gè)一個(gè)取出載入的數(shù)據(jù)
data, target = data.to(args.device), target.to(args.device) # 傳到設(shè)備
log_probs = net_g(data) # 獲得前向傳播結(jié)果
test_loss += F.cross_entropy(log_probs, target).item()
# 取出item的結(jié)果 計(jì)算交叉損失熵 付給test_loss
y_pred = log_probs.data.max(1, keepdim=True)[1]
# 最大值得索引位置為y_pred
correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
# 通過與真實(shí)值的索引位置來對比
test_loss /= len(data_loader.dataset)
print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
test_loss, correct, len(data_loader.dataset),
100. * correct / len(data_loader.dataset)))
return correct, test_loss
# 與main_fed.py中的main函數(shù)相比,不調(diào)用fed.py即可
if __name__ == '__main__':
# parse args
args = args_parser()
args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')
torch.manual_seed(args.seed)
# load dataset and split users
#分別對mnist cifar數(shù)據(jù)集載入 劃分
if args.dataset == 'mnist':
dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
img_size = dataset_train[0][0].shape
elif args.dataset == 'cifar':
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset_train = datasets.CIFAR10('./data/cifar', train=True, transform=transform, target_transform=None, download=True)
img_size = dataset_train[0][0].shape
else:
exit('Error: unrecognized dataset')
# build model
if args.model == 'cnn' and args.dataset == 'cifar':
net_glob = CNNCifar(args=args).to(args.device)
elif args.model == 'cnn' and args.dataset == 'mnist':
net_glob = CNNMnist(args=args).to(args.device)
elif args.model == 'mlp':
len_in = 1
for x in img_size:
len_in *= x
net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).to(args.device)
else:
exit('Error: unrecognized model')
print(net_glob)
# training
optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum)
train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True)
list_loss = []
net_glob.train()
for epoch in range(args.epochs):
batch_loss = []
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(args.device), target.to(args.device)
optimizer.zero_grad()
output = net_glob(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
if batch_idx % 50 == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
batch_loss.append(loss.item())
loss_avg = sum(batch_loss)/len(batch_loss)
print('\nTrain loss:', loss_avg)
list_loss.append(loss_avg)
# plot loss
plt.figure()
plt.plot(range(len(list_loss)), list_loss)
plt.xlabel('epochs')
plt.ylabel('train loss')
plt.savefig('./log/nn_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs))
# testing
if args.dataset == 'mnist':
dataset_test = datasets.MNIST('./data/mnist/', train=False, download=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]))
test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
elif args.dataset == 'cifar':
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset_test = datasets.CIFAR10('./data/cifar', train=False, transform=transform, target_transform=None, download=True)
test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
else:
exit('Error: unrecognized dataset')
print('test on', len(dataset_test), 'samples')
test_acc, test_loss = ceshi(net_glob, test_loader)
參考:
聯(lián)邦學(xué)習(xí)方法FedAvg實(shí)戰(zhàn)(Pytorch) - 知乎 (zhihu.com)
FedAvg源碼學(xué)習(xí)_mnist_iid_idkmn_的博客-CSDN博客
機(jī)器學(xué)習(xí)中的獨(dú)立同分布_半夜起來敲代碼的博客-CSDN博客_機(jī)器學(xué)習(xí) 獨(dú)立同分布
從零開始 | FedAvg 代碼實(shí)現(xiàn)詳解 - 知乎 (zhihu.com)
pytorch教程之nn.Module類詳解——使用Module類來自定義模型_LoveMIss-Y的博客-CSDN博客文章來源:http://www.zghlxwxcb.cn/news/detail-429192.html
【代碼解析(3)】Communication-Efficient Learning of Deep Networks from Decentralized Data_enumerate(self.trainloader)_緘默的天空之城的博客-CSDN博客文章來源地址http://www.zghlxwxcb.cn/news/detail-429192.html
到了這里,關(guān)于【FedAvg論文筆記】&【代碼復(fù)現(xiàn)】的文章就介紹完了。如果您還想了解更多內(nèi)容,請?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!