国产 无码 综合区,色欲AV无码国产永久播放,无码天堂亚洲国产AV,国产日韩欧美女同一区二区

paddle2.3-基于聯(lián)邦學(xué)習(xí)實(shí)現(xiàn)FedAVg算法-CNN

這篇具有很好參考價(jià)值的文章主要介紹了paddle2.3-基于聯(lián)邦學(xué)習(xí)實(shí)現(xiàn)FedAVg算法-CNN。希望對(duì)大家有所幫助。如果存在錯(cuò)誤或未考慮完全的地方,請(qǐng)大家不吝賜教,您也可以點(diǎn)擊"舉報(bào)違法"按鈕提交疑問。

目錄

1. 聯(lián)邦學(xué)習(xí)介紹

2. 實(shí)驗(yàn)流程

3. 數(shù)據(jù)加載

4. 模型構(gòu)建

5. 數(shù)據(jù)采樣函數(shù)

6. 模型訓(xùn)練


paddle2.3-基于聯(lián)邦學(xué)習(xí)實(shí)現(xiàn)FedAVg算法-CNN,2023 AI,算法

paddle2.3-基于聯(lián)邦學(xué)習(xí)實(shí)現(xiàn)FedAVg算法-CNN,2023 AI,算法

paddle2.3-基于聯(lián)邦學(xué)習(xí)實(shí)現(xiàn)FedAVg算法-CNN,2023 AI,算法

1. 聯(lián)邦學(xué)習(xí)介紹

聯(lián)邦學(xué)習(xí)是一種分布式機(jī)器學(xué)習(xí)方法,中心節(jié)點(diǎn)為server(服務(wù)器),各分支節(jié)點(diǎn)為本地的client(設(shè)備)。聯(lián)邦學(xué)習(xí)的模式是在各分支節(jié)點(diǎn)分別利用本地?cái)?shù)據(jù)訓(xùn)練模型,再將訓(xùn)練好的模型匯合到中心節(jié)點(diǎn),獲得一個(gè)更好的全局模型。

聯(lián)邦學(xué)習(xí)的提出是為了充分利用用戶的數(shù)據(jù)特征訓(xùn)練效果更佳的模型,同時(shí),為了保證隱私,聯(lián)邦學(xué)習(xí)在訓(xùn)練過程中,server和clients之間通信的是模型的參數(shù)(或梯度、參數(shù)更新量),本地的數(shù)據(jù)不會(huì)上傳到服務(wù)器。

本項(xiàng)目主要是升級(jí)1.8版本的聯(lián)邦學(xué)習(xí)fedavg算法至2.3版本,內(nèi)容取材于基于PaddlePaddle實(shí)現(xiàn)聯(lián)邦學(xué)習(xí)算法FedAvg - 飛槳AI Studio星河社區(qū)

2. 實(shí)驗(yàn)流程

聯(lián)邦學(xué)習(xí)的基本流程是:

1. server初始化模型參數(shù),所有的clients將這個(gè)初始模型下載到本地;

2. clients利用本地產(chǎn)生的數(shù)據(jù)進(jìn)行SGD訓(xùn)練;

3. 選取K個(gè)clients將訓(xùn)練得到的模型參數(shù)上傳到server;

4. server對(duì)得到的模型參數(shù)整合,所有的clients下載新的模型。

5. 重復(fù)執(zhí)行2-5,直至收斂或達(dá)到預(yù)期要求

import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import random
import time
import paddle
import paddle.nn as nn
import numpy as np
from paddle.io import Dataset,DataLoader
import paddle.nn.functional as F

3. 數(shù)據(jù)加載

mnist_data_train=np.load('data/data2489/train_mnist.npy')
mnist_data_test=np.load('data/data2489/test_mnist.npy')
print('There are {} images for training'.format(len(mnist_data_train)))
print('There are {} images for testing'.format(len(mnist_data_test)))
# 數(shù)據(jù)和標(biāo)簽分離(便于后續(xù)處理)
Label=[int(i[0]) for i in mnist_data_train]
Data=[i[1:] for i in mnist_data_train]
There are 60000 images for training
There are 10000 images for testing

4. 模型構(gòu)建

