paddle2.3-基于联邦学习实现FedAVg算法

目录

1. 联邦学习介绍

2. 实验流程

3. 数据加载

4. 模型构建

5. 数据采样函数

6. 模型训练


1. 联邦学习介绍

联邦学习是一种分布式机器学习方法,中心节点为server(服务器),各分支节点为本地的client(设备)。联邦学习的模式是在各分支节点分别利用本地数据训练模型,再将训练好的模型汇合到中心节点,获得一个更好的全局模型。

联邦学习的提出是为了充分利用用户的数据特征训练效果更佳的模型,同时,为了保证隐私,联邦学习在训练过程中,server和clients之间通信的是模型的参数(或梯度、参数更新量),本地的数据不会上传到服务器。

本项目主要是升级1.8版本的联邦学习fedavg算法至2.3版本,内容取材于基于PaddlePaddle实现联邦学习算法FedAvg - 飞桨AI Studio星河社区

2. 实验流程

联邦学习的基本流程是:

1. server初始化模型参数,所有的clients将这个初始模型下载到本地;

2. clients利用本地产生的数据进行SGD训练;

3. 选取K个clients将训练得到的模型参数上传到server;

4. server对得到的模型参数整合,所有的clients下载新的模型。

5. 重复执行2-5,直至收敛或达到预期要求

import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import random
import time
import paddle
import paddle.nn as nn
import numpy as np
from paddle.io import Dataset,DataLoader
import paddle.nn.functional as F

3. 数据加载

mnist_data_train=np.load('data/data2489/train_mnist.npy')
mnist_data_test=np.load('data/data2489/test_mnist.npy')
print('There are {} images for training'.format(len(mnist_data_train)))
print('There are {} images for testing'.format(len(mnist_data_test)))
# 数据和标签分离(便于后续处理)
Label=[int(i[0]) for i in mnist_data_train]
Data=[i[1:] for i in mnist_data_train]
There are 60000 images for training
There are 10000 images for testing

4. 模型构建

class CNN(nn.Layer):def __init__(self):super(CNN,self).__init__()self.conv1=nn.Conv2D(1,32,5)self.relu = nn.ReLU()self.pool1=nn.MaxPool2D(kernel_size=2,stride=2)self.conv2=nn.Conv2D(32,64,5)self.pool2=nn.MaxPool2D(kernel_size=2,stride=2)self.fc1=nn.Linear(1024,512)self.fc2=nn.Linear(512,10)# self.softmax = nn.Softmax()def forward(self,inputs):x = self.conv1(inputs)x = self.relu(x)x = self.pool1(x)x = self.conv2(x)x = self.relu(x)x = self.pool2(x)x=paddle.reshape(x,[-1,1024])x = self.relu(self.fc1(x))y = self.fc2(x)return y

5. 数据采样函数

# 均匀采样,分配到各个client的数据集都是IID且数量相等的
def IID(dataset, clients):num_items_per_client = int(len(dataset)/clients)client_dict = {}image_idxs = [i for i in range(len(dataset))]for i in range(clients):client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False)) # 为每个client随机选取数据image_idxs = list(set(image_idxs) - client_dict[i]) # 将已经选取过的数据去除client_dict[i] = list(client_dict[i])return client_dict
# 非均匀采样,同时各个client上的数据分布和数量都不同
def NonIID(dataset, clients, total_shards, shards_size, num_shards_per_client):shard_idxs = [i for i in range(total_shards)]client_dict = {i: np.array([], dtype='int64') for i in range(clients)}idxs = np.arange(len(dataset))data_labels = Labellabel_idxs = np.vstack((idxs, data_labels)) # 将标签和数据ID堆叠label_idxs = label_idxs[:, label_idxs[1,:].argsort()]idxs = label_idxs[0,:]for i in range(clients):rand_set = set(np.random.choice(shard_idxs, num_shards_per_client, replace=False)) shard_idxs = list(set(shard_idxs) - rand_set)for rand in rand_set:client_dict[i] = np.concatenate((client_dict[i], idxs[rand*shards_size:(rand+1)*shards_size]), axis=0) # 拼接return client_dict

