pytorch实现RNN网络

目录

1.导包

2. 加载本地文本数据

 3.构建循环神经网络层

4.初始化隐藏状态state

5.创建随机的数据,检测一下代码是否能正常运行

6. 构建一个完整的循环神经网络¶ 

7.模型训练 

8.个人知识点理解


 

1.导包

import torch
from torch import nn
from torch.nn import functional as F
import dltools

2. 加载本地文本数据

#声明变量:批次大小(一批所取的数据量)、子序列的长度
batch_size, num_steps =32, 35
#获取训练数据的迭代器, 词汇表
train_iter, vocab = dltools.load_data_time_machine(batch_size=batch_size, num_steps=num_steps)

 3.构建循环神经网络层

#声明变量:隐藏层的神经元数量(每个神经元都会有一个输出)
num_hiddens = 256
#构建一个具有256个隐藏单元的单隐藏层的循环神经网络
#num_layers=1默认值:一层神经网络
rnn_layer = nn.RNN(input_size=len(vocab), hidden_size=num_hiddens, num_layers=1)

4.初始化隐藏状态state

# 括号中的1:因为num_layers=1默认值:一层神经网络
state = torch.zeros((1, batch_size, num_hiddens))
state.shape
torch.Size([1, 32, 256])

5.创建随机的数据,检测一下代码是否能正常运行

X = torch.rand(size=(num_steps, batch_size, len(vocab)))
#传入X和初始化时的state,获取Y和state_new
Y, state_new = rnn_layer(X, state)
Y.shape, state_new.shape#有输出表示代码正常运行!!!

 (torch.Size([35, 32, 256]), torch.Size([1, 32, 256])) 

6. 构建一个完整的循环神经网络¶ 

.long() 方法‌:这是PyTorch张量的一个方法,用于将张量的数据类型转换为torch.long。torch.long是一种整数数据类型,通常用于索引或存储不需要浮点数精度的整数数据。 

class RNNModel(nn.Module):   #继承nn.Module#初始化(需要用到的)参数,  **kwargs表示继承的其他参数(不一一写明的意思)#vocab_size = len(vocab)def __init__(self, rnn_layer, vocab_size, **kwargs):#继承父类的属性和方法super().__init__(**kwargs)self.rnn_layer = rnn_layer#词汇表的长度self.vocab_size =vocab_sizeself.num_hiddens = self.rnn_layer.hidden_size#判断是否为双向循环if not self.rnn_layer.bidirectional:self.num_directions = 1#nn.Linear用于定义线性层的类,一般用于全连接层self.linear = nn.Linear(in_features=self.num_hiddens, out_features=self.vocab_size)else:self.num_directions = 2self.linear = nn.Linear(self.num_hiddens*2, self.vocab_size)#定义了数据在模型中的前向传播过程。(串联每一件事件的逻辑顺序)def forward(self, inputs, state):#one_hot编码,处理输入的X数据,此时的X.shape=(batch_size, num_steps)#。T转置之后,X.shape=(num_steps,batch_size)#one_hot编码之后, X.shape=(num_steps,batch_size, len(vocab)X = F.one_hot(inputs.T.long(), self.vocab_size)#将数据转化为tensorX = X.to(torch.float32)Y, state = self.rnn_layer(X, state)#此时,Y.shape = torch.Size(num_steps, batch_size, num_hiddens)#输出层:Y.shape必须是一个二维的, -1表示合并Y.shape中的num_steps与batch_size,outputs = self.linear(Y.reshape(-1, Y.shape[-1]))return outputs, state# 初始化隐藏状态def begin_state(self, device, batch_size=1):return torch.zeros((self.num_directions * self.rnn_layer.num_layers, batch_size, self.num_hiddens), device=device)
#在训练之前,基于随机初始化的权重进行预测,测试模型
device = dltools.try_gpu()
rnn_net = RNNModel(rnn_layer, vocab_size=len(vocab))
rnn_net = rnn_net.to(device)
dltools.predict_ch8(prefix='time traveller',num_preds=10, net=rnn_net, vocab=vocab, device=device)
'time travellergghhhhhhhh'

7.模型训练 

