【深度学习】LSTM、BiLSTM详解

文章目录

  • 1. LSTM简介:
  • 2. LSTM结构图:
  • 3. 单层LSTM详解
  • 4. 双层LSTM详解
  • 5. BiLSTM
  • 6. Pytorch实现LSTM示例
  • 7. nn.LSTM参数详解

1. LSTM简介:

    LSTM是一种循环神经网络,它可以处理和预测时间序列中间隔和延迟相对较长的重要事件。LSTM通过使用门控单元来控制信息的流动,从而缓解RNN中的梯度消失和梯度爆炸的问题。LSTM的核心是三个门:输入门遗忘门输出门

遗忘门: 遗忘门的作用是决定哪些信息从记忆单元中遗忘,它使用sigmoid激活函数,可以输出在0到1之间的值,可以理解为保留信息的比例。
输入门: 作用是决定哪些新信息被存储在记忆单元中
输出门: 输出门决定了下一个隐藏状态,即生成当前时间步的输出并传递到下一时间步
记忆单元:负责长期信息的存储,通过遗忘门和输入门的相互作用,记忆单元能够学习如何选择性地记住或忘记信息

2. LSTM结构图:

在这里插入图片描述

涉及到的计算公式如下:
在这里插入图片描述

3. 单层LSTM详解

(1)设定有3个字的序列【“早”“上”“好”】要经过LSTM处理,每个序列由20个元素组成的列向量构成,所以input size就为20。

(2)设定全连接层中有100个隐藏单元,LSTM的层数为1。

(3)因为是3个字的序列,所以LSTM需要3个时间步(即会自循环3次)才能处理完这个序列。

(4)nn.LSTM()每层也可以拆开写,这样每层的隐藏单元个数就可以分别设定。

在这里插入图片描述
    LSTM单元包含三个输入参数x、c、h;首先t1时刻作为第一个时间步,输入到第一个LSTM单元中,此时输入的初始从c(0)和h(0)都是0矩阵,计算完成后,第一个LSTM单元输出一组h(t1)\c(t1),作为本层LSTM的第二个时间步的输入参数;因此第二个时间步的输入就是h(t1),c(t1),x(t2),而输出是h(t2),c(t2);因此第三个时间步的输入就是h(t2),c(t2),x(t3),而输出是h(t3),c(t3)。

4. 双层LSTM详解

(1)设定有3个字的序列【“早”“上”“好”】要经过LSTM处理,每个序列由20个元素组成的列向量构成,所以input size就为20。

(2)设定全连接层中有100个隐藏单元,LSTM的层数为2。

(3)因为是3个字的序列,所以LSTM需要3个时间步(即会自循环3次)才能处理完这个序列。

(4)nn.LSTM()每层也可以拆开写,这样每层的隐藏单元个数就可以分别设定。

在这里插入图片描述
    第二层LSTM没有输入参数x(t1)、x(t2)、x(t3);所以我们将第一层LSTM输出的h(t1)、h(t2)、h(t3)作为第二层LSTM的输入x(t1)、x(t2)、x(t3)。第一个时间步输入的初始c(0)和h(0)都为0矩阵,计算完成后,第一个时间步输出新的一组h(t1)、c(t1),作为本层LSTM的第二个时间步的输入参数;因此第二个时间步的输入就是h(t1),c(t1),x(t2),而输出是h(t2),c(t2);因此第三个时间步的输入就是h(t2),c(t2),x(t3),而输出是h(t3),c(t3)。

5. BiLSTM

单层的BiLSTM其实就是2个LSTM,一个正向去处理序列,一个反向去处理序列,处理完后,两个LSTM的输出会拼接起来。
在这里插入图片描述

6. Pytorch实现LSTM示例

import torch 
import torch.nn as nndevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class LSTM(nn.Module):def __init__(self, input_dim, hidden_dim, num_layers, output_dim):super(LSTM, self).__init__()self.hidden_dim = hidden_dim  # 隐藏层维度self.num_layers = num_layers  # LSTM层的数量# LSTM网络层self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)# 全连接层,用于将LSTM的输出转换为最终的输出维度self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):# 初始化隐藏状态和细胞状态h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)# 前向传播LSTM,返回输出和最新的隐藏状态与细胞状态out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))# 将LSTM的最后一个时间步的输出通过全连接层out = self.fc(out[:, -1, :])return out

7. nn.LSTM参数详解

pytorch官方定义:

CLASS torch.nn.LSTM(
        input_size,
        hidden_size,
        num_layers=1,
        bias=True,
        batch_first=False,
        dropout=0.0,
        bidirectional=False,
        proj_size=0,
        device=None,
        dtype=None
    )

