【深度学习实验】前馈神经网络(九):整合训练、评估、预测过程(Runner)

目录

一、实验介绍

 二、实验环境

1. 配置虚拟环境

2. 库版本介绍

三、实验内容

0. 导入必要的工具包

1. __init__(初始化)

2. train(训练)

3. evaluate(评估)

4. predict(预测)

5. save_model

6. load_model

7. 代码整合


一、实验介绍

      

 二、实验环境

    本系列实验使用了PyTorch深度学习框架,相关操作如下:

1. 配置虚拟环境

conda create -n DL python=3.7 
conda activate DL
pip install torch==1.8.1+cu102 torchvision==0.9.1+cu102 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
conda install matplotlib
 conda install scikit-learn

2. 库版本介绍

软件包本实验版本目前最新版
matplotlib3.5.33.8.0
numpy1.21.61.26.0
python3.7.16
scikit-learn0.22.11.3.0
torch1.8.1+cu1022.0.1
torchaudio0.8.12.0.2
torchvision0.9.1+cu1020.15.2

三、实验内容

ChatGPT:

        前馈神经网络(Feedforward Neural Network)是一种常见的人工神经网络模型,也被称为多层感知器(Multilayer Perceptron,MLP)。它是一种基于前向传播的模型,主要用于解决分类和回归问题。

        前馈神经网络由多个层组成,包括输入层、隐藏层和输出层。它的名称"前馈"源于信号在网络中只能向前流动,即从输入层经过隐藏层最终到达输出层,没有反馈连接。

以下是前馈神经网络的一般工作原理:

  1. 输入层:接收原始数据或特征向量作为网络的输入,每个输入被表示为网络的一个神经元。每个神经元将输入加权并通过激活函数进行转换,产生一个输出信号。

  2. 隐藏层:前馈神经网络可以包含一个或多个隐藏层,每个隐藏层由多个神经元组成。隐藏层的神经元接收来自上一层的输入,并将加权和经过激活函数转换后的信号传递给下一层。

  3. 输出层:最后一个隐藏层的输出被传递到输出层,输出层通常由一个或多个神经元组成。输出层的神经元根据要解决的问题类型(分类或回归)使用适当的激活函数(如Sigmoid、Softmax等)将最终结果输出。

  4. 前向传播:信号从输入层通过隐藏层传递到输出层的过程称为前向传播。在前向传播过程中,每个神经元将前一层的输出乘以相应的权重,并将结果传递给下一层。这样的计算通过网络中的每一层逐层进行,直到产生最终的输出。

  5. 损失函数和训练:前馈神经网络的训练过程通常涉及定义一个损失函数,用于衡量模型预测输出与真实标签之间的差异。常见的损失函数包括均方误差(Mean Squared Error)和交叉熵(Cross-Entropy)。通过使用反向传播算法(Backpropagation)和优化算法(如梯度下降),网络根据损失函数的梯度进行参数调整,以最小化损失函数的值。

        前馈神经网络的优点包括能够处理复杂的非线性关系,适用于各种问题类型,并且能够通过训练来自动学习特征表示。然而,它也存在一些挑战,如容易过拟合、对大规模数据和高维数据的处理较困难等。为了应对这些挑战,一些改进的网络结构和训练技术被提出,如卷积神经网络(Convolutional Neural Networks)和循环神经网络(Recurrent Neural Networks)等。

本系列为实验内容,对理论知识不进行详细阐释

(咳咳,其实是没时间整理,待有缘之时,回来填坑)

977468b5ae9843c6a88005e792817cb1.png

0. 导入必要的工具包

import torch
from torch import nn
import torch.nn.functional as F
# 绘画时使用的工具包
import matplotlib.pyplot as plt
# 导入鸢尾花数据集
from sklearn.datasets import load_iris
# 构建自己的数据集,继承自Dataset类
from torch.utils.data import Dataset, DataLoader

1. __init__(初始化)

    def __init__(self, model, optimizer, loss_fn, metric, **kwargs):self.model = modelself.optimizer = optimizerself.loss_fn = loss_fn# 用于计算评价指标self.metric = metric# 记录训练过程中的评价指标变化self.dev_scores = []# 记录训练过程中的损失变化self.train_epoch_losses = []self.dev_losses = []# 记录全局最优评价指标self.best_score = 0
  • 五个参数:
    • model(模型)
    • optimizer(优化器)
    • loss_fn(损失函数)
    • metric(评价指标)
    • 其他可选参数。
  • 该类还定义了一些用于记录训练过程中的指标变化和全局最优指标的属性:
    • self.dev_scores(记录验证集评价指标的变化)
    • self.train_epoch_losses(记录训练集损失的变化)
    • self.dev_losses(记录验证集损失的变化)
    • self.best_score(记录全局最优评价指标)

