20241008深度学习动手篇

文章目录

    • 1.如何写一个神经网络进行训练?
      • 1.1创建一个子类,搭建你需要的神经网络结构
      • 1.2 加载数据集
      • 1.3 自定义一些指标评估函数
      • 1.4训练
      • 1.5 结果展示
    • 2.参考文献

在这里插入图片描述

1.如何写一个神经网络进行训练?

1.1创建一个子类,搭建你需要的神经网络结构

# @File: 241008LeNet.py
# @Author: chen_song
# @Time: 2024/10/8 上午8:31import torch
from torch import nn
from d2l import torch as d2lnet = nn.Sequential(# 进行卷积操作以后,nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),nn.AvgPool2d(2,stride=2),nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),nn.AvgPool2d(2,stride=2),nn.Flatten(),nn.Linear(16*5*5,120),nn.Sigmoid(),nn.Linear(120,84),nn.Sigmoid(),nn.Linear(84,10)
)
print(net)
print("===============================")
X = torch.rand(size=(1,1,28,28),dtype=torch.float32)
Y  = X.copy_(X)
for layer in net:X = layer(X)print(layer.__class__.__name__,X.shape)print("============================")
# 输入给定以后,会进行一系列张量乘法计算
A = net(Y)
# print the last result
print(A)

result below:

Sequential( (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1),
padding=(2, 2)) (1): Sigmoid() (2): AvgPool2d(kernel_size=2,
stride=2, padding=0) (3): Conv2d(6, 16, kernel_size=(5, 5),
stride=(1, 1)) (4): Sigmoid() (5): AvgPool2d(kernel_size=2,
stride=2, padding=0) (6): Flatten(start_dim=1, end_dim=-1) (7):
Linear(in_features=400, out_features=120, bias=True) (8): Sigmoid()
(9): Linear(in_features=120, out_features=84, bias=True) (10):
Sigmoid() (11): Linear(in_features=84, out_features=10, bias=True) )
=============================== Conv2d torch.Size([1, 6, 28, 28]) Sigmoid torch.Size([1, 6, 28, 28]) AvgPool2d torch.Size([1, 6, 14,
14]) Conv2d torch.Size([1, 16, 10, 10]) Sigmoid torch.Size([1, 16, 10,
10]) AvgPool2d torch.Size([1, 16, 5, 5]) Flatten torch.Size([1, 400])
Linear torch.Size([1, 120]) Sigmoid torch.Size([1, 120]) Linear
torch.Size([1, 84]) Sigmoid torch.Size([1, 84]) Linear torch.Size([1,
10])
============================ tensor([[-0.2278, -0.5057, -0.6303, 0.1526, -0.1510, -0.1933, -0.3120, -0.7823,
0.4070, -0.0717]], grad_fn=)

1.2 加载数据集

batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

打断点调试:
在这里插入图片描述在这里插入图片描述
你会发现:
train_iter和test_iter都是一个torch.utils.dataLoader对象,里面包含几个成员变量,住关键的是dataset对象以及sample对象,仔细研究你就会发现,为啥需要数据加载器了,因为你用神经网络进行训练,数据格式总得对吧,再就是要给个label吧,也就是目标值target吧,所以有余力朋友可以自己设计一个数据加载器…

1.3 自定义一些指标评估函数

def evaluate_accuracy_gpu(net, data_iter, device=None):  # @save"""使用GPU计算模型在数据集上的精度"""if isinstance(net, nn.Module):net.eval()  # 设置为评估模式if not device:device = next(iter(net.parameters())).device# 正确预测的数量,总预测的数量metric = d2l.Accumulator(2)with torch.no_grad():for X, y in data_iter:if isinstance(X, list):# BERT微调所需的(之后将介绍)=== 自然语言处理X = [x.to(device) for x in X]else:X = X.to(device)y = y.to(device)metric.add(d2l.accuracy(net(X), y), y.numel())return metric[0] / metric[1]

注意一下里面net.eval()和net.train()

1.4训练