class CNN(nn.Layer):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1=nn.Conv2D(1,32,5)
        self.relu = nn.ReLU()
        self.pool1=nn.MaxPool2D(kernel_size=2,stride=2)
        self.conv2=nn.Conv2D(32,64,5)
        self.pool2=nn.MaxPool2D(kernel_size=2,stride=2)
        self.fc1=nn.Linear(1024,512)
        self.fc2=nn.Linear(512,10)
        # self.softmax = nn.Softmax()
    def forward(self,inputs):
        x = self.conv1(inputs)
        x = self.relu(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool2(x)
        
        x=paddle.reshape(x,[-1,1024])
        x = self.relu(self.fc1(x))
        y = self.fc2(x)
        return y

5. 數(shù)據(jù)采樣函數(shù)

# 均勻采樣,分配到各個(gè)client的數(shù)據(jù)集都是IID且數(shù)量相等的
def IID(dataset, clients):
  num_items_per_client = int(len(dataset)/clients)
  client_dict = {}
  image_idxs = [i for i in range(len(dataset))]
  for i in range(clients):
    client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False)) # 為每個(gè)client隨機(jī)選取數(shù)據(jù)
    image_idxs = list(set(image_idxs) - client_dict[i]) # 將已經(jīng)選取過的數(shù)據(jù)去除
    client_dict[i] = list(client_dict[i])

  return client_dict
# 非均勻采樣,同時(shí)各個(gè)client上的數(shù)據(jù)分布和數(shù)量都不同
def NonIID(dataset, clients, total_shards, shards_size, num_shards_per_client):
  shard_idxs = [i for i in range(total_shards)]
  client_dict = {i: np.array([], dtype='int64') for i in range(clients)}
  idxs = np.arange(len(dataset))
  data_labels = Label

  label_idxs = np.vstack((idxs, data_labels)) # 將標(biāo)簽和數(shù)據(jù)ID堆疊
  label_idxs = label_idxs[:, label_idxs[1,:].argsort()]
  idxs = label_idxs[0,:]

  for i in range(clients):
    rand_set = set(np.random.choice(shard_idxs, num_shards_per_client, replace=False)) 
    shard_idxs = list(set(shard_idxs) - rand_set)

    for rand in rand_set:
      client_dict[i] = np.concatenate((client_dict[i], idxs[rand*shards_size:(rand+1)*shards_size]), axis=0) # 拼接
  
  return client_dict
class MNISTDataset(Dataset):
    def __init__(self, data,label):
        self.data = data
        self.label = label

    def __getitem__(self, idx):
        image=np.array(self.data[idx]).astype('float32')
        image=np.reshape(image,[1,28,28])
        label=np.array(self.label[idx]).astype('int64')
        return image, label

    def __len__(self):
        return len(self.label)

6. 模型訓(xùn)練

class ClientUpdate(object):
    def __init__(self, data, label, batch_size, learning_rate, epochs):
        dataset = MNISTDataset(data,label)
        self.train_loader = DataLoader(dataset,
                    batch_size=batch_size,
                    shuffle=True,
                    drop_last=True)
        self.learning_rate = learning_rate
        self.epochs = epochs
        
    def train(self, model):
        optimizer=paddle.optimizer.SGD(learning_rate=self.learning_rate,parameters=model.parameters())
        criterion = nn.CrossEntropyLoss(reduction='mean')
        model.train()
        e_loss = []
        for epoch in range(1,self.epochs+1):
            train_loss = []
            for image,label in self.train_loader:
                # image=paddle.to_tensor(image)
                # label=paddle.to_tensor(label.reshape([label.shape[0],1]))
                output=model(image)
                loss= criterion(output,label)
                # print(loss)
                loss.backward()
                optimizer.step()
                optimizer.clear_grad()
                train_loss.append(loss.numpy()[0])
            t_loss=sum(train_loss)/len(train_loss)
            e_loss.append(t_loss)
        total_loss=sum(e_loss)/len(e_loss)
        return model.state_dict(), total_loss