2. train(训练)

 def train(self, train_loader, dev_loader=None, **kwargs):# 将模型设置为训练模式,此时模型的参数会被更新self.model.train()num_epochs = kwargs.get('num_epochs', 0)log_steps = kwargs.get('log_steps', 100)save_path = kwargs.get('save_path', 'best_mode.pth')eval_steps = kwargs.get('eval_steps', 0)# 运行的step数,不等于epoch数global_step = 0if eval_steps:if dev_loader is None:raise RuntimeError('Error: dev_loader can not be None!')if self.metric is None:raise RuntimeError('Error: Metric can not be None')# 遍历训练的轮数for epoch in range(num_epochs):total_loss = 0# 遍历数据集for step, data in enumerate(train_loader):x, y = datalogits = self.model(x.float())loss = self.loss_fn(logits, y.long())total_loss += lossif log_steps and global_step % log_steps == 0:print(f'loss:{loss.item():.5f}')loss.backward()self.optimizer.step()self.optimizer.zero_grad()# 每隔一定轮次进行一次验证,由eval_steps参数控制,可以采用不同的验证判断条件if (epoch + 1) % eval_steps == 0:dev_score, dev_loss = self.evaluate(dev_loader, global_step=global_step)print(f'[Evalute] dev score:{dev_score:.5f}, dev loss:{dev_loss:.5f}')if dev_score > self.best_score:self.save_model(f'model_{epoch + 1}.pth')print(f'[Evaluate]best accuracy performance has been updated: {self.best_score:.5f}-->{dev_score:.5f}')self.best_score = dev_score# 验证过程结束后,请记住将模型调回训练模式self.model.train()global_step += 1# 保存当前轮次训练损失的累计值train_loss = (total_loss / len(train_loader)).item()self.train_epoch_losses.append((global_step, train_loss))print('[Train] Train done')

3. evaluate(评估)

    def evaluate(self, dev_loader, **kwargs):assert self.metric is not None# 将模型设置为验证模式,此模式下,模型的参数不会更新self.model.eval()global_step = kwargs.get('global_step', -1)total_loss = 0self.metric.reset()for batch_id, data in enumerate(dev_loader):x, y = datalogits = self.model(x.float())loss = self.loss_fn(logits, y.long()).item()total_loss += lossself.metric.update(logits, y)dev_loss = (total_loss / len(dev_loader))self.dev_losses.append((global_step, dev_loss))dev_score = self.metric.accumulate()self.dev_scores.append(dev_score)return dev_score, dev_loss

4. predict(预测)

    predict方法用于模型的阶段,输入数据x,返回模型对输入的预测结果。

 def predict(self, x, **kwargs):self.model.eval()logits = self.model(x)return logits

5. save_model

 def save_model(self, save_path):torch.save(self.model.state_dict(),save_path)

6. load_model

  def load_model(self, model_path):self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

7. 代码整合

class Runner(object):def __init__(self, model, optimizer, loss_fn, metric, **kwargs):self.model = modelself.optimizer = optimizerself.loss_fn = loss_fn# 用于计算评价指标self.metric = metric# 记录训练过程中的评价指标变化self.dev_scores = []# 记录训练过程中的损失变化self.train_epoch_losses = []self.dev_losses = []# 记录全局最优评价指标self.best_score = 0# 模型训练阶段def train(self, train_loader, dev_loader=None, **kwargs):# 将模型设置为训练模式,此时模型的参数会被更新self.model.train()num_epochs = kwargs.get('num_epochs', 0)log_steps = kwargs.get('log_steps', 100)save_path = kwargs.get('save_path','best_mode.pth')eval_steps = kwargs.get('eval_steps', 0)# 运行的step数,不等于epoch数global_step = 0if eval_steps:if dev_loader is None:raise RuntimeError('Error: dev_loader can not be None!')if self.metric is None:raise RuntimeError('Error: Metric can not be None')# 遍历训练的轮数for epoch in range(num_epochs):total_loss = 0# 遍历数据集for step, data in enumerate(train_loader):x, y = datalogits = self.model(x.float())loss = self.loss_fn(logits, y.long())total_loss += lossif log_steps and global_step%log_steps == 0:print(f'loss:{loss.item():.5f}')loss.backward()self.optimizer.step()self.optimizer.zero_grad()# 每隔一定轮次进行一次验证,由eval_steps参数控制,可以采用不同的验证判断条件if (epoch+1)% eval_steps ==  0:dev_score, dev_loss = self.evaluate(dev_loader, global_step=global_step)print(f'[Evalute] dev score:{dev_score:.5f}, dev loss:{dev_loss:.5f}')if dev_score > self.best_score:self.save_model(f'model_{epoch+1}.pth')print(f'[Evaluate]best accuracy performance has been updated: {self.best_score:.5f}-->{dev_score:.5f}')self.best_score = dev_score# 验证过程结束后,请记住将模型调回训练模式   self.model.train()global_step += 1# 保存当前轮次训练损失的累计值train_loss = (total_loss/len(train_loader)).item()self.train_epoch_losses.append((global_step,train_loss))print('[Train] Train done')# 模型评价阶段def evaluate(self, dev_loader, **kwargs):assert self.metric is not None# 将模型设置为验证模式,此模式下,模型的参数不会更新self.model.eval()global_step = kwargs.get('global_step',-1)total_loss = 0self.metric.reset()for batch_id, data in enumerate(dev_loader):x, y = datalogits = self.model(x.float())loss = self.loss_fn(logits, y.long()).item()total_loss += loss self.metric.update(logits, y)dev_loss = (total_loss/len(dev_loader))self.dev_losses.append((global_step, dev_loss))dev_score = self.metric.accumulate()self.dev_scores.append(dev_score)return dev_score, dev_loss# 模型预测阶段,def predict(self, x, **kwargs):self.model.eval()logits = self.model(x)return logits# 保存模型的参数def save_model(self, save_path):torch.save(self.model.state_dict(),save_path)# 读取模型的参数def load_model(self, model_path):self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/140212.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【算法】滑动窗口破解长度最小子数组

