【pytorch20】多分类问题

网络结构以及示例

在这里插入图片描述
该网络的输出不是一层或两层的,而是一个十层的代表有十分类
在这里插入图片描述

新建三个线性层,每个线性层都有w和b的tensor

首先输入维度是784,第一个维度是ch_out,第二个维度才是ch_in(由于后面要转置),没有经过softmax函数和sigmoid,即logits

上图已经完成了网络的参数的定义和网络的前向传播过程

在这里插入图片描述
nn.CrossEntropyLoss()F.cross_entropy()是一样的功能,都包含softmax和log和F.nll_loss

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transformsbatch_size = 200
learning_rate = 0.01
epochs = 10train_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=True, download=True,transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(datasets.MNIST('data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,), (0.3081,))])),batch_size=batch_size, shuffle=True)w1, b1 = torch.randn(200, 784, requires_grad=True), \torch.zeros(200, requires_grad=True)
w2, b2 = torch.randn(200, 200, requires_grad=True), \torch.zeros(200, requires_grad=True)
w3, b3 = torch.randn(10, 200, requires_grad=True), \torch.zeros(10, requires_grad=True)def forward(x):x = x @ w1.t() + b1x = F.relu(x)x = x @ w2.t() + b2x = F.relu(x)x = x @ w3.t() + b3x = F.relu(x)return x# train
optimizer = optim.SGD([w1, b1, w2, b2, w3, b3], lr=learning_rate)
criten = nn.CrossEntropyLoss()for epoch in range(epochs):for batch_idx, (data, target) in enumerate(train_loader):data = data.view(-1, 28 * 28)logits = forward(data)loss = criten(logits, target)optimizer.zero_grad()loss.backward()optimizer.step()if batch_idx % 100 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(epoch, batch_idx * len(data), len(train_loader.dataset),100. * batch_idx / len(train_loader), loss.item()))test_loss = 0correct = 0for data, target in test_loader:data = data.view(-1, 28 * 28)logits = forward(data)test_loss += criten(logits, target).item()# 每一行的最大值对应的索引pred = logits.data.max(1)[1]correct += pred.eq(target.data).sum()test_loss /= len(test_loader.dataset)print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(test_loss, correct, len(test_loader.dataset),100. * correct / len(test_loader.dataset)))

在这里插入图片描述
创建的loss一直不变,为什么会出现这个问题,这个网络的结构层次并不深,数据集也比较简单,但这里出现了梯度弥散的情况,因为loss长时间得不到更新,因为梯度信息几乎接近于0

为什么会出现梯度为0?
影响训练的因素,除了有loss,学习率过大,还有初始化的问题,把初始化代码加上

torch.nn.init.kaiming_normal_(w1)
torch.nn.init.kaiming_normal_(w2)
torch.nn.init.kaiming_normal_(w3)

为什么b不初始化,因为已经初始化为0了

但是w也初始化,只是它们使用的是高斯分布进行初始化,即使是用高斯分布初始化后,结果也不满意,所以用了何凯明的初始化
在这里插入图片描述
可以看出loss直接到0.4了,准确率也达到了80%,而且这里还没运行完,运行完效果会更好

可以看出对于分类问题,初始化参数非常关键

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1474184.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Fast R-CNN(论文阅读)

论文名:Fast R-CNN 论文作者:Ross Girshick 期刊/会议名:ICCV 2015 发表时间:2015-9 ​论文地址:https://arxiv.org/pdf/1504.08083 源码:https://github.com/rbgirshick/fast-rcnn 摘要 这篇论文提出了一…

基于java+springboot+vue实现的校园外卖服务系统(文末源码+Lw)292

摘 要 传统信息的管理大部分依赖于管理人员的手工登记与管理,然而,随着近些年信息技术的迅猛发展,让许多比较老套的信息管理模式进行了更新迭代,外卖信息因为其管理内容繁杂,管理数量繁多导致手工进行处理不能满足广…

Stream流真的很好,但答应我别用toMap()

你可能会想,toList 和 toSet 都这么便捷顺手了,当又怎么能少得了 toMap() 呢。 答应我,一定打消你的这个想法,否则这将成为你噩梦的开端。 让我们先准备一个用户实体类。 Data AllArgsConstructor public class User { priv…

昇思MindSpore学习总结十——ResNet50迁移学习

1、迁移学习 (抄自CS231n Convolutional Neural Networks for Visual Recognition) 在实践中,很少有人从头开始训练整个卷积网络(使用随机初始化),因为拥有足够大小的数据集相对罕见。相反,通常…

LLM - 循环神经网络(RNN)

1. RNN的关键点:即在处理序列数据时会有顺序的记忆。比如,RNN在处理一个字符串时,在对字母表顺序有记忆的前提下,处理这个字符串会更容易。就像人一样,读取下面第一个字符串会更容易,因为人对字母出现的顺序…

Linux应用---信号

写在前面:在前面的学习过程中,我们学习了进程间通信的管道以及内存映射的方式。这次我们介绍另外一种应用较为广泛的进程间通信的方式——信号。信号的内容比较多,是学习的重点,大家一定要认真学,多多思考。 一、信号概…

【AI资讯】可以媲美GPT-SoVITS的低显存开源文本转语音模型Fish Speech

Fish Speech是一款由fishaudio开发的全新文本转语音工具,支持中英日三种语言,语音处理接近人类水平,使用Flash-Attn算法处理大规模数据,提供高效、准确、稳定的TTS体验。 Fish Audio

驾校管理系统设计

驾校管理系统设计旨在提高驾校运营效率、学员管理、教练安排、考试预约、财务结算等方面的能力。以下是一个基本的设计框架,包括关键模块和数据表设计: 1. 系统架构设计 前端界面:提供给学员、教练和管理员使用的Web界面或移动应用&#xf…

【高中数学/基本不等式】当x是正实数时,求函数f(x)=x/(1+x^2)的最大值?

【问题】 当x是正实数时&#xff0c;求函数f(x)x/(1x^2)的最大值&#xff1f; 【解答】 解&#xff1a; f(x)x/(1x^2)1/(x1/x))<1/2倍根号下(x*1/x)1/2 所以函数在[0,∞)的区域最大值为0.5 【函数图像】 f(x)x/(1x^2)是奇函数&#xff0c;没有断点&#xff0c;是可以…

idea推送到gitee 401错误

在idea上推送时遇到这样的问题&#xff0c;解决方法如下&#xff1a; 在https://的后面加上 用户名:密码 然后再提交就ok啦&#xff01;

使用qt creator配置msvc环境(不需要安装shit一样的宇宙第一IDE vs的哈)

1. 背景 习惯使用Qt编程的童鞋&#xff0c;尤其是linux下开发Qt的童鞋一般都是使用qt creator作为首选IDE的&#xff0c;通常在windows上使用Qt用qt creator作为IDE的话一般编译器有mingw和msvc两种&#xff0c;使用mingw版本和在linux下的方式基本上一样十分简单&#xff0c;不…

Alt与Tab切换窗口时将Edge多个标签页作为一个整体参与切换的方法

本文介绍在Windows电脑中&#xff0c;使用Alt与Tab切换窗口时&#xff0c;将Edge浏览器作为一个整体参与切换&#xff0c;而不是其中若干个页面参与切换的方法。 最近&#xff0c;需要将主要使用的浏览器由原本的Chrome换为Edge&#xff1b;但是&#xff0c;在更换后发现&#…

数据结构与算法笔记:实战篇 - 剖析微服务接口鉴权限流背后的数据结构和算法

概述 微服务是最近几年才兴起的概念。简单点将&#xff0c;就是把复杂的大应用&#xff0c;解耦成几个小的应用 。这样做的好处有很多。比如&#xff0c;这样有利于团队组织架构的拆分&#xff0c;比较团队越大协作的难度越大&#xff1b;再比如&#xff0c;每个应用都可以独立…

mybatis-plus参数绑定异常

前言 最近要搞个发票保存的需求&#xff0c;当发票数据有id时说明是发票已经保存只需更新发票数据即可&#xff0c;没有id时说明没有发票数据需要新增发票&#xff1b;于是将原有的发票提交接口改造了下&#xff0c;将调用mybatis-plus的save方法改为saveOrUpdate方法&#xff…

从零开始实现大语言模型(四):简单自注意力机制

1. 前言 理解大语言模型结构的关键在于理解自注意力机制(self-attention)。自注意力机制可以判断输入文本序列中各个token与序列中所有token之间的相关性&#xff0c;并生成包含这种相关性信息的context向量。 本文介绍一种不包含训练参数的简化版自注意力机制——简单自注意…

Spark快速大数据分析PDF下载读书分享推荐

《Spark 快速大数据分析》是一本为 Spark 初学者准备的书&#xff0c;它没有过多深入实现细节&#xff0c;而是更多关注上层用户的具体用法。不过&#xff0c;本书绝不仅仅限于 Spark 的用法&#xff0c;它对 Spark 的核心概念和基本原理也有较为全面的介绍&#xff0c;让读者能…

clickhouse高可用可拓展部署

clickhouse高可用&可拓展部署 1.部署架构 1.1高可用架构 1.2硬件资源 部署服务 节点名称 节点ip 核数 内存 磁盘 zookeeper zk-01 / 4c 8G 100G zk-02 / 4c 8G 100G zk-03 / 4c 8G 100G clikehouse ck-01 / 32c 128G 2T ck-02 / 32c 128G 2T ck-03 / 32c 128G 2T ck-04 /…

昇思25天学习打卡营第13天 | LLM原理和实践:文本解码原理--以MindNLP为例

1. 文本解码原理--以MindNLP为例 1.1 自回归语言模型 根据前文预测下一个单词 一个文本序列的概率分布可以分解为每个词基于其上文的条件概率的乘积 W 0 W_0 W0​:初始上下文单词序列 t t t: 时间步 当生成EOS标签时&#xff0c;停止生成。 MindNLP/huggingface Transfor…

13.SQL注入-宽字节

SQL注入-宽字节 含义&#xff1a; MySQL是用的PHP语言&#xff0c;然后PHP有addslashes()等函数&#xff0c;这类函数会自动过滤 ’ ‘’ null 等这些敏感字符&#xff0c;将它们转义成’ ‘’ \null&#xff1b;然后宽字节字符集比如GBK它会自动把两个字节的字符识别为一个汉…

解决分布式环境下session共享问题

在分布式环境下&#xff0c;session会存在两个问题 第一个问题:不同域名下&#xff0c;浏览器存储的jsessionid是没有存储的。比如登录时认证服务auth.gulimall.com存储了session&#xff0c;但是搜索服务search.gulimall.com是没有这个session的&#xff1b; 第二个问题&…