train_x = np.array(Data)
train_y = np.array(Label)
BATCH_SIZE = 32
# 通信輪數(shù)
rounds = 100
# client比例
C = 0.1
# clients數(shù)量
K = 100
# 每次通信在本地訓(xùn)練的epoch
E = 5
# batch size
batch_size = 10
# 學(xué)習(xí)率
lr=0.001
# 數(shù)據(jù)切分
iid_dict = IID(mnist_data_train, 100)
def training(model, rounds, batch_size, lr, ds,L, data_dict, C, K, E, plt_title, plt_color):
    global_weights = model.state_dict()
    train_loss = []
    start = time.time()
    # clients與server之間通信
    for curr_round in range(1, rounds+1):
        w, local_loss = [], []
        m = max(int(C*K), 1) # 隨機(jī)選取參與更新的clients
        S_t = np.random.choice(range(K), m, replace=False)
        for k in S_t:
            # print(data_dict[k])
            sub_data = ds[data_dict[k]]
            sub_y = L[data_dict[k]]
            local_update = ClientUpdate(sub_data,sub_y, batch_size=batch_size, learning_rate=lr, epochs=E)
            weights, loss = local_update.train(model)
            w.append(weights)
            local_loss.append(loss)

        # 更新global weights
        weights_avg = w[0]
        for k in weights_avg.keys():
            for i in range(1, len(w)):
                # weights_avg[k] += (num[i]/sum(num))*w[i][k]
                weights_avg[k]=weights_avg[k]+w[i][k]   
            weights_avg[k]=weights_avg[k]/len(w)
            global_weights[k].set_value(weights_avg[k])
        # global_weights = weights_avg
        # print(global_weights)
    #模型加載最新的參數(shù)
        model.load_dict(global_weights)

        loss_avg = sum(local_loss) / len(local_loss)
        if curr_round % 10 == 0:
            print('Round: {}... \tAverage Loss: {}'.format(curr_round, np.round(loss_avg, 5)))
        train_loss.append(loss_avg)

    end = time.time()
    fig, ax = plt.subplots()
    x_axis = np.arange(1, rounds+1)
    y_axis = np.array(train_loss)
    ax.plot(x_axis, y_axis, 'tab:'+plt_color)

    ax.set(xlabel='Number of Rounds', ylabel='Train Loss',title=plt_title)
    ax.grid()
    fig.savefig(plt_title+'.jpg', format='jpg')
    print("Training Done!")
    print("Total time taken to Train: {}".format(end-start))
  
    return model.state_dict()

#導(dǎo)入模型
mnist_cnn = CNN()
mnist_cnn_iid_trained = training(mnist_cnn, rounds, batch_size, lr, train_x,train_y, iid_dict, C, K, E, "MNIST CNN on IID Dataset", "orange")

paddle2.3-基于聯(lián)邦學(xué)習(xí)實(shí)現(xiàn)FedAVg算法-CNN,2023 AI,算法文章來源地址http://www.zghlxwxcb.cn/news/detail-731116.html

Round: 10... 	Average Loss: [0.024]
Round: 20... 	Average Loss: [0.015]
Round: 30... 	Average Loss: [0.008]
Round: 40... 	Average Loss: [0.003]
Round: 50... 	Average Loss: [0.004]
Round: 60... 	Average Loss: [0.002]
Round: 70... 	Average Loss: [0.002]
Round: 80... 	Average Loss: [0.002]
Round: 90... 	Average Loss: [0.001]
Round: 100... 	Average Loss: [0.]
Training Done!
Total time taken to Train: 759.6239657402039

到了這里,關(guān)于paddle2.3-基于聯(lián)邦學(xué)習(xí)實(shí)現(xiàn)FedAVg算法-CNN的文章就介紹完了。如果您還想了解更多內(nèi)容,請(qǐng)?jiān)谟疑辖撬阉鱐OY模板網(wǎng)以前的文章或繼續(xù)瀏覽下面的相關(guān)文章,希望大家以后多多支持TOY模板網(wǎng)!

本文來自互聯(lián)網(wǎng)用戶投稿,該文觀點(diǎn)僅代表作者本人,不代表本站立場(chǎng)。本站僅提供信息存儲(chǔ)空間服務(wù),不擁有所有權(quán),不承擔(dān)相關(guān)法律責(zé)任。如若轉(zhuǎn)載,請(qǐng)注明出處: 如若內(nèi)容造成侵權(quán)/違法違規(guī)/事實(shí)不符,請(qǐng)點(diǎn)擊違法舉報(bào)進(jìn)行投訴反饋,一經(jīng)查實(shí),立即刪除!

領(lǐng)支付寶紅包贊助服務(wù)器費(fèi)用