Problem: 209. 长度最小的子数组 文章目录 题意分析算法原理讲解暴力枚举O(N^2)利用单调性,滑动窗口求解 复杂度Code 题意分析 首先来分析一下本题的题目意思 题目中会给到一个数组,我们的目的是找出在这个数组中 长度最小的【连续】子数组,而…

latexocr安装过程中遇到的问题解决办法

环境要求:需要Python版本3.7,并安装相应依赖文件 具体的详细安装步骤可见我上次写的博文:Mathpix替代者|科研人必备公式识别插件|latexocr安装教程 ‘latexocr‘ 不是内部或外部命令,也不是可运行的程序或批处理文件的相关解决办…

无线感知之手势识别模型:Widar 3.0

目录 一、前言 二、无线感知 三、国内的一些工作 四、WiFi 手势识别模型:Widar 3.0 一、前言 最近不少人吐槽WiFi CSI定位已经做无可做了,也发不了什么好的期刊,顶多冲一个SCI 2区。回首WiFi 指纹定位这块,RSS指纹定位已经发…

Leetcode 剑指 Offer II 045. 找树左下角的值

题目难度: 中等 原题链接 今天继续更新 Leetcode 的剑指 Offer(专项突击版)系列, 大家在公众号 算法精选 里回复 剑指offer2 就能看到该系列当前连载的所有文章了, 记得关注哦~ 题目描述 给定一个二叉树的 根节点 root,请找出该二叉树的 最底…

从零开始—【Mac系统】MacOS配置Java环境变量

系统环境说明 Apple M1 macOS Ventura 版本13.5.2 1.下载JDK安装包 Oracle官网下载地址 JDK下载【注:推荐下载JDK8 Oracle官网JDK8下载】 关于JDK、JRE、JVM的关系说明 JDK(Java Development Kit,Java开发工具包) ,是整个JAVA的核心&#…

【完全二叉树魔法:顺序结构实现堆的奇象】

本章重点 二叉树的顺序结构堆的概念及结构堆的实现堆的调整算法堆的创建堆排序TOP-K问题 1.二叉树的顺序结构 普通的二叉树是不适合用数组来存储的,因为可能会存在大量的空间浪费。而完全二叉树更适合使用顺序结构存储。现实中我们通常把堆(一种二叉树)使用顺序结构…

SpringMVC自定义注解---[详细介绍]

一,对于SpringMVC自定义注解概念 是一种特殊的 Java 注解,它允许开发者在代码中添加自定义的元数据,并且可以在运行时使用反射机制来获取和处理这些信息。在 Spring MVC 中,自定义注解通常用于定义控制器、请求处理方法、参数或者…

3、靶场——Pinkys-Place v3(3)

文章目录 一、获取flag41.1 关于SUID提权1.2 通过端口转发获取setuid文件1.3 运行pinksecd文件1.4 利用nm对文件进行分析1.5 构建payload1.6 Fire 二、获取flag52.1 生成ssh公钥2.2 免密登录ssh2.3 以pinksecmanagement的身份进行信息收集2.4 测试程序/usr/local/bin/PSMCCLI2.…

基于matlab实现的额 BP神经网络电力系统短期负荷预测未来(对比+误差)完整程序分享