class MNISTDataset(Dataset):def __init__(self, data,label):self.data = dataself.label = labeldef __getitem__(self, idx):image=np.array(self.data[idx]).astype('float32')image=np.reshape(image,[1,28,28])label=np.array(self.label[idx]).astype('int64')return image, labeldef __len__(self):return len(self.label)

6. 模型训练

class ClientUpdate(object):def __init__(self, data, label, batch_size, learning_rate, epochs):dataset = MNISTDataset(data,label)self.train_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True,drop_last=True)self.learning_rate = learning_rateself.epochs = epochsdef train(self, model):optimizer=paddle.optimizer.SGD(learning_rate=self.learning_rate,parameters=model.parameters())criterion = nn.CrossEntropyLoss(reduction='mean')model.train()e_loss = []for epoch in range(1,self.epochs+1):train_loss = []for image,label in self.train_loader:# image=paddle.to_tensor(image)# label=paddle.to_tensor(label.reshape([label.shape[0],1]))output=model(image)loss= criterion(output,label)# print(loss)loss.backward()optimizer.step()optimizer.clear_grad()train_loss.append(loss.numpy()[0])t_loss=sum(train_loss)/len(train_loss)e_loss.append(t_loss)total_loss=sum(e_loss)/len(e_loss)return model.state_dict(), total_loss

train_x = np.array(Data)
train_y = np.array(Label)
BATCH_SIZE = 32
# 通信轮数
rounds = 100
# client比例
C = 0.1
# clients数量
K = 100
# 每次通信在本地训练的epoch
E = 5
# batch size
batch_size = 10
# 学习率
lr=0.001
# 数据切分
iid_dict = IID(mnist_data_train, 100)
def training(model, rounds, batch_size, lr, ds,L, data_dict, C, K, E, plt_title, plt_color):global_weights = model.state_dict()train_loss = []start = time.time()# clients与server之间通信for curr_round in range(1, rounds+1):w, local_loss = [], []m = max(int(C*K), 1) # 随机选取参与更新的clientsS_t = np.random.choice(range(K), m, replace=False)for k in S_t:# print(data_dict[k])sub_data = ds[data_dict[k]]sub_y = L[data_dict[k]]local_update = ClientUpdate(sub_data,sub_y, batch_size=batch_size, learning_rate=lr, epochs=E)weights, loss = local_update.train(model)w.append(weights)local_loss.append(loss)# 更新global weightsweights_avg = w[0]for k in weights_avg.keys():for i in range(1, len(w)):# weights_avg[k] += (num[i]/sum(num))*w[i][k]weights_avg[k]=weights_avg[k]+w[i][k]   weights_avg[k]=weights_avg[k]/len(w)global_weights[k].set_value(weights_avg[k])# global_weights = weights_avg# print(global_weights)#模型加载最新的参数model.load_dict(global_weights)loss_avg = sum(local_loss) / len(local_loss)if curr_round % 10 == 0:print('Round: {}... \tAverage Loss: {}'.format(curr_round, np.round(loss_avg, 5)))train_loss.append(loss_avg)end = time.time()fig, ax = plt.subplots()x_axis = np.arange(1, rounds+1)y_axis = np.array(train_loss)ax.plot(x_axis, y_axis, 'tab:'+plt_color)ax.set(xlabel='Number of Rounds', ylabel='Train Loss',title=plt_title)ax.grid()fig.savefig(plt_title+'.jpg', format='jpg')print("Training Done!")print("Total time taken to Train: {}".format(end-start))return model.state_dict()#导入模型
mnist_cnn = CNN()
mnist_cnn_iid_trained = training(mnist_cnn, rounds, batch_size, lr, train_x,train_y, iid_dict, C, K, E, "MNIST CNN on IID Dataset", "orange")