#声明变量
#模型训练时,可以先让学习率的值稍大一些,让梯度下降的快一些,然后
#梯度下降到一定程度再改成较小的值
num_epochs, lr = 500, 0.1
dltools.train_ch8(net=rnn_net, train_iter=train_iter, vocab=vocab, lr=lr, num_epochs=num_epochs, device=device)

 

8.个人知识点理解

 

 

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/145398.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

API安全推荐厂商瑞数信息入选IDC《中国数据安全技术发展路线图》

近日,全球领先的IT研究与咨询公司IDC发布报告《IDC TechScape:中国数据安全技术发展路线图,2024》。瑞数信息凭借其卓越的技术实力和广泛的行业应用,被IDC评选为“增量型”技术曲线API安全的推荐厂商。 IDC指出,数据安…

Liveweb视频汇聚平台支持GB28181转RTMP、HLS、RTSP、FLV格式播放方案

GB28181协议凭借其在安防流媒体行业独有的大统一地位,目前已经在各种安防项目上使用。雪亮工程、幼儿园监控、智慧工地、物流监控等等项目上目前都需要接入安防摄像头或平台进行直播、回放。而GB28181协议作为国家推荐标准,目前基本所有厂家的安防摄像头…

Netty源码解析-请求处理与多路复用

Netty基本介绍,参考 Netty与网络编程 摘要 Netty源码系列-NioEventLoop 1.1 Netty给Channel分配Nio Event Loop的规则 看下图,EventLoopGroup是线程组,每个EventLoop是一个线程,那么线程处理请求是怎么分配的呢?我…

Docker 以外置数据库方式部署禅道

2.安装步骤 2.1.参考资料 禅道官网文档: https://www.zentao.net/book/zentaopms/docker-1111.html https://www.zentao.net/book/zentaopms/405.html 2.2.详细步骤 ssh 登录服务器创建目录 /opt/zentao /opt/zentao/data /opt/zentao/db cd /opt mkdir zentao mkdir zentao…

回归预测 | Matlab实现SSA-HKELM麻雀算法优化混合核极限学习机多变量回归预测

回归预测 | Matlab实现SSA-HKELM麻雀算法优化混合核极限学习机多变量回归预测 目录 回归预测 | Matlab实现SSA-HKELM麻雀算法优化混合核极限学习机多变量回归预测效果一览基本介绍程序设计参考资料 效果一览 基本介绍 1.Matlab实现SSA-HKELM麻雀算法优化混合核极限学习机多变量…

java -versionbash:/usr/lib/jvm/jdk1.8.0_162/bin/java:无法执行二进制文件:可执行文件格式错误

实验环境:Apple M1在VMwareFusion使用Utubun Jdk文件错误  尝试: 1、重新在网盘下载java1.8 2、在终端通过命令下载 3、确保 JDK 正确安装在系统中,可以通过 echo $JAVA_HOME 检查 JAVA_HOME 环境变量是否设置正确。 &#xfff…

十种果冻的做法

菠萝果冻 1.在菠萝的1/5处切开,切去顶做盖子用,用水果刀在四周划一圈使皮和果肉分离 2.注意底部切透了,用水果刀把菠萝肉挖出,菠萝肉切丁用盐水浸泡备用 3.把菠萝丁放入料理机中加入少许纯净水,打成菠萝汁备用 4.打好…

【学习笔记】数据结构(六 ②)

树和二叉树(二) 文章目录 树和二叉树(二)6.3.2 线索二叉树 6.4 树和森林6.4.1 树的存储结构6.4.2 森林与二叉树的转换6.4.3 树和森林的遍历 6.5 树与等价问题6.5.1 等价定义6.5.2 划分等价类的方法6.5.3 划分等价类的具体操作 - 并…

【IoTDB 线上小课 07】多类写入接口,快速易懂的“说明书”!

【IoTDB 视频小课】稳定更新中!第七期来啦~ 关于 IoTDB,关于物联网,关于时序数据库,关于开源... 一个问题重点,3-5 分钟,我们讲给你听: 一条视频了解写入接口 了解我们的友友们,应该…

[Linux]Vi和Vim编辑器