input_size – 输入 x 中预期的特征数量
hidden_size – 隐藏状态 h 中的特征数量
num_layers – 循环层的数量。例如,设置 num_layers=2 表示将两个 LSTM 堆叠在一起形成一个 stacked LSTM,其中第二个 LSTM 接收第一个 LSTM 的输出并计算最终结果。默认值:1
bias – 如果 False,则该层不使用偏差权重 b_ih 和 b_hh。默认值:True
batch_first – 如果 True,则输入和输出张量将以 (batch, seq, feature) 而不是 (seq, batch, feature) 的形式提供。请注意,这并不适用于隐藏状态或单元状态。有关详细信息,请参见下面的输入/输出部分。默认值:False
dropout – 如果非零,则在除最后一层之外的每个 LSTM 层的输出上引入一个 Dropout 层,其 dropout 概率等于 dropout。默认值:0
bidirectional – 如果 True,则变为双向 LSTM。默认值:False
proj_size – 如果 > 0,则将使用具有相应大小的投影的 LSTM。默认值:0

对于输入序列每一个元素,每一层都会进行以下计算:
在这里插入图片描述
网络输入:
在这里插入图片描述

网络输出:
在这里插入图片描述

本文参考:https://blog.csdn.net/qq_34486832/article/details/134898868
https://pytorch.ac.cn/docs/stable/generated/torch.nn.LSTM.html#

LSTM每层的输出都要经过全连接层吗,还是直接对隐藏层进行输出?
通过在代码中对lstm的输出进行print输出:

import torch 
import torch.nn as nndevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")class LSTM(nn.Module):def __init__(self, input_dim, hidden_dim, num_layers, output_dim):super(LSTM, self).__init__()self.hidden_dim = hidden_dim  # 隐藏层维度self.num_layers = num_layers  # LSTM层的数量# LSTM网络层self.lstm = nn.LSTM(input_dim, hidden_dim, num_layers, batch_first=True)# 全连接层,用于将LSTM的输出转换为最终的输出维度self.fc = nn.Linear(hidden_dim, output_dim)def forward(self, x):# 初始化隐藏状态和细胞状态h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_dim).to(device)# 前向传播LSTM,返回输出和最新的隐藏状态与细胞状态out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))print(out)print(hn)print(cn)# 将LSTM的最后一个时间步的输出通过全连接层out = self.fc(out[:, -1, :])return out
if __name__ == "__main__":input_dim = 3        # 输入特征的维度hidden_dim = 4       # 隐藏层的维度num_layers = 1       # LSTM 层的数量output_dim = 1       # 输出特征的维度lstm = LSTM(input_dim, hidden_dim, num_layers, output_dim).to(device)batch_size = 1seq_length = 10input_tensor = torch.randn(batch_size, seq_length, input_dim).to(device)output = lstm(input_tensor)

通过对LSTM网络的输出我们可以看到,out的最后一层与最后一层隐藏层hn一致,说明并未经过全连接层,而是直接输出隐藏层
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/13034.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

PyQt5 加载UI界面与资源文件

步骤一: 使用 Qt Designer 创建 XXX.ui文件 步骤二: 使用 Qt Designer 创建 资源文件 步骤三: Python文件中创建相关类, 使用 uic.loadUi(mainwidget.ui, self ) 加载UI文件 import sys from PyQt5 import QtCore, QtWidgets, uic from PyQt5.QtCore import Qt f…

ENSP作业——小型园区网

题目 根据上图,可得需求为: 1.配置交换机上的VLAN及IP地址。 2.设置SW1为VLAN 2/3的主根桥,设置SW2为VLAN 20/30的主根桥,且两台交换机互为主备。 3.可以使用super vlan。(本次实验中未使用) 4.上层通过静…

计算机网络:运输层 —— 运输层端口号

文章目录 运输层端口号的分类端口号与应用程序的关联应用举例发送方的复用和接收方的分用 运输层端口号的分类 端口号只具有本地意义,即端口号只是为了标识本计算机网络协议栈应用层中的各应用进程。在因特网中不同计算机中的相同端口号是没有关系的,即…

【C++练习】使用C++编写程序计算π的近似值