W0605 23:22:00.961916 10307 gpu_context.cc:278] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 10.1
W0605 23:22:00.966121 10307 gpu_context.cc:306] device: 0, cuDNN Version: 7.6.Round: 10... 	Average Loss: 0.033
Round: 20... 	Average Loss: 0.011
Round: 30... 	Average Loss: 0.012
Round: 40... 	Average Loss: 0.008
Round: 50... 	Average Loss: 0.003
Round: 60... 	Average Loss: 0.002
Round: 70... 	Average Loss: 0.001
Round: 80... 	Average Loss: 0.001
Round: 90... 	Average Loss: 0.001

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/142641.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【操作系统笔记四】高速缓存

CPU 高速缓存 存储器的分层结构: 问题:为什么这种存储器层次结构行之有效呢? 衡量 CPU 性能的两个指标: 响应时间(或执行时间):执行一条指令平均时间 吞吐量,就是 1 秒内 CPU 可以…

Kafka的消息存储机制

前面咱们简单讲了K啊开发入门相关的概念、架构、特点以及安装启动。 今天咱们来说一下它的消息存储机制。 前言: Kafka通过将消息持久化到磁盘上的日志文件来实现高吞吐量的消息传递。 这种存储机制使得Kafka能够处理大量的消息,并保证消息的可靠性。 1…

Vue+ElementUI实现动态树和表格数据的查询

目录 前言 一、动态树的实现 1.数据表 2.编写后端controller层 3.定义前端发送请求路径 4.前端左侧动态树的编写 4.1.发送请求获取数据 4.2.遍历左侧菜单 5.实现左侧菜单点击展示右边内容 5.1.定义组件 5.2.定义组件与路由的对应关系 5.3.渲染组件内容 5.4.通过动态…

OpenAI 更新 ChatGPT:支持图片和语音输入【附点评】

一、消息正文 9月25日消息,近日OpenAI宣布其对话AI系统ChatGPT进行升级,添加了语音输入和图像处理两个新功能。据OpenAI透露,这些新功能将在未来两周内面向ChatGPT Plus付费用户推出,免费用户也将很快可以使用这些新功能。这标志着ChatGPT继续朝着多模态交互的方向发展,为用户提…

Cpp/Qt-day040920Qt

目录 时钟 头文件&#xff1a;Widget.h: 源文件:Widget.c: 效果图&#xff1a; 思维导图 时钟 头文件&#xff1a;Widget.h: #ifndef WIDGET_H #define WIDGET_H#include <QWidget> #include <QPaintEvent> #include <QPainter> #include <QTime>…

无需求文档,保障测试质量的可行性做法

001 没有需求文档3种可能情况 &#xff1a; 1、公司都没产品经理&#xff0c;开发人员的意识不足&#xff0c;收到的客户需求&#xff0c;直接开干&#xff08;写需求文档 &#xff1f;不可能的&#xff09; 。 2、项目进度紧张&#xff0c;需求变动大&#xff0c;一直在变&…

如何在.NET电子表格应用程序中创建流程图

前言 流程图是一种常用的图形化工具&#xff0c;用于展示过程中事件、决策和操作的顺序和关系。它通过使用不同形状的图标和箭头线条&#xff0c;将任务和步骤按照特定的顺序连接起来&#xff0c;以便清晰地表示一个过程的执行流程。 在企业环境中&#xff0c;高管和经理利用…

【C语言】模拟实现内存函数

本篇文章目录 相关文章1. 模拟 memcpy 内存拷贝2. 模拟 memmove 内存移动 相关文章 【C语言】数据在内存中是以什么顺序存储的&#xff1f;【C语言】整数在内存中如何存储&#xff1f;又是如何进行计算使用的&#xff1f;【C语言】利用void*进行泛型编程【C语言】4.指针类型部…

关于MATLAB R2022b中MATLAB function没有edit data选项的解决办法

问题描述 在MATLAB 2022b的simulink中双击MATLAB function&#xff0c;出来的是这个界面&#xff0c;而不是跳转到MATLAB的编辑窗口。因此就找不到edit data选项&#xff0c;没法完成新建data store memory 全局变量。 解决办法&#xff1a; 点击 编辑数据 按钮 在弹出的窗…