基于matlab实现的额 BP神经网络电力系统短期负荷预测 完整程序: clear; clc; %%输入矢量P(15*10) P[0.2452 0.1466 0.1314 0.2243 0.5523 0.6642 0.7105 0.6981 0.6821 0.6945 0.7549 0.8215 0.2415 0.3027 0; 0.2217 0.1581 0.1408 0.23…

JS-ECharts-前端图表 多层级联合饼图、柱状堆叠图、柱/线组合图、趋势图、自定义中线、平均线、气泡备注点

本篇博客背景为JavaScript。在ECharts在线编码快速上手,绘制相关前端可视化图表。 ECharts官网:https://echarts.apache.org/zh/index.html 其他的一些推荐: AntV:https://antv.vision/zh chartcube:https://chartcub…

【力扣1464】数组中两元素的最大乘积

👑专栏内容:力扣刷题⛪个人主页:子夜的星的主页💕座右铭:前路未远,步履不停 目录 一、题目描述二、题目分析1、排序2、最值模拟 一、题目描述 题目链接:数组中两元素的最大乘积 给你一个整数数…

基于SSM的社区志愿者招募系统

博主主页:猫头鹰源码 博主简介:Java领域优质创作者、CSDN博客专家、公司架构师、全网粉丝5万、专注Java技术领域和毕业设计项目实战 主要内容:毕业设计(Javaweb项目|小程序等)、简历模板、学习资料、面试题库、技术咨询 文末联系获取 项目介绍…

ERR_CONNECTION_REFUSED等非标准的HTTP错误状态码原因分析和解决办法

文章目录 一、DNS Resolution Failed1,DNS服务器故障2,DNS配置错误3,DNS劫持4,域名过期-5,其他网络问题 二、ERR_CONNECTION_REFUSED-"ERR_CONNECTION_REFUSED" 错误可能有多种原因 三、ERR_SSL_PROTOCOL_ER…

组队竞赛(int溢出问题)

目录 一、题目 二、代码 &#xff08;一&#xff09;没有注意int溢出 &#xff08;二&#xff09;正确代码 1. long long sum0 2. #define int long long 3. 使用现成的sort函数 一、题目 二、代码 &#xff08;一&#xff09;没有注意int溢出 #include <iostream&g…

CoreData 在新建或更新托管对象中途发生错误时如何恢复如初?

问题现象 在 CoreData 支持的 App 中,当我们新建或更新托管对象到一半突然出现错误时,应该禁止任何已发生的改变被写入内存或数据库中。不过,有时仍会出现始料未及的“意外”: 从上面的演示可以看到:即使在 Item 对象新建和更新途中出现错误后不执行后续的保存操作,但界…

追光者的梦

追光者的梦 鸿蒙中我茫然于世&#xff0c;你是钻入我心里的那束光 我所有的梦想都是和你热烈的拥抱 没有追到你时&#xff0c;我一直在路上 追到你时&#xff0c;我的人生就被你点燃 ——致所有的追光者 合肥先进光源国家重大科技基础设施项目及配套工程启动会刚开过&…

重新认识架构—不只是软件设计

前言 什么是架构&#xff1f; 通常情况下&#xff0c;人们对架构的认知仅限于在软件工程中的定义&#xff1a;架构主要指软件系统的结构设计&#xff0c;比如常见的SOLID准则、DDD架构。一个良好的软件架构可以帮助团队更有效地进行软件开发&#xff0c;降低维护成本&#xff0…

RestTemplate:简化HTTP请求的强大工具

文章目录 什么是RestTemplateRestTemplate的作用代码示例 RestTemplate与HttpClient 什么是RestTemplate RestTemplate是一个在Java应用程序中发送RESTful HTTP请求的强大工具。本文将介绍RestTemplate的定义、作用以及与HttpClient的对比&#xff0c;以帮助读者更好地理解和使…

建构居住安全生态,鹿客科技2023秋季发布会圆满举办

9月20日&#xff0c;以「Lockin Opening」为主题的2023鹿客秋季发布会在上海隆重举办&#xff0c;面向居住安全领域鹿客带来了最新的高端旗舰智能锁新品、多眸OS1.0、Lockin Care服务以及全联接OPENING计划。此外&#xff0c;现场还邀请了国家机构、合作伙伴、技术专家等业界同…

什么是单点登录?什么又是 OAuth2.0?

对于刚开始接触身份认证的朋友对于单点登录&#xff0c;OAuth2.0&#xff0c;JWT 等等会有诸多疑惑&#xff0c;甚至还会问既然有了 JWT 还拿 单点登录做什么&#xff1f;还拿 OAuth2.0 做什么&#xff1f; 不知做过身份认证的 xdm 看到这里是不是感觉这句话有点迷&#xff1f…