paddle2.3-基于联邦学习实现FedAVg算法-CNN

目录

1. 联邦学习介绍

2. 实验流程

3. 数据加载

4. 模型构建

5. 数据采样函数

6. 模型训练


1. 联邦学习介绍

联邦学习是一种分布式机器学习方法,中心节点为server(服务器),各分支节点为本地的client(设备)。联邦学习的模式是在各分支节点分别利用本地数据训练模型,再将训练好的模型汇合到中心节点,获得一个更好的全局模型。

联邦学习的提出是为了充分利用用户的数据特征训练效果更佳的模型,同时,为了保证隐私,联邦学习在训练过程中,server和clients之间通信的是模型的参数(或梯度、参数更新量),本地的数据不会上传到服务器。

本项目主要是升级1.8版本的联邦学习fedavg算法至2.3版本,内容取材于基于PaddlePaddle实现联邦学习算法FedAvg - 飞桨AI Studio星河社区

2. 实验流程

联邦学习的基本流程是:

1. server初始化模型参数,所有的clients将这个初始模型下载到本地;

2. clients利用本地产生的数据进行SGD训练;

3. 选取K个clients将训练得到的模型参数上传到server;

4. server对得到的模型参数整合,所有的clients下载新的模型。

5. 重复执行2-5,直至收敛或达到预期要求

import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import random
import time
import paddle
import paddle.nn as nn
import numpy as np
from paddle.io import Dataset,DataLoader
import paddle.nn.functional as F

3. 数据加载

mnist_data_train=np.load('data/data2489/train_mnist.npy')
mnist_data_test=np.load('data/data2489/test_mnist.npy')
print('There are {} images for training'.format(len(mnist_data_train)))
print('There are {} images for testing'.format(len(mnist_data_test)))
# 数据和标签分离(便于后续处理)
Label=[int(i[0]) for i in mnist_data_train]
Data=[i[1:] for i in mnist_data_train]
There are 60000 images for training
There are 10000 images for testing

4. 模型构建

class CNN(nn.Layer):def __init__(self):super(CNN,self).__init__()self.conv1=nn.Conv2D(1,32,5)self.relu = nn.ReLU()self.pool1=nn.MaxPool2D(kernel_size=2,stride=2)self.conv2=nn.Conv2D(32,64,5)self.pool2=nn.MaxPool2D(kernel_size=2,stride=2)self.fc1=nn.Linear(1024,512)self.fc2=nn.Linear(512,10)# self.softmax = nn.Softmax()def forward(self,inputs):x = self.conv1(inputs)x = self.relu(x)x = self.pool1(x)x = self.conv2(x)x = self.relu(x)x = self.pool2(x)x=paddle.reshape(x,[-1,1024])x = self.relu(self.fc1(x))y = self.fc2(x)return y

5. 数据采样函数

# 均匀采样,分配到各个client的数据集都是IID且数量相等的
def IID(dataset, clients):num_items_per_client = int(len(dataset)/clients)client_dict = {}image_idxs = [i for i in range(len(dataset))]for i in range(clients):client_dict[i] = set(np.random.choice(image_idxs, num_items_per_client, replace=False)) # 为每个client随机选取数据image_idxs = list(set(image_idxs) - client_dict[i]) # 将已经选取过的数据去除client_dict[i] = list(client_dict[i])return client_dict
# 非均匀采样,同时各个client上的数据分布和数量都不同
def NonIID(dataset, clients, total_shards, shards_size, num_shards_per_client):shard_idxs = [i for i in range(total_shards)]client_dict = {i: np.array([], dtype='int64') for i in range(clients)}idxs = np.arange(len(dataset))data_labels = Labellabel_idxs = np.vstack((idxs, data_labels)) # 将标签和数据ID堆叠label_idxs = label_idxs[:, label_idxs[1,:].argsort()]idxs = label_idxs[0,:]for i in range(clients):rand_set = set(np.random.choice(shard_idxs, num_shards_per_client, replace=False)) shard_idxs = list(set(shard_idxs) - rand_set)for rand in rand_set:client_dict[i] = np.concatenate((client_dict[i], idxs[rand*shards_size:(rand+1)*shards_size]), axis=0) # 拼接return client_dict