Vi和Vim编辑器 Linux系统会内置vi文本编辑器, 类似于windows中的记事本 Vim具有程序编辑的能力, 可以看作是Vi的增强版本, 可以进行语法检查, 代码补全,代码编译和错误调整等功能 Vi和Vim的模式 快速入门 使用vim开发一个Hello.java程序 通过Xshell连接Linux系统命令行输入…

XML:DOM4j解析XML

XML简介: 什么是XML:XML 是独立于软件和硬件的信息传输工具。 XML 的设计宗旨是传输数据,而不是显示数据。XML 标签没有被预定义。您需要自行定义标签。XML不会做任何事情,XML被设计用来结构化、存储以及传输信息。 XML可以发明…

企业内网安全

企业内网安全 1.安全域2.终端安全3.网络安全网络入侵检测系统异常访问检测系统隐蔽信道检测系统 4.服务器安全基础安全配置入侵防护检测 5.重点应用安全活动目录邮件系统VPN堡垒机 6.蜜罐体系建设蜜域名蜜网站蜜端口蜜服务蜜库蜜表蜜文件全民皆兵 1.安全域 企业出于不同安全防…

[备忘]测算.net中对象所占用的内存

.net 基础库中应该是没有直接提供计算某个对象所占内存的方法。简单查了下,找到几种方式: 1、运行态用工具进行内存分析 比如,微软这篇教程中有介绍。《使用 .NET 对象分配工具分析内存使用情况》https://learn.microsoft.com/zh-cn/visuals…

优数:助力更高效的边缘计算

在数字化时代的浪潮中,数据已成为企业最宝贵的资产之一。随着物联网(IoT)设备的激增和5G技术的兴起,我们正迅速步入一个新时代,在这个时代中,数据不仅在量上爆炸性增长,更在速度和实时性上提出了…

idea 恢复 pom 文件呈现灰色并带删除线

今天在 idea 中导入别人的项目时发现有几个 pom 文件是灰色的并带删除线。 可以用以下方式解决: 打开file - settings - build,execution,deployment - Build Tools - Maven - Ignored Files 把 pom.xml 前面的复选框去掉,去掉之后,点击 appl…

IMS 中private user id/public user id的格式

private user identity(IMPI)的格式为 "<IMSI>ims.mnc<MNC>.mcc<MCC>.3gppnetwork.org" public user identity SIP URI 格式为 sip:usernamedomain&#xff1b;而Tel URI 格式为 tel:<CC><NDC><SN> temporary public user ide…

linux安装nginx+前端部署vue项目(实际测试react项目也可以)

&#x1f9f8;本篇博客作者测试上线过不下5个项目&#xff0c;包括单纯的静态资源&#xff0c;vue项目和react项目&#xff0c;包好用&#xff0c;请放心使用 &#x1f4dc;作者首页&#xff1a;dream_ready-CSDN博客 &#x1f4dc;有任何问题都可以评论留言&#xff0c;作者将…

无人机飞手教员培训及就业分析

无人机飞手教员培训是一个综合性的教育体系&#xff0c;旨在培养具备专业飞行技能、扎实理论知识以及高效教学能力的无人机飞手导师。培训内容广泛覆盖从无人机基础知识到高级飞行技巧&#xff0c;再到教学方法与技巧&#xff0c;确保学员能够全面掌握无人机操作与教学的精髓。…

pcdn盒子连接方式

连接方式 大部分连接方式如下 光猫拨号 → 路由器 → 盒子 优点&#xff1a;光猫负责拨号&#xff0c;路由器只需做路由转发&#xff0c;性能要求不高缺点&#xff1a;光猫会有一层nat&#xff0c;路由器还有一层nat&#xff0c;两层nat需要在两个设备上都做nat优化注意&…

macOS 中搭建 Flutter 开发环境

如果你的 Mac 是 Apple silicon 处理器&#xff0c;那么有些 Flutter 组件就需要通过 Rosetta 2 来转换适配&#xff08;详情&#xff09;。要在 Apple silicon 处理器上运行所有 Flutter 组件&#xff0c;请运行以下指令来安装 Rosetta 2。 sudo softwareupdate --install-ro…