题目:使用C编写程序计算π的近似值 描述: 编写一个C程序,使用一个特定的数学公式来计算圆周率(π)的近似值。该程序定义了一个函数calculatePi(),该函数通过一个迭代算法和一个涉及反正切函数(…

Hook小程序

下载: https://github.com/JaveleyQAQ/WeChatOpenDevTools-Python 配置: pip install -r requirements 实现: 开启小程序开发者模式,类似浏览器F12 效果: 使用: 退出微信,进入安装的目录…

如何在pycharm中 判断是否成功安装pytorch环境

1、在电脑开始端,找到 2、打开后 在base环境下 输入conda env list 目前我的环境中没有pytorch 学习视频:【Anaconda、Pytorch、Pycharm到底是什么关系?什么是环境?什么是包?】https://www.bilibili.com/video/BV1CN411s7Ue?vd_sourcefad0750b8c6…

AI陪伴走热,网易云信“融合通讯+AI”新方案发布!附场景App及源码

信息秒回、不会失联、724h 情感陪伴、随时提供情绪价值......在 AI 能力越来越强大的今天,我们开始有了“AI 助手”、“AI 搭子”,甚至开始谈起“AI 男友/女友”,AI 的角色早已超越了简单的生产力工具,它正深入到我们生活的方方面…

力扣 LeetCode 704. 二分查找(Day1:数组)

解题思路: 二分查找主要分为[ left , right ]左闭右闭和[ left , right )左闭右开两种 此处采取[ left , right ]左闭右闭写法 注意: 1. right的初始化取值 2. while中取等 3. right mid -1 ; class Solution {public int search(int[] nums, i…

java-AOP编程示例

SpringBoot工程,有不懂的留言or Kimi一下 LogAspect.java package com.xxx.javaaopdemo.Aspect;import com.xxx.javaaopdemo.LogAnnotation; import org.aspectj.lang.ProceedingJoinPoint; import org.aspectj.lang.annotation.Around; import org.aspectj.lang…

Kafka入门:Java客户端库的使用

在现代的分布式系统中,消息队列扮演着至关重要的角色,而Apache Kafka以其高吞吐量、可扩展性和容错性而广受欢迎。本文将带你了解如何使用Kafka的Java客户端库来实现生产者(Producer)和消费者(Consumer)的基…

使用 npm 安装 Yarn

PS E:\WeChat Files\wxid_fipwhzebc1yh22\FileStorage\File\2024-11\spid-admin\spid-admin> yarn install yarn : 无法将“yarn”项识别为 cmdlet、函数、脚本文件或可运行程序的名称。请检查名称的拼写,如果包括路径,请确保路径正确,然后…

力扣617:合并二叉树

给你两棵二叉树: root1 和 root2 。 想象一下,当你将其中一棵覆盖到另一棵之上时,两棵树上的一些节点将会重叠(而另一些不会)。你需要将这两棵树合并成一棵新二叉树。合并的规则是:如果两个节点重叠&#…

谷歌浏览器支持的开发者工具详解

谷歌浏览器(Google Chrome)是全球最受欢迎的网页浏览器之一,它不仅提供了快速、安全的浏览体验,还为开发者提供了强大的开发者工具。本文将详细介绍如何使用谷歌浏览器的开发者工具,并解答一些常见问题。(本…

HTB:OpenAdmin[WriteUP]

目录 连接至HTB服务器并启动靶机 使用nmap对靶机TCP端口进行开放扫描 继续使用nmap对靶机22、80端口进行脚本、服务扫描 使用Dirbuster对靶机网页路径进行递归扫描 ​编辑 尝试在searchsploit中搜索改WebAPP漏洞 横向移动(其实没有横) 启动Metasploit 特权提升 USER_…

IO作业5

设置双信号实现交替生产者线程和消费者线程 #include <myhead.h> int n0; pthread_mutex_t fastmutex;//定义互斥锁 pthread_cond_t cond;//定义条件变量 pthread_cond_t cond2; void *product() {int i;for(i0;i<10;i){n;printf("我生产了一台特斯拉n%d\n"…

Web3.0安全开发实践|BNB链安全开发,这10大实用tips一定要收藏

BNB Chain是Web3世界中最受欢迎的区块链之一&#xff0c;其费用合理、交易迅速以及项目生态系统丰富几大原因吸引了广大用户。与任何的区块链都一样&#xff0c;BNB Chain上的开发者在开发过程中首先考虑的应该是安全问题&#xff1a;因为任何资金的损失都会导致用户对协议及平…

QT开发笔记之小知识

QCoreApplication::aboutToQuit 主事件循环退出前发出的信号&#xff0c;是程序退出前等待QT线程退出回收资源的神器。 官方帮助文档 [signal] void QCoreApplication::aboutToQuit() 该信号在应用程序即将退出主事件循环时发出&#xff0c;例如&#xff1a;当事件循环级别降至…

插入排序(C语言)

直接插入排序的基本思想&#xff1a;把待排序的记录按其关键码值的大小逐个插入到一个已经排好序的有序序列中&#xff0c;直到所有的记录插入完为止&#xff0c;得到一个新的有序序列 。 一、步骤 1.给定一个乱序的数组&#xff0c;如 从第一个元素开始排序&#xff0c;当只…

文心一言 VS 讯飞星火 VS chatgpt (389)-- 算法导论25.1 2题

二、为什么要求对于所有的 1 ⩽ i ⩽ n 1⩽i⩽n 1⩽i⩽n&#xff0c;有 w i i 0 w_{ii}0 wii​0 &#xff1f;如果要写代码&#xff0c;请用go语言。 文心一言&#xff1a; 在许多数学和计算应用中&#xff0c;要求矩阵 W W W 的对角线元素 w i i 0 w_{ii} 0 wii​0 是…

Java多线程详解⑦(全程干货!!!)内存可见性 || volatile || JMM || wait notify notifyAll

这里是Themberfue 在上一节中&#xff0c;我们讨论了死锁的概念&#xff0c;产生的场景 &#xff0c;产生的必要条件...... 内存可见性 我们先来看一段百度百科关于 "内存可见性" 的解释 "内存可见性" 就是造成线程安全问题的原因之一 如果是单纯地看…