class MNISTDataset(Dataset):def __init__(self, data,label):self.data = dataself.label = labeldef __getitem__(self, idx):image=np.array(self.data[idx]).astype('float32')image=np.reshape(image,[1,28,28])label=np.array(self.label[idx]).astype('int64')return image, labeldef __len__(self):return len(self.label)

6. 模型训练

class ClientUpdate(object):def __init__(self, data, label, batch_size, learning_rate, epochs):dataset = MNISTDataset(data,label)self.train_loader = DataLoader(dataset,batch_size=batch_size,shuffle=True,drop_last=True)self.learning_rate = learning_rateself.epochs = epochsdef train(self, model):optimizer=paddle.optimizer.SGD(learning_rate=self.learning_rate,parameters=model.parameters())criterion = nn.CrossEntropyLoss(reduction='mean')model.train()e_loss = []for epoch in range(1,self.epochs+1):train_loss = []for image,label in self.train_loader:# image=paddle.to_tensor(image)# label=paddle.to_tensor(label.reshape([label.shape[0],1]))output=model(image)loss= criterion(output,label)# print(loss)loss.backward()optimizer.step()optimizer.clear_grad()train_loss.append(loss.numpy()[0])t_loss=sum(train_loss)/len(train_loss)e_loss.append(t_loss)total_loss=sum(e_loss)/len(e_loss)return model.state_dict(), total_loss

train_x = np.array(Data)
train_y = np.array(Label)
BATCH_SIZE = 32
# 通信轮数
rounds = 100
# client比例
C = 0.1
# clients数量
K = 100
# 每次通信在本地训练的epoch
E = 5
# batch size
batch_size = 10
# 学习率
lr=0.001
# 数据切分
iid_dict = IID(mnist_data_train, 100)
def training(model, rounds, batch_size, lr, ds,L, data_dict, C, K, E, plt_title, plt_color):global_weights = model.state_dict()train_loss = []start = time.time()# clients与server之间通信for curr_round in range(1, rounds+1):w, local_loss = [], []m = max(int(C*K), 1) # 随机选取参与更新的clientsS_t = np.random.choice(range(K), m, replace=False)for k in S_t:# print(data_dict[k])sub_data = ds[data_dict[k]]sub_y = L[data_dict[k]]local_update = ClientUpdate(sub_data,sub_y, batch_size=batch_size, learning_rate=lr, epochs=E)weights, loss = local_update.train(model)w.append(weights)local_loss.append(loss)# 更新global weightsweights_avg = w[0]for k in weights_avg.keys():for i in range(1, len(w)):# weights_avg[k] += (num[i]/sum(num))*w[i][k]weights_avg[k]=weights_avg[k]+w[i][k]   weights_avg[k]=weights_avg[k]/len(w)global_weights[k].set_value(weights_avg[k])# global_weights = weights_avg# print(global_weights)#模型加载最新的参数model.load_dict(global_weights)loss_avg = sum(local_loss) / len(local_loss)if curr_round % 10 == 0:print('Round: {}... \tAverage Loss: {}'.format(curr_round, np.round(loss_avg, 5)))train_loss.append(loss_avg)end = time.time()fig, ax = plt.subplots()x_axis = np.arange(1, rounds+1)y_axis = np.array(train_loss)ax.plot(x_axis, y_axis, 'tab:'+plt_color)ax.set(xlabel='Number of Rounds', ylabel='Train Loss',title=plt_title)ax.grid()fig.savefig(plt_title+'.jpg', format='jpg')print("Training Done!")print("Total time taken to Train: {}".format(end-start))return model.state_dict()#导入模型
mnist_cnn = CNN()
mnist_cnn_iid_trained = training(mnist_cnn, rounds, batch_size, lr, train_x,train_y, iid_dict, C, K, E, "MNIST CNN on IID Dataset", "orange")

