004_动手实现MLP(pytorch)

import torch
from torch import nn
from torch.nn import init
import numpy as np
import sys
import d2lzh_pytorch as d2l
# 1.数据预处理
mnist_train = torchvision.datasets.FashionMNIST(root='/Users/w/PycharmProjects/DeepLearning_with_LiMu/datasets/FashionMnist', train=True, download=True,transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='/Users/w/PycharmProjects/DeepLearning_with_LiMu/datasets/FashionMnist', train=False, download=True,transform=transforms.ToTensor())
# 1.2 数据集的预处理
batch_size = 256
if sys.platform.startswith('win'):num_worker = 0
else:num_worker = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_worker)
test_iter  = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_worker)# 封装自定义的结构转换函数
class FlattenLayer(nn.Module):def __init__(self):super(FlattenLayer, self).__init__()def forward(self, x): # x shape: (batch, *, *, ...)return x.view(x.shape[0], -1)
#定义网络结构
num_inputs, num_outputs, num_hiddens = 784, 10, 256
net = nn.Sequential(FlattenLayer(),nn.Linear(num_inputs,num_hiddens),nn.ReLU(),nn.Linear(num_hiddens,num_outputs)
)
for param in net.parameters():print(param.shape)
# 在 PyTorch 中,init.normal_ 是一个初始化方法,用于直接将张量中的元素初始化为来自正态分布(高斯分布)随机生成的值。它属于 torch.nn.init 模块,通常在神经网络的权重初始化中使用。
for params in net.parameters():init.normal_(params, mean=0, std=0.01)
# print 结果 torch.Size([256, 784])
#torch.Size([256])
#torch.Size([10, 256])
#torch.Size([10])batch_size = 256
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.5)
num_epochs = 5def train(net, train_iter, test_iter, loss, num_epochs, batch_size,params=None, lr=None, optimizer=None):for epoch in range(num_epochs):train_l_sum, train_acc_sum, n = 0.0, 0.0, 0for X, y in train_iter:y_hat = net(X)l = loss(y_hat, y).sum()# 梯度清零if optimizer is not None:optimizer.zero_grad()elif params is not None and params[0].grad is not None:for param in params:param.grad.data.zero_()l.backward()if optimizer is None:sgd(params, lr, batch_size)else:optimizer.step()  # “softmax回归的简洁实现”一节将用到train_l_sum += l.item()train_acc_sum += (y_hat.argmax(dim=1) == y).sum().item()n += y.shape[0]test_acc = evaluate_accuracy(test_iter, net)print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f'% (epoch + 1, train_l_sum / n, train_acc_sum / n, test_acc))train(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1544347.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

二刷LeetCode:“51.N皇后 37.解数独”题解心得(简单易懂)

引言(初遇噩梦,再遇坦然) 在阅读本文之前,建议大家已经接触过回溯算法,并完成回溯相关题目,例如:子集问题、组合问题、排列问题。 子集:子集II、子集 组合:组合、组合总和…

多比特AI事业部VP程伟光受邀为第四届中国项目经理大会演讲嘉宾

全国项目经理专业人士年度盛会 武汉市多比特信息科技有限公司AI事业部VP程伟光先生受邀为PMO评论主办的全国项目经理专业人士年度盛会——2024第四届中国项目经理大会演讲嘉宾,演讲议题为“AI对于项目经理工作的影响和变化解析”。大会将于10月26-27日在北京举办&am…

Scanner流程控制语句

1. Scanner类 Scanner的意思是扫描 Scanner是JDK提供的一个类,位于java.util包下,所以我们如果需要使用则必须导包,导包的语句必须在声明包之后,在声明类之前 Scanner类是用来接受用户输入的各种信息 Scanner类提供了用于接受…

SpringBoot开发——整合Hutool工具类轻松生成验证码

文章目录 1、Hutool简介2、验证码效果展示2.1 扭曲干扰验证码2.2 线条干扰验证码2.3 圆圈干扰验证码3、验证码应用场景3.1. 用户注册与身份验证3.2. 支付验证3.3. 订单与物流通知3.4. 信息安全与隐私保护3.5. 通知与提醒3.6. 其他应用场景4、Hutool工具类实现验证码生成4.1 引入…

学习threejs,绘制任意字体模型

👨‍⚕️ 主页: gis分享者 👨‍⚕️ 感谢各位大佬 点赞👍 收藏⭐ 留言📝 加关注✅! 👨‍⚕️ 收录于专栏:threejs gis工程师 文章目录 一、🍀前言二、🍀绘制任意字体模型…

Python邮件发送附件:怎么配置SMTP服务器?

Python邮件发送附件如何实现?Python发送带附件邮件? 在自动化和脚本编写中,使用Python发送带有附件的邮件是一个非常实用的功能。AokSend将详细介绍如何配置SMTP服务器,以便在Python中实现邮件发送附件的功能。 Python邮件发送附…

叉车高位显示器无线摄影,安装更加便捷!