def train_ch6(net, train_iter, test_iter, num_epochs, lr, device):"""用GPU训练模型(在第六章定义)"""def init_weights(m):if type(m) == nn.Linear or type(m) == nn.Conv2d:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)print('training on', device)net.to(device)optimizer = torch.optim.SGD(net.parameters(), lr=lr)loss = nn.CrossEntropyLoss()animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs],legend=['train loss', 'train acc', 'test acc'])timer, num_batches = d2l.Timer(), len(train_iter)for epoch in range(num_epochs):# 训练损失之和,训练准确率之和,样本数metric = d2l.Accumulator(3)net.train()for i, (X, y) in enumerate(train_iter):timer.start()optimizer.zero_grad()X, y = X.to(device), y.to(device)y_hat = net(X)l = loss(y_hat, y)l.backward()optimizer.step()with torch.no_grad():metric.add(l * X.shape[0], d2l.accuracy(y_hat, y), X.shape[0])timer.stop()train_l = metric[0] / metric[2]train_acc = metric[1] / metric[2]if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:animator.add(epoch + (i + 1) / num_batches,(train_l, train_acc, None))test_acc = evaluate_accuracy_gpu(net, test_iter)animator.add(epoch + 1, (None, None, test_acc))print(f'loss {train_l:.3f}, train acc {train_acc:.3f}, 'f'test acc {test_acc:.3f}')print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec 'f'on {str(device)}')lr, num_epochs = 0.9, 10
train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
d2l.plt.show()

1.5 结果展示

在这里插入图片描述

2.参考文献

[1]王辉,张帆,刘晓凤,等.基于DarkNet-53和YOLOv3的水果图像识别[J].东北师大学报(自然科学版),2020,52(04):60-65.DOI:10.16163/j.cnki.22-1123/n.2020.04.010.
[2]王治国,曹爽,管海燕,等.基于改进SSD的城市地下排水管道缺陷识别算法[J].测绘工程,2024,33(05):7-13.DOI:10.19349/j.cnki.issn1006-7949.2024.05.002.
[3]杨继雯.基于深度学习的监控视频中人员异常行为识别技术[D].西安工业大学,2024.DOI:10.27391/d.cnki.gxagu.2024.000829.

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1558231.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

“炫我”受邀出席虚拟现实及元宇宙产业创新论坛!

当前,新一轮科技革命和产业变革向纵深演进,虚拟现实及元宇宙等相关产业加速发展,催生了新产业新业态新模式,发展潜力巨大、应用前景广阔。 9月27日,由北京市科学技术委员会、中关村科技园区管理委员会,北京…

什么是变阻器?

变阻器是一种电子元件,主要用于调整电路中的电阻值,从而实现对电流、电压等电学参数的控制。它在电路中起到非常重要的作用,广泛应用于各种电子设备和实验装置中。 变阻器的主要作用是改变电路中的电阻值。在电路中,电阻值的大小…

基于springboot vue 学生就业信息管理系统设计与实现

博主介绍:专注于Java(springboot ssm springcloud等开发框架) vue .net php phython node.js uniapp小程序 等诸多技术领域和毕业项目实战、企业信息化系统建设,从业十五余年开发设计教学工作☆☆☆ 精彩专栏推荐订阅☆☆☆☆…

LC538 - 把二叉搜索树转换为累加树

文章目录 1 题目2 思路3 ACM模式参考 1 题目 https://leetcode.cn/problems/convert-bst-to-greater-tree/description/ 给出二叉 搜索 树的根节点,该树的节点值各不相同,请你将其转换为累加树(Greater Sum Tree) 累加树&#…

递归特征消除(RFE)与随机森林回归模型的 MATLAB 实现

在机器学习中,特征选择是提高模型性能的重要步骤。本文将详细探讨使用递归特征消除(RFE)结合随机森林回归模型的实现,以研究对股票收盘价影响的特征。我们将逐步分析代码并介绍相关的数学原理。 1. 数据准备 首先,我…

wordpress发邮件SMTP服务器配置步骤指南?

wordpress发邮件功能如何优化?怎么用wordpress发信? 由于WordPress默认的邮件发送功能可能不够稳定,配置SMTP服务器成为了许多网站管理员的选择。AokSend将详细介绍如何在WordPress中配置SMTP服务器,以确保邮件能够顺利发送。 w…

注册安全分析报告:惠农网

前言 由于网站注册入口容易被黑客攻击,存在如下安全问题: 暴力破解密码,造成用户信息泄露短信盗刷的安全问题,影响业务及导致用户投诉带来经济损失,尤其是后付费客户,风险巨大,造成亏损无底洞…

BLE MESH学习2——自定义MESH网络架构思考