Round: 10... 	Average Loss: [0.024]
Round: 20... 	Average Loss: [0.015]
Round: 30... 	Average Loss: [0.008]
Round: 40... 	Average Loss: [0.003]
Round: 50... 	Average Loss: [0.004]
Round: 60... 	Average Loss: [0.002]
Round: 70... 	Average Loss: [0.002]
Round: 80... 	Average Loss: [0.002]
Round: 90... 	Average Loss: [0.001]
Round: 100... 	Average Loss: [0.]
Training Done!
Total time taken to Train: 759.6239657402039

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/146748.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【精品】Springboot 接收发送日期类型的数据

问题 无法请求到后台,后台报错:[Failed to convert property value of type java.lang.String to required type java.time.LocalDateTime for property : 2023-10-02T09:26:16.06908:00 WARN 14296 --- [p-nio-80-exec-1] .w.s.m.s.Defaul…

跨类型文本文件,反序列化与类型转换的思考

文章目录 应用场景序列化 - 对象替换原内容,方便使用编写程序取得结果数组 序列化 - JSON 应用场景 在编写热更新的时候,我发现了一个古早的 ini 文件,记录了许多有用的数据 由于使用的语言年份较新,没有办法较好地对 ini 文件的…

Canal实现数据同步

1、Canal实现数据同步 canal可以用来监控数据库数据的变化,从而获得新增数据,或者修改的数据。 1.1 Canal工作原理 原理相对比较简单: 1、canal模拟mysql slave的交互协议,伪装自己为mysql slave,向mysql master发送…

图神经网络GNN(一)GraphEmbedding

DeepWalk 使用随机游走采样得到每个结点x的上下文信息,记作Context(x)。 SkipGram优化的目标函数:P(Context(x)|x;θ) θ argmax P(Context(x)|x;θ) DeepWalk这种GraphEmbedding方法是一种无监督方法,个人理解有点类似生成模型的Encoder过程…

树莓派CM4开启I2C与UART串口登录同时serial0映射到ttyS0 开启多串口

文章目录 前言1. 树莓派开启I2C与UART串口登录2. 开启多串口总结: 前言 最近用CM4的时候使用到了I2C以及多个UART的情况。 同时配置端口映射也存在部分问题。 这里集中记录一下。 1. 树莓派开启I2C与UART串口登录 输入指令sudo raspi-config 跳转到如下界面&#…

3D WEB轻量化引擎HOOPS助力3D测量应用蓬勃发展:效率、精度显著提升

在3D开发工具领域,Tech Soft 3D打造的HOOPS SDK已经崭露头角,成为了全球领先的3D领域开发工具提供商。HOOPS SDK包括四种不同的3D软件开发工具,已成为行业的翘楚。 其中,HOOPS Exchange以其CAD数据转换的能力脱颖而出&#xff0c…

Arm Cache学习资料大汇总

关键词:cache学习、mmu学习、cache资料、mmu资料、arm资料、armv8资料、armv9资料、 trustzone视频、tee视频、ATF视频、secureboot视频、安全启动视频、selinux视频,cache视频、mmu视频,armv8视频、armv9视频、FF-A视频、密码学视频、RME/CC…

前端开发网站推荐

每个人都会遇见那么一个人,永远无法忘却,也永远不能拥有。 以下是一些可以用来查找和比较前端框架的推荐网站: JavaScript框架比较: 这些网站提供了对不同JavaScript框架和库的详细比较和评估。 JavaScripting: 提供了大量的JavaS…

JavaScript高阶班之ES6 → ES11(八)

JavaScript高阶班之ES6 → ES11 1、ES6新特性1.1、let 关键字1.2、const关键字1.3、变量的解构赋值1.3.1、数组的解构赋值1.3.2、对象的解构赋值 1.4、模板字符串1.5、简化对象写法1.6、箭头函数1.7、函数参数默认值1.8、rest参数1.9、spread扩展运算符1.9.1、数组合并1.9.2、数…

上古神器:十六位应用程序 Debug 的基本使用

文章目录 参考环境上古神器 DebugBug 与 DebuggingDebugDebug 应用程序淘汰原因使用限制 DOSBox学习 Debug 的必要性DOSBox-X Debug 的基本使用命令 R查看寄存器的状态修改寄存器的内容 命令 D显示内存中的数据指定起始内存空间地址指定内存空间的范围 命令 A使用命令语法错误查…

第8章 Spring(二)

8.11 Spring 中哪些情况下,不能解决循环依赖问题 难度:★★ 重点:★★ 白话解析 有一下几种情况,循环依赖是不能解决的: 1、原型模式下的循环依赖没办法解决; 假设Girl中依赖了Boy,Boy中依赖了Girl;在实例化Girl的时候要注入Boy,此时没有Boy,因为是原型模式,每次都…

Konva离屏缓存

前言 cache实例方法定义在Node基类上,通过该方法可以实现图形缓存,在Konva中Stage、Layer、Group、Shape等所有容器类和图形类都直接或间接继承了Node基类,故而都可以使用缓存方法。本篇文章就是探讨Konva背后的缓存机制,版本是v…

8.3Jmeter使用json提取器提取数组值并循环(循环控制器)遍历使用

Jmeter使用json提取器提取数组值并循环遍历使用 响应返回值例如: {"code":0,"data":{"totalCount":11,"pageSize":100,"totalPage":1,"currPage":1,"list":[{"structuredId":&q…

计算机网络笔记 第二章 物理层

2.1 物理层概述 物理层要实现的功能 物理层接口特性 机械特性 形状和尺寸引脚数目和排列固定和锁定装置 电气特性 信号电压的范围阻抗匹配的情况传输速率距离限制 功能特性 -规定接口电缆的各条信号线的作用 过程特性 规定在信号线上传输比特流的一组操作过程&#xff0…

3. 文档操作

1. 创建文档 1.1 创建一个文档 在相应的索引下面使用_doc创建文档,地址为:http://127.0.0.1:9200/students/_doc,创建一个姓名张三的学生信息: {"姓名":"张三","年级":5,"班级":2,&qu…

28391-2012 建筑施工机械与设备 人力移动式液压动力站

声明 本文是学习GB-T 28391-2012 建筑施工机械与设备 人力移动式液压动力站. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本标准规定了人力移动式液压动力站(以下简称动力站)的范围、分类、要求、试验方法和检验规则。 本标准适用于以中小…

AI编程助手 Amazon CodeWhisperer 全面解析与实践

目录 引言Amazon CodeWhisperer简介智能编程助手智能代码建议代码自动补全 提升代码质量代码质量提升安全性检测 支持多平台多语言 用户体验和系统兼容性用户体验文档和学习资源个性化体验系统兼容性 功能全面性和代码质量功能全面性代码生成质量和代码安全性 CodeWhisperer的代…

MySQL数据库——索引(6)-索引使用(覆盖索引与回表查询,前缀索引,单列索引与联合索引 )、索引设计原则、索引总结

目录 索引使用(下) 覆盖索引与回表查询 思考题 前缀索引 语法 示例 前缀长度 前缀索引的查询流程 单列索引与联合索引 索引设计原则 索引总结 1.索引概述 2.索引结构 3.索引分类 4.索引语法 5.SQL性能分析 6.索引使用 7.索引设计…

应用在手机触摸屏中的电容式触摸芯片

触控屏(Touch panel)又称为触控面板,是个可接收触头等输入讯号的感应式液晶显示装置,当接触了屏幕上的图形按钮时,屏幕上的触觉反馈系统可根据预先编程的程式驱动各种连结装置,可用以取代机械式的按钮面板&…