相關(guān)文章

  • 基于區(qū)塊鏈的分層聯(lián)邦學(xué)習(xí)

    基于區(qū)塊鏈的分層聯(lián)邦學(xué)習(xí)

    分層聯(lián)邦學(xué)習(xí)(HFL)在保留聯(lián)邦學(xué)習(xí)(FL)隱私保護(hù)優(yōu)勢(shì)的同時(shí),減輕了通信開銷,具有高帶寬和豐富計(jì)算資源的優(yōu)點(diǎn)。當(dāng)FL的工作人員或參數(shù)服務(wù)器不可信或惡意時(shí),方法是使用分層聯(lián)邦學(xué)習(xí)。 IEEE Access QIMEI CHEN1, (Member, IEEE), ZEHUA YOU1, JING WU1, YUNPENG LIU1, and HAO JIANG1 2022 (端邊

    2024年02月03日
    瀏覽(56)
  • 《橫向聯(lián)邦學(xué)習(xí)中 PCA差分隱私數(shù)據(jù)發(fā)布算法》論文算法原理筆記

    《橫向聯(lián)邦學(xué)習(xí)中 PCA差分隱私數(shù)據(jù)發(fā)布算法》論文算法原理筆記

    論文地址:https://www.arocmag.com/article/01-2022-01-041.html 論文摘要 ?????為了讓不同組織在保護(hù)本地敏感數(shù)據(jù)和降維后發(fā)布數(shù)據(jù)隱私的前提下,聯(lián)合使用 PCA進(jìn)行降維和數(shù)據(jù)發(fā)布,提出 橫向聯(lián)邦 PCA差分隱私數(shù)據(jù)發(fā)布算法 。引入隨機(jī)種子聯(lián)合協(xié)商方案,在各站點(diǎn)之間以較少通信代

    2024年02月08日
    瀏覽(21)
  • 基于區(qū)塊鏈的聯(lián)邦學(xué)習(xí)工作流程

    基于區(qū)塊鏈的聯(lián)邦學(xué)習(xí)工作流程

    1.初始化(Initialization) :從預(yù)定義好的目標(biāo)函數(shù)和全局梯度中隨機(jī)選擇參數(shù)。 2.本地模型更新(Local model update) :終端設(shè)備根據(jù)所需的迭代次數(shù)來訓(xùn)練本地模型。 3.本地模型上傳(Local model upload) :礦工與終端設(shè)備進(jìn)行綁定。終端設(shè)備上傳本地模型參數(shù)給礦工,同時(shí)上傳相

    2024年02月05日
    瀏覽(33)
  • 聯(lián)邦學(xué)習(xí)實(shí)戰(zhàn)-1:用python從零開始實(shí)現(xiàn)橫向聯(lián)邦學(xué)習(xí)

    聯(lián)邦學(xué)習(xí)實(shí)戰(zhàn)-1:用python從零開始實(shí)現(xiàn)橫向聯(lián)邦學(xué)習(xí)

    什么是聯(lián)邦學(xué)習(xí)? 簡(jiǎn)單來說就是在一個(gè)多方的環(huán)境中,數(shù)據(jù)集是零散的(在各個(gè)不同的客戶端中),那么怎樣實(shí)現(xiàn)機(jī)器學(xué)習(xí)算法呢? 首先想到的就是將多個(gè)數(shù)據(jù)集合并合并起來,然后統(tǒng)一的使用傳統(tǒng)的機(jī)器學(xué)習(xí)或者深度學(xué)習(xí)算法進(jìn)行計(jì)算,但是如果有一方因?yàn)閿?shù)據(jù)隱私問題

    2023年04月08日
    瀏覽(148)
  • Paddle進(jìn)階實(shí)戰(zhàn)系列(三):基于SVTR算法的手寫英文單詞識(shí)別

    Paddle進(jìn)階實(shí)戰(zhàn)系列(三):基于SVTR算法的手寫英文單詞識(shí)別

    ????? 作者簡(jiǎn)介: CSDN、阿里云人工智能領(lǐng)域博客專家,新星計(jì)劃計(jì)算機(jī)視覺導(dǎo)師,百度飛槳PPDE,專注大數(shù)據(jù)與AI知識(shí)分享。 公眾號(hào):GoAI的學(xué)習(xí)小屋 ,免費(fèi)分享書籍、簡(jiǎn)歷、導(dǎo)圖等,更有交流群分享寶藏資料,關(guān)注公眾號(hào)回復(fù)“加群”或?? 鏈接 加群。 ?? 專欄推薦:

    2023年04月18日
    瀏覽(18)
  • 基于聯(lián)邦強(qiáng)化學(xué)習(xí)的集群機(jī)器人協(xié)同導(dǎo)航

    基于聯(lián)邦強(qiáng)化學(xué)習(xí)的集群機(jī)器人協(xié)同導(dǎo)航

    1.1 集群機(jī)器人技術(shù)仿生背景 靈感來自群居昆蟲,比如螞蟻,它們利用信息素進(jìn)行長(zhǎng)距離覓食。由于群居昆蟲能夠集體完成單個(gè)個(gè)體無法完成的具有挑戰(zhàn)性的任務(wù),因此群體機(jī)器人系統(tǒng)有望在動(dòng)態(tài)復(fù)雜環(huán)境下完成單個(gè)機(jī)器人難以完成的具有挑戰(zhàn)性的任務(wù)。 示例1:蟻群協(xié)同工作

    2024年03月20日
    瀏覽(25)
  • 基于SGX和聯(lián)邦學(xué)習(xí)的電腦程序數(shù)據(jù)共享

    在當(dāng)今大數(shù)據(jù)時(shí)代,數(shù)據(jù)共享及數(shù)據(jù)安全問題已經(jīng)成為了一大關(guān)注焦點(diǎn)。 為了解決這一問題,研究者們不斷探索新的技術(shù)方法,其中包括英特爾的軟件保護(hù)擴(kuò)展(SGX)和聯(lián)邦學(xué)習(xí)(FL)。在本篇博客中,我們將深入探討基于SGX和聯(lián)邦學(xué)習(xí)的電腦程序數(shù)據(jù)共享方法,以及它們?nèi)?/p>

    2024年02月16日
    瀏覽(13)
  • 2023年3月版聯(lián)邦學(xué)習(xí)(fate)從主機(jī)安裝到實(shí)現(xiàn)聯(lián)邦學(xué)習(xí)

    2023年3月版聯(lián)邦學(xué)習(xí)(fate)從主機(jī)安裝到實(shí)現(xiàn)聯(lián)邦學(xué)習(xí)

    單機(jī)版提供3種部署方式,這里選擇在 主機(jī)中安裝FATE (官方建議使用Docker鏡像,但不熟悉Docker的人容易找不到FATE路徑) 使用 虛擬機(jī)VMware 進(jìn)行實(shí)驗(yàn),實(shí)驗(yàn)過程中隨時(shí) 拍攝快照 ,節(jié)約重裝時(shí)間。 項(xiàng)目 Value 虛擬機(jī)配置 內(nèi)存4G + 硬盤150G 操作系統(tǒng) centos 7 這里不重復(fù)寫了,請(qǐng)參考

    2023年04月15日
    瀏覽(49)
  • PFL-MoE:基于混合專家的個(gè)性聯(lián)邦學(xué)習(xí)

    PFL-MoE:基于混合專家的個(gè)性聯(lián)邦學(xué)習(xí)

    文章鏈接:PFL-MoE: Personalized Federated Learning Based on Mixture of Experts 發(fā)表會(huì)議:APWeb-WAIM 2021(CCF-C) 過去幾年,深度學(xué)習(xí)在AI應(yīng)用領(lǐng)域(CV、NLP、RS)中快速發(fā)展,這離不開海量數(shù)據(jù)集的支持。這些數(shù)據(jù)集通常是來自不同組織、設(shè)備或用戶的數(shù)據(jù)集合。 分布式機(jī)器學(xué)習(xí)(distributed m

    2024年02月07日
    瀏覽(24)
  • 【聯(lián)邦學(xué)習(xí)論文閱讀】常用算法理解(SCAFFOLD、FedPD、FedBN)-目前僅SCAFFOLD

    【聯(lián)邦學(xué)習(xí)論文閱讀】常用算法理解(SCAFFOLD、FedPD、FedBN)-目前僅SCAFFOLD

    SCAFFOLD(ICML-2020):SCAFFOLD: Stochastic Controlled Averaging for Federated Learning FedPD:https://arxiv.org/abs/2005.11418 FedBN(ICLR 2021):FEDBN: FEDERATED LEARNING ON NON-IID FEATURES VIA LOCAL BATCH NORMALIZATION 1… 梯度 實(shí)際上是對(duì)用戶數(shù)據(jù)進(jìn)行函數(shù)變換,在訓(xùn)練數(shù)據(jù)時(shí)攜帶信息,可能有泄露梯度隱私的風(fēng)險(xiǎn)。

    2023年04月20日
    瀏覽(45)

覺得文章有用就打賞一下文章作者

支付寶掃一掃打賞

博客贊助

微信掃一掃打賞

請(qǐng)作者喝杯咖啡吧~博客贊助

支付寶掃一掃領(lǐng)取紅包,優(yōu)惠每天領(lǐng)

二維碼1

領(lǐng)取紅包

二維碼2

領(lǐng)紅包