BLE MESH学习2——自定义MESH网络架构思考 基于对WCH CH582这款单片机的了解,其可以实现mesh配网、朋友节点、低功耗节点和中继节点的角色,基本功能无问题。在此基础上,考虑满足IoT需求的MESH架构设计,作为后续设计的“白皮书”。…

第170天:应急响应-战中溯源反制对抗上线CSGoby蚁剑Sqlmap等安全工具

目录 案例一:溯源反制-Webshell工具-Antsword 案例二:溯源反制-SQL注入工具-SQLMAP 案例三:溯源反制-漏洞扫描工具-Goby 案例四:溯源反制-远程控制工具-CobaltStrike 反制Server,爆破密码(通用&#x…

快餐食品检测系统源码分享[一条龙教学YOLOV8标注好的数据集一键训练_70+全套改进创新点发刊_Web前端展示]

快餐食品检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vision …

sql堆叠注入

准备知识: php中multi_query():一次可以执行多个sql语句比如:查询注入id1;update xxx; 定义:如果后端代码中,数据库执行的方法是multi_query(),那么就可以一次执行多个sql,也就可以…

java面向对之象类的继承与多态

目录 1.类的继承 图解 案例:创建一个动物类和一个猫类 1.代码 1)动物类 2)猫类 3.测试类 2.效果 2.父类方法的重写 案例:如何重写父类的方法 1.代码 1)Animal类 2)Dog类 3)测试类 2.效果 3.super关键字 案例:如何在子类中调用父类的方…

肺结节分割与提取系统(基于传统图像处理方法)

Matlab肺结节分割(肺结节提取)源程序,GUI人机界面版本。使用传统图像分割方法,非深度学习方法。使用LIDC-IDRI数据集。 工作如下: 1、读取图像。读取原始dicom格式的CT图像,并显示,绘制灰度直方图; 2、图像…

系统架构设计师论文《论企业集成平台的理解与应用》精选试读

论文真题 企业集成平台(Enterprise Imtcgation Plaform,EIP)是支特企业信息集成的像环境,其主要功能是为企业中的数据、系统和应用等多种对象的协同行提供各种公共服务及运行时的支撑环境。企业集成平台能够根据业务模型的变化快速地进行信息系统的配置…

使用XML实现MyBatis的基础操作

目录 前言 1.准备工作 1.1⽂件配置 1.2添加 mapper 接⼝ 2.增删改查操作 2.1增(Insert) 2.2删(Delete) 2.3改(Update) 2.4查(Select) 前言 接下来我们会使用的数据表如下: 对应的实体类为:UserInfoMapper 所有的准备工作都在如下文章。 MyBati…

github创建仓库并本地使用流程,以及问题src refspec xxx does not match any

1.在 GitHub 上创建一个新仓库 登录你的 GitHub 账户。 点击右上角的 “” 按钮,然后选择 “New repository”。 填写仓库名称(如 my-repo)。 (可选)添加描述,选择是否公开或私有仓库。 (可选&…

电层相关 -- Transponder Muxponder

光波长转换类单板(Optical Transponder Unit,简称OTU单板)主要将客户侧业务经过封装映射、汇聚等处理后,输出符合WDM系统要求的标准波长的光信号。OTU的主要功能有两类:Transponder 和Muxponder,简称TP和MP…

Python 字典(Dictionary) items(),pop(‘key‘)方法

描述 Python 字典(Dictionary) items() 函数以列表返回可遍历的(键, 值) 元组数组。 语法 items()方法语法: dict.items()参数 NA。 返回值 返回可遍历的(键, 值) 元组数组。 实例 以下实例展示了 items()函数的使用方法: tinydict {Google: …

使用Docker搭建WAF-开源Web防火墙VeryNginx

1、说明 VeryNginx 基于 lua_nginx_module(openrestry) 开发,实现了防火墙、访问统计和其他的一些功能。 集成在 Nginx 中运行,扩展了 Nginx 本身的功能,并提供了友好的 Web 交互界面。 文章目录 1、说明1.1、基本概述1.2、主要功能1.3、应用场景2、拉取镜像3、配置文件4、…

IPv6为什么没有完全代替IPv4

IPv4的设计始于20世纪70年代末,随着ARPANET的扩展和网络需求的增加,工程师们意识到需要一个更大规模、更灵活的地址系统。IPv4在1981年被正式定义为RFC 791,它成为了互联网协议套件的一部分,并迅速被广泛采用。 IPv4地址由32位(4字节)组成,通常以点分十进制表示。例如,…