如何构建LSTM神经网络模型

一、了解LSTM

1. 核心思想

        首先,LSTM 是 RNN(循环神经网络)的变体。它通过引入细胞状态 C(t) 贯穿于整个网络模型,达到长久记忆的效果,进而解决了 RNN 的长期依赖问题。

2. 思维导图

        每个LSTM层次都有三个重要的门结构,从前往后依次是遗忘门(forget gate layer)、输入门(input gate layer)、输出门(output gate layer)。

        还有两个重要的状态,分别是细胞状态(cell state)、隐藏状态(hidden state),即图示中的 C(t) 和 h(t) 。其中细胞状态不仅记忆某个时间步的信息,而是对整个时间序列保持较为稳定的记忆,是一种长期 “记忆信息” 。对于隐藏状态来说,它更多地关注当前时间步以及上一个时间步的输出,是一种短期 “记忆信息”

        具体内容如下面思维导图所示:


二、利用pytorch构建LSTM

1. 构造神经网络模型

1.1 LSTM层
self.lstm = nn.LSTM(input_size=28,  # 每次输入特征数量为28hidden_size=64,  # 表示每个时间步的输出会有 64 个特征num_layers=1,  # LSTM隐藏层的层数batch_first=True  # 输入数据的格式是“批次在第一位”
)
  • input_size: 这告诉模型,每次输入的数据有多少个特征(比如一张28x28像素的图像,每一行就是一个时间步)。也就是图示中的 x(t) 。
  • hidden_size:这是模型的“记忆”大小。即细胞状态C(t) 和隐藏状态 h(t) 的容量。
  • num_layers:等于1则代表只使用一层 LSTM 网络。
  • batch_first:这个参数表示输入数据的维度格式是(批次,时间步、特征数),即批次在第一维。
1.2 全连接层
self.out = nn.Linear(in_features=64,out_features=10  # 将LSTM层提取到的64个特征进一步转化为10个输出(0~9)
)
  • in_features:全连接层的输入大小,来自LSTM的输出,每个时间步的特征数是64(即 hidden_size )
  • out_features:全连接层的输出大小是10,通常表示有10个类别。
 1.3 Softmax层
self.softmax = nn.Softmax(dim=1)

        这一层主要是将全连接层的输出转化为概率分布。如果使用的是交叉熵代价函数(CrossEntropyLoss),可以不加这层。

2. 前向传播

  1. 在前面LSTM层中batch_first参数设置了输入数据的维度格式,即(批次,时间步、特征数)。所以首先要做的就是调整输入的维度格式。这里每个样本是 28 个时间步,每个时间步有 28 个特征(像是一个28x28的图像)

    x = x.view(-1, 28, 28)
  2. 让输入数据通过LSTM层,并最终输出三个信息,分别是 output,h_n 和 c_n。output 包含了每个时间步的输出信息(理解为LSTM分析每个时间步得到的结果)。h_n 是最后一个时间步的隐藏状态,c_n 是记忆状态。我们重点关注 h_n,因为它代表了 LSTM 在处理完所有时间步后的总结。

    output, (h_n, c_n) = self.lstm(x)
    
  3. 接下来从隐藏状态中拿到最后一个时间步 h_n 的输出 output_in_last_timestep。可以理解为,LSTM看完了所有时间步之后,得到了它对整个序列的理解。

    output_in_last_timestep = h_n[-1, :, :]
    
  4. 最后LSTM的输出被送到全连接层,转化成10个数字,这些数字代表模型对每个类别的预测分数。并通过Softmax转化为概率。

    x = self.out(output_in_last_timestep)
    x = self.softmax(x)
    

        构造好的LSTM神经网络模型代码如下所示: 