孟晚舟最新发声!华为吹响人工智能的号角,发布“全面智能化”战略部署

原创 | 文 BFT机器人 1、华为孟晚舟新发声&#xff0c;华为发布“全面智能化”战略 上周三&#xff08;9月30号&#xff09;上午&#xff0c;华为全联接大会2023正式在上海举行&#xff0c;作为华为副董事长、轮值董事长、CFO的孟晚舟代表华为再次发声&#xff01;在演讲上&am…

力扣刷题-链表-链表相交

02.07. 链表相交 给你两个单链表的头节点 headA 和 headB &#xff0c;请你找出并返回两个单链表相交的起始节点。如果两个链表没有交点&#xff0c;返回 null 。 图示两个链表在节点 c1 开始相交&#xff1a; 题目数据 保证 整个链式结构中不存在环。 注意&#xff0c;函数返…

基于springboot+vue的大学生科创项目在线管理系统

博主主页&#xff1a;猫头鹰源码 博主简介&#xff1a;Java领域优质创作者、CSDN博客专家、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战 主要内容&#xff1a;毕业设计(Javaweb项目|小程序等)、简历模板、学习资料、面试题库、技术咨询 文末联系获取 项目介绍…

海大校园学习《乡村振兴战略下传统村落文化旅游设计》许少辉八一新著

海大校园学习《乡村振兴战略下传统村落文化旅游设计》许少辉八一新著

WARNING:tensorflow:Your input ran out of data; interrupting training. 解决方法

问题详情&#xff1a; WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least steps_per_epoch * epochs batches (in this case, 13800 batches). You may need to use the repeat() funct…

【数据结构-树】哈夫曼树

&#x1f49d;&#x1f49d;&#x1f49d;欢迎来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐:kuan 的首页,持续学…

汽车电子——产品标准规范汇总和梳理(自动驾驶)

文章目录 前言 一、分级 二、定位 三、地图 四、座舱 五、远程 六、信息数据 七、场景 八、智慧城市 九、方法论 总结 前言 见《汽车电子——产品标准规范汇总和梳理》 一、分级 《GB/T 40429-2021 汽车驾驶自动化分级》 《QC/T XXXXX—XXXX 智能网联汽车 自动驾…

C语言动态内存管理malloc、calloc、realloc、free函数的讲解

一.为什么存在动态内存管理&#xff1a; 我们知道&#xff0c;在此之前向内存申请空间的方式有以下两种&#xff1a;&#xff08;变量和数组&#xff09; 但这两种方法有几个缺陷&#xff1a; ①&#xff1a;空间开辟大小是固定的&#xff1b; ②&#xff1a;数组在声明的时候&…

Qt扫盲-QSqlQueryModel理论总结

QSqlQueryModel理论总结 一、概述二、使用1. 与 view 视图 绑定2. 分离视图&#xff0c;只存数据 一、概述 QSqlQueryModel是用于执行SQL语句和遍历结果集的高级接口。它构建在较低级的 QSqlQuery之上&#xff0c;可用于向QTableView 等视图类提供数据&#xff0c;也是使用了Q…

微信开发者工具appdata\local\微信开发者工具有啥用,能删掉吗?占用空间8G

你好这边 微信开发者工具\User Data 存储的都是一些用户开发者在工具的一些数据存储&#xff0c;不建议全部删除&#xff0c;这样可能你较常用的一些项目记录和缓存信息就会找不到&#xff0c;如果需要清理的话&#xff0c;可以考虑删除&#xff1a; WeappApplication 应用更新…

【Java 基础篇】Java 接口组成与更新详解

在Java编程中&#xff0c;接口&#xff08;interface&#xff09;是一种非常重要的概念。它允许类定义一组抽象方法&#xff0c;这些方法可以在不同的类中实现。接口在Java中起到了重要的角色&#xff0c;被广泛应用于代码的组织和设计中。本文将详细解释Java接口的组成和最新的…