叉车叉货,基本功能,但货叉升降高度确不一定,普通的3米左右,高的十几米,特别是仓储车,仓库叉货空间小,环境昏暗,视线受阻严重,司机叉货升的那么高怎么准确无误的插到货呢&…

艾体宝产品丨无需代码开发!Redis数据集成助你轻松优化数据库

我们不仅致力于加速应用程序的构建过程,更专注于助力您达成最终目标——实现应用的高效运行。因此,我们欣然宣布,Redis 数据集成(Redis Data Integration,RDI)(https://redis.io/data-integration/) 已经正…

探索LLM中的CoT链式推理:ECHO方法深度解读

近年来,随着大型语言模型(LLMs)的快速发展,如何有效利用这些模型进行复杂任务的推理成为了研究热点。其中,链式思考(Chain-of-Thought, CoT)推理方法作为一种有效的策略,能够显著提升…

Windows 安全事件日记中账户登录失败问题处理

一 window系统安全日记 在使用 Windows 系统时,我们可能会在安全事件日记中发现账户登录失败的记录。当遇到这种情况时,不必惊慌。 今天在检查自己的操作系统日记时发现系统的安全事件记录存在大量的-帐户登录失败日记如下: 从上次清除日记到…

Ansys Zemax | 模拟偏振敏感的散射过程

附件下载 联系工作人员获取附件 概述 这篇文章介绍了如何在OpticStudio中使用一个自定义的DLL模拟偏振敏感的体散射和荧光现象。该散射模型由MSP.DLL文件定义,它考虑了非序列模式下入射光的偏振属性,模拟了散射对光线的传播方向和偏振态的影响&#x…

上海交通大学《2020年+2021年816自动控制原理真题》 (完整版)

本文内容,全部选自自动化考研联盟的:《25届上海交通大学816自控考研资料》的真题篇。后续会持续更新更多学校,更多年份的真题,记得关注哦~ 目录 2020年真题 2021年真题 Part1:2020年2021年完整版真题 2020年真题 2…

中电金信多模态鉴伪技术抵御AI造假威胁

AI换脸技术,属于深度伪造最常见方式之一,是一种利用人工智能生成逼真的虚假人脸图片或视频的技术。基于深度学习算法,可以将一个人的面部特征映射到另一个人的面部,创造出看似真实的伪造内容。近年来,以AI换脸为代表的…

带着徒弟从一次跨域漏洞修复展开的学习

一.背景 本次测试使用到的主要工具包含:eclipse、谷歌浏览器、Windows11家庭版、ApiPost。 (一)发生的问题 公司安全兄弟提示我们一个应用存在跨域攻击的漏洞,需要我们修复。扫描情况及整改建议如下: 昨天晚上扫描了…

免费制作证件照的小程序源码

1、效果展示 可以下载程序包,最初级版本免费下载。以上是高级版本。如果你有开发能力的话可以自己写前端,然后以下调用以下api接口,代码如下: 证件照检测制作 接口地址:https://api.zheyings.cn/idcardv3/all 请求方…

2024年网络安全人才平均年薪 24.09 万,跳槽周期 31 个月,安全工程师现状大曝光!_2024网络安全人才市场状况研究报告

网络安全作为近两年兴起的热门行业,成了很多就业无门但是想转行的人心中比较向往但是又心存疑惑的行业,毕竟网络安全的发展史比较短,而国内目前网安的环境和市场情况还不算为大众所知晓,所以到底零基础转行入门网络安全之后&#…

通过pyenv local 3.6.1 这里设置了当前目录的python版本,通过pycharm基于这个版本创建一个虚拟环境

要在 PyCharm 中基于你通过 pyenv local 设置的 Python 版本创建虚拟环境,可以按照以下步骤进行操作: 步骤 1: 获取当前使用的 Python 路径 通过 pyenv 查找当前项目下的 Python 解释器路径,使用以下命令: pyenv which python …

『功能项目』3D模型动态UI显示【76】

本章项目成果展示 我们打开上一篇75主角属性值显示的项目, 本章要做的事情是将3D模型动态显示在主角属性展示界面 首先创建RawImage 调整尺寸 创建文件夹:RenderTexture 创建 Render Texture 创建Camera 在场景中放置一个主角预制体删除所有组件 清空标…

从理论到实践:业务能力建模在数字化转型中的落地实施路径

在数字化转型的浪潮下,企业正在寻求有效的方法来将复杂的战略目标、业务需求和技术能力整合为可执行的操作路径。《业务能力指南》提供了一个系统性的框架,通过业务能力建模帮助企业实现从理论到实践的平稳过渡。本文将以“从理论到实践应用”的视角&…

优思学院:六西格玛(6 Sigma)是什么?

自1987年起,在摩托罗拉公司的推动下,六西格玛的定义已经经历了多次演进。六西格玛可以分为三个基本类别:一种质量方案,主要关注财务成果;一种统计方法,基于过程改进;以及一种统计定义&#xff0…