class LSTM(nn.Module):def __init__(self):super(LSTM, self).__init__()self.lstm = nn.LSTM(input_size=28,  # 每次输入特征数量hidden_size=64,  num_layers=1,  # LSTM隐藏层的层数batch_first=True  )self.out = nn.Linear(in_features=64,out_features=10  # 将LSTM层提取到的64个特征进一步转化为10个输出(0~9))self.softmax = nn.Softmax(dim=1)def forward(self, x):x = x.view(-1, 28, 28)  # 将输入调整成一个 (批次大小, 时间步数, 特征数) 的形式output, (h_n, c_n) = self.lstm(x)output_in_last_timestep = h_n[-1, :, :]  # 从隐藏状态中拿到最后一个时间步的输出x = self.out(output_in_last_timestep)  # LSTM的输出被送到全连接层,转化成10个数字x = self.softmax(x)  return x

三、测试 LSTM 神经网络模型

        用MNIST数据集测试代码如下:

# 训练集
train_dataset = datasets.MNIST(root='./',train=True,transform=transforms.ToTensor(),  # 数据转换为张量格式download=True)
# 测试集
test_dataset = datasets.MNIST(root='./',train=False,transform=transforms.ToTensor(),download=True)# 批次大小
batch_size = 100
# 装载训练集
train_loader = DataLoader(dataset=train_dataset,batch_size=batch_size,  # 每次加载多少条数据shuffle=True)  # 生成数据前打乱数据# 装载测试集
test_loader = DataLoader(dataset=test_dataset,batch_size=batch_size,shuffle=True)LR = 0.001  # 学习率
model = LSTM()  # 模型
crossEntropy_loss = nn.CrossEntropyLoss()  # 交叉熵代价函数
optimizer = optim.Adam(model.parameters(), LR)def train():model.train()for i, data in enumerate(train_loader):inputs, labels = data  # 获得一个批次的数据和标签out = model(inputs)  # 获得模型预测输出(64张图像,10个数字的概率)loss = crossEntropy_loss(out, labels)  # 使用交叉熵损失函数时,可以直接使用整型标签,无须独热编码optimizer.zero_grad()  # 梯度清0loss.backward()  # 计算梯度optimizer.step()  # 修改权值def test():model.eval()correct = 0for i, data in enumerate(test_loader):inputs, labels = data  # 获得一个批次的数据和标签out = model(inputs)  # 获得模型预测结构(64,10)_, predicted = torch.max(out, 1)  # 获得最大值,以及最大值所在位置correct += (predicted == labels).sum()  # 判断64个值有多少是正确的print("测试集正确率:{}\n".format(correct.item() / len(test_loader)))# 训练20个周期
for epoch in range(20):print("Epoch:{}".format(epoch))train()test()

        测试结果: 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1555922.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

VMware ESXi更改https的TLS协议版本

简要概述 TLS 1.0 和 1.1 是已弃用的协议,具有广为人知的缺点和漏洞。应在所有接口上启用 TLS 1.2,并在支持的情况下禁用 SSLv3、TL 1.1 和 1.0。强制要求 TLS 1.2 可能会破坏 vSphere 的第三方集成和加载项。在实施 TLS 1.2 后仔细测试这些集成&#x…

maven指定模块快速打包idea插件Quick Maven Package

问题背景描述 在实际开发项目中,我们的maven项目结构可能不是单一maven项目结构,项目一般会用parent方式将各个项目进行规范; 随着组件的数量增加,就会引入一个问题:我们只想打包某一个修改后的组件A时就变得很不方便…

C++ 算法学习——1.8 悬线法

1.问题引入:对于一个矩形图,图中放置着不少障碍,要求出最大的不含障碍的矩形。 2.分析:显然一个极大矩形是左右上下都被障碍挡住,无法再扩大的矩形,此时障碍也包括边界。 3.方法:悬线法考虑以…

01 从0开始搭建django环境

1 安装相关版本的django,这里,我以5.1.1为例子 pip3 install django5.1.1 (.venv) D:\DjangoCode\MS>pip3 install django5.1.1 Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple Collecting django5.1.1Using cached https://pypi.t…

STM32定时器(TIM)

目录 一、概述 二、定时器的类型 三、时序 四、定时器中断基本结构 五、定时器定时中断代码 六、定时器外部时钟代码 一、概述 TIM(Timer)定时器 定时器可以对输入的时钟进行计数,并在计数值达到设定值时触发中断16位计数器、预分频器、自动重装寄存器的时基…

TM1618数码管控制芯片使用共阳极数码管过程中的问题和解决办法

控制芯片的基本了解 相比于不用控制芯片的电路:这里带2根电源线和3个信号线,共使用了5根线,但可以控制4个8段数码管显示。若是电路直接控制4个8段数码管需要84113个接口,这对于MCU的珍贵引脚简直是浪费。 这里不会出现余晖效应也…

Python编程常用的35个经典案例

Python 的简洁和强大使其成为许多开发者的首选语言。本文将介绍35个常用的Python经典代码案例。这些示例覆盖了基础语法、常见任务、以及一些高级功能。 1.列表推导式 这个例子展示了列表推导式,用于生成FizzBuzz序列。 fizz_buzz_list ["FizzBuzz" i…

ultralytics yolo pose 示例:加载官方pose模型进行推理

Ultralytics YOLO 是计算机视觉和 ML 领域专业人士的高效工具。 安装 ultralytics 库: pip install ultralytics 官方YoLo Pose 模型列表信息: 实现代码如下: from ultralytics import YOLO import cv2 # Load a model ckpt_dir "…

HTB:Ignition[WriteUP]

目录 连接至HTB服务器并启动靶机 1.Which service version is found to be running on port 80? 2.What is the 3-digit HTTP status code returned when you visit http://{machine IP}/? 3.What is the virtual host name the webpage expects to be accessed by? 4.…

详细解释:前向传播、反向传播等

详细解释:前向传播、反向传播等 在机器学习和深度学习中,**前向传播(Forward Propagation)和反向传播(Backward Propagation)**是训练神经网络的两个核心过程。理解这两个概念对于掌握神经网络的工作原理、优化方法以及模型微调技术(如LoRA、P-tuning等)至关重要。以下…

机器人技术基础(1-3章坐标变换)

位置矢量的意思是B坐标系的原点O相对于A坐标系的平移变换后的矩阵: 齐次坐标最后一个数表示缩放倍数: 左边的是T形变换矩阵,右边的是需要被变换的矩阵:T形变换矩阵的左上角表示旋转,右上角表示平移,左下角最…

好用且不伤眼镜的超声波清洗机排名!谁才是清洁小能手?

对于经常佩戴眼镜的人来说,眼镜的日常清洁保养极为关键。传统清洁方式可能导致镜片刮花和残留污渍,鉴于此,眼镜专用的超声波清洗机应运而生,利用超声振动技术深入微细缝隙,彻底扫除污垢与油脂,保护镜片免受…

JavaEE: 数据链路层的奇妙世界

文章目录 数据链路层以太网源地址和目的地址 类型数据认识 MTU 数据链路层 以太网 以太网的帧格式如下所示: 源地址和目的地址 源地址和目的地址是指网卡的硬件地址(也叫MAC地址). mac 地址和 IP 地址的区别: mac 地址使用6个字节表示,IP 地址4个字节表示. 一般一个网卡,在…

Unity3D 单例模式

Unity3D 泛型单例 单例模式 单例模式是一种创建型设计模式,能够保证一个类只有一个实例,提供访问实例的全局节点。 通常会把一些管理类设置成单例,例如 GameManager、UIManager 等,可以很方便地使用这些管理类单例,…

BM1 反转链表

要求 代码 /*** struct ListNode {* int val;* struct ListNode *next;* };*/ /*** 代码中的类名、方法名、参数名已经指定,请勿修改,直接返回方法规定的值即可*** param head ListNode类* return ListNode类*/ struct ListNode* ReverseList(struct …

从零开始讲PCIe(10)——事务层介绍

一、事务层概述 事务层在响应软件层的请求时,会生成出站数据包。同时,它也会检查入站数据包,并将其中包含的信息传递到软件层。事务层支持非发布事务的分割事务协议,能够将入站的完成数据包与之前传输的非发布请求相关联。该层处理…

After-kaoyan

知乎 - 安全中心 有态度,有回应,有温度,是跟双鱼相处的基础 我今天跟大家泄漏一个秘密,这个秘密也很简单,就是我每次遇到困难险阻时候我从不退缩,我也不会想着:“算了吧,我做不到&a…

C/C++/EasyX——入门图形编程(5)

【说明】友友们好,今天来讲一下键盘消息函数。(其实这个本来准备和鼠标消息函数放在一起的,但是上一篇三个放在一起,内容就有点多了,只写一个又太单调了,所以键盘消息函数的内容就放在这一篇了 (^&#xff…

用manim实现Gram-Schmidt正交化过程

在线性代数中,正交基有许多美丽的性质。例如,由正交列向量组成的矩阵(又称正交矩阵)可以通过矩阵的转置很容易地进行反转。此外,例如:在由彼此正交的向量张成的子空间上投影向量也更容易。Gram-Schmidt过程是一个重要的算法&#…

Oracle 表空间异构传输

已经有了表空间的数据文件,和元数据dump文件,如何把这个表空间传输到异构表空间中? 查询异构传输平台信息: COLUMN PLATFORM_NAME FORMAT A40 SELECT PLATFORM_ID, PLATFORM_NAME, ENDIAN_FORMAT FROM V$TRANSPORTABLE_PLATFORM O…