57 长短期记忆网络(LSTM)_by《李沐:动手学深度学习v2》pytorch版

系列文章目录


文章目录


长短期记忆网络(LSTM)

长期以来,隐变量模型存在着长期信息保存和短期输入缺失的问题。
解决这一问题的最早方法之一是长短期存储器(long short-term memory,LSTM)它有许多与门控循环单元(GRU)一样的属性。有趣的是,长短期记忆网络的设计比门控循环单元稍微复杂一些,却比门控循环单元早诞生了近20年。

门控记忆元

可以说,长短期记忆网络的设计灵感来自于计算机的逻辑门。
长短期记忆网络引入了记忆元(memory cell),或简称为单元(cell)。
有些文献认为记忆元是隐状态的一种特殊类型,它们与隐状态具有相同的形状,其设计目的是用于记录附加的信息。
为了控制记忆元,我们需要许多门。
其中一个门用来从单元中输出条目,我们将其称为输出门(output gate)。
另外一个门用来决定何时将数据读入单元,我们将其称为输入门(input gate)。
我们还需要一种机制来重置单元的内容,由遗忘门(forget gate)来管理,这种设计的动机与门控循环单元相同,能够通过专用机制决定什么时候记忆或忽略隐状态中的输入。让我们看看这在实践中是如何运作的。

输入门、忘记门和输出门

就如在门控循环单元中一样,当前时间步的输入和前一个时间步的隐状态作为数据送入长短期记忆网络的门中,如下图所示。它们由三个具有sigmoid激活函数的全连接层处理,以计算输入门、遗忘门和输出门的值。
因此,这三个门的值都在 ( 0 , 1 ) (0, 1) (0,1)的范围内。

在这里插入图片描述label:lstm_0

我们来细化一下长短期记忆网络的数学表达。
假设有 h h h个隐藏单元,批量大小为 n n n,输入数为 d d d
因此,输入为 X t ∈ R n × d \mathbf{X}_t \in \mathbb{R}^{n \times d} XtRn×d,前一时间步的隐状态为 H t − 1 ∈ R n × h \mathbf{H}_{t-1} \in \mathbb{R}^{n \times h} Ht1Rn×h
相应地,时间步 t t t的门被定义如下:
输入门是 I t ∈ R n × h \mathbf{I}_t \in \mathbb{R}^{n \times h} ItRn×h
遗忘门是 F t ∈ R n × h \mathbf{F}_t \in \mathbb{R}^{n \times h} FtRn×h
输出门是 O t ∈ R n × h \mathbf{O}_t \in \mathbb{R}^{n \times h} OtRn×h
它们的计算方法如下:

I t = σ ( X t W x i + H t − 1 W h i + b i ) , F t = σ ( X t W x f + H t − 1 W h f + b f ) , O t = σ ( X t W x o + H t − 1 W h o + b o ) , \begin{aligned} \mathbf{I}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xi} + \mathbf{H}_{t-1} \mathbf{W}_{hi} + \mathbf{b}_i),\\ \mathbf{F}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xf} + \mathbf{H}_{t-1} \mathbf{W}_{hf} + \mathbf{b}_f),\\ \mathbf{O}_t &= \sigma(\mathbf{X}_t \mathbf{W}_{xo} + \mathbf{H}_{t-1} \mathbf{W}_{ho} + \mathbf{b}_o), \end{aligned} ItFtOt=σ(XtWxi+Ht1Whi+bi),=σ(XtWxf+Ht1Whf+bf),=σ(XtWxo+Ht1Who+bo),

其中 W x i , W x f , W x o ∈ R d × h \mathbf{W}_{xi}, \mathbf{W}_{xf}, \mathbf{W}_{xo} \in \mathbb{R}^{d \times h} Wxi,Wxf,WxoRd×h W h i , W h f , W h o ∈ R h × h \mathbf{W}_{hi}, \mathbf{W}_{hf}, \mathbf{W}_{ho} \in \mathbb{R}^{h \times h} Whi,Whf,WhoRh×h是权重参数, b i , b f , b o ∈ R 1 × h \mathbf{b}_i, \mathbf{b}_f, \mathbf{b}_o \in \mathbb{R}^{1 \times h} bi,bf,boR1×h是偏置参数。

候选记忆元 (相当于RNN中计算 H t H_t Ht)

由于还没有指定各种门的操作,所以先介绍候选记忆元(candidate memory cell) C ~ t ∈ R n × h \tilde{\mathbf{C}}_t \in \mathbb{R}^{n \times h} C~tRn×h
它的计算与上面描述的三个门的计算类似,但是使用 tanh ⁡ \tanh tanh函数作为激活函数,函数的值范围为 ( − 1 , 1 ) (-1, 1) (1,1)
下面导出在时间步 t t t处的方程:

C ~ t = tanh ( X t W x c + H t − 1 W h c + b c ) , \tilde{\mathbf{C}}_t = \text{tanh}(\mathbf{X}_t \mathbf{W}_{xc} + \mathbf{H}_{t-1} \mathbf{W}_{hc} + \mathbf{b}_c), C~t=tanh(XtWxc+Ht1Whc+bc),

其中 W x c ∈ R d × h \mathbf{W}_{xc} \in \mathbb{R}^{d \times h} WxcRd×h W h c ∈ R h × h \mathbf{W}_{hc} \in \mathbb{R}^{h \times h} WhcRh×h是权重参数, b c ∈ R 1 × h \mathbf{b}_c \in \mathbb{R}^{1 \times h} bcR1×h是偏置参数。

候选记忆元的如下图 :numref:lstm_1所示。

在这里插入图片描述label:lstm_1

记忆元

在门控循环单元中,有一种机制来控制输入和遗忘(或跳过)。
类似地,在长短期记忆网络中,也有两个门用于这样的目的:
输入门 I t \mathbf{I}_t It控制采用多少来自 C ~ t \tilde{\mathbf{C}}_t C~t的新数据,而遗忘门 F t \mathbf{F}_t Ft控制保留多少过去的记忆元 C t − 1 ∈ R n × h \mathbf{C}_{t-1} \in \mathbb{R}^{n \times h} Ct1Rn×h的内容。
使用按元素乘法,得出:

C t = F t ⊙ C t − 1 + I t ⊙ C ~ t . \mathbf{C}_t = \mathbf{F}_t \odot \mathbf{C}_{t-1} + \mathbf{I}_t \odot \tilde{\mathbf{C}}_t. Ct=FtCt1+ItC~t.

如果遗忘门始终为 1 1 1且输入门始终为 0 0 0,则过去的记忆元 C t − 1 \mathbf{C}_{t-1} Ct1将随时间被保存并传递到当前时间步。
引入这种设计是为了缓解梯度消失问题,并更好地捕获序列中的长距离依赖关系。

这样我们就得到了计算记忆元的流程图,如 :numref:lstm_2

在这里插入图片描述label:lstm_2

隐状态

最后,我们需要定义如何计算隐状态 H t ∈ R n × h \mathbf{H}_t \in \mathbb{R}^{n \times h} HtRn×h,这就是输出门发挥作用的地方。
在长短期记忆网络中,它仅仅是记忆元的 tanh ⁡ \tanh tanh的门控版本。
这就确保了 H t \mathbf{H}_t Ht的值始终在区间 ( − 1 , 1 ) (-1, 1) (1,1)内:

H t = O t ⊙ tanh ⁡ ( C t ) . \mathbf{H}_t = \mathbf{O}_t \odot \tanh(\mathbf{C}_t). Ht=Ottanh(Ct).

只要输出门接近 1 1 1,我们就能够有效地将所有记忆信息传递给预测部分,而对于输出门接近 0 0 0,我们只保留记忆元内的所有信息,而不需要更新隐状态(相当于重置隐状态)。

下图 :numref:lstm_3提供了数据流的图形化演示。

在这里插入图片描述label:lstm_3
在这里插入图片描述

从零开始实现

现在,我们从零开始实现长短期记忆网络。我们首先加载时光机器数据集。

import torch
from torch import nn
from d2l import torch as d2lbatch_size, num_steps = 32, 35
train_iter, vocab = d2l.load_data_time_machine(batch_size, num_steps)

初始化模型参数

接下来,我们需要定义和初始化模型参数。
如前所述,超参数num_hiddens定义隐藏单元的数量。
我们按照标准差 0.01 0.01 0.01的高斯分布初始化权重,并将偏置项设为 0 0 0

def get_lstm_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device)*0.01def three():return (normal((num_inputs, num_hiddens)),normal((num_hiddens, num_hiddens)),torch.zeros(num_hiddens, device=device))W_xi, W_hi, b_i = three()  # 输入门参数W_xf, W_hf, b_f = three()  # 遗忘门参数W_xo, W_ho, b_o = three()  # 输出门参数W_xc, W_hc, b_c = three()  # 候选记忆元参数# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc,b_c, W_hq, b_q]for param in params:param.requires_grad_(True)return params

定义模型

在[初始化函数]中,长短期记忆网络的隐状态需要返回一个额外的记忆元,单元的值为0,形状为(批量大小,隐藏单元数)。因此,我们得到以下的状态初始化。

def init_lstm_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device),torch.zeros((batch_size, num_hiddens), device=device))

[实际模型]的定义与我们前面讨论的一样:
提供三个门和一个额外的记忆元。
请注意,只有隐状态才会传递到输出层,而记忆元 C t \mathbf{C}_t Ct不直接参与输出计算。

def lstm(inputs, state, params):[W_xi, W_hi, b_i, W_xf, W_hf, b_f, W_xo, W_ho, b_o, W_xc, W_hc, b_c,W_hq, b_q] = params(H, C) = stateoutputs = []for X in inputs:I = torch.sigmoid((X @ W_xi) + (H @ W_hi) + b_i)F = torch.sigmoid((X @ W_xf) + (H @ W_hf) + b_f)O = torch.sigmoid((X @ W_xo) + (H @ W_ho) + b_o)C_tilda = torch.tanh((X @ W_xc) + (H @ W_hc) + b_c)C = F * C + I * C_tildaH = O * torch.tanh(C)Y = (H @ W_hq) + b_qoutputs.append(Y) #Y的shape是(批量大小,词表长度)只有这里输出了批量大小的预测,之后才能用来计算损失return torch.cat(outputs, dim=0), (H, C)

训练和预测

让我们通过实例化RNN从零实现中引入的RNNModelScratch类来训练一个长短期记忆网络,就如我们在GRU中所做的一样。

vocab_size, num_hiddens, device = len(vocab), 256, d2l.try_gpu()
num_epochs, lr = 500, 1
model = d2l.RNNModelScratch(len(vocab), num_hiddens, device, get_lstm_params,init_lstm_state, lstm)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 14.5, 27965.3 tokens/sec on cuda:0
time traveller te at at at at at at at at at at at at at at at a
traveller te at at at at at at at at at at at at at at at a<Figure size 350x250 with 1 Axes>

在这里插入图片描述

简洁实现

使用高级API,我们可以直接实例化LSTM模型。
高级API封装了前文介绍的所有配置细节。
这段代码的运行速度要快得多,因为它使用的是编译好的运算符而不是Python来处理之前阐述的许多细节。

num_inputs = vocab_size
lstm_layer = nn.LSTM(num_inputs, num_hiddens)
model = d2l.RNNModel(lstm_layer, len(vocab))
model = model.to(device)
d2l.train_ch8(model, train_iter, vocab, lr, num_epochs, device)
perplexity 11.2, 233619.5 tokens/sec on cuda:0
time traveller the the the the the the the the the the the the t
traveller the the the the the the the the the the the the t<Figure size 350x250 with 1 Axes>

在这里插入图片描述
长短期记忆网络是典型的具有重要状态控制的隐变量自回归模型。
多年来已经提出了其许多变体,例如,多层、残差连接、不同类型的正则化。
然而,由于序列的长距离依赖性,训练长短期记忆网络和其他序列模型(例如门控循环单元)的成本是相当高的。
在后面的内容中,我们将讲述更高级的替代模型,如Transformer。

小结

  • 长短期记忆网络有三种类型的门:输入门、遗忘门和输出门。
  • 长短期记忆网络的隐藏层输出包括“隐状态”和“记忆元”。只有隐状态会传递到输出层,而记忆元完全属于内部信息。
  • 长短期记忆网络可以缓解梯度消失和梯度爆炸。

练习

  1. 调整和分析超参数对运行时间、困惑度和输出顺序的影响。
  2. 如何更改模型以生成适当的单词,而不是字符序列?
  3. 在给定隐藏层维度的情况下,比较门控循环单元、长短期记忆网络和常规循环神经网络的计算成本。要特别注意训练和推断成本。
  4. 既然候选记忆元通过使用 tanh ⁡ \tanh tanh函数来确保值范围在 ( − 1 , 1 ) (-1,1) (1,1)之间,那么为什么隐状态需要再次使用 tanh ⁡ \tanh tanh函数来确保输出值范围在 ( − 1 , 1 ) (-1,1) (1,1)之间呢?
  5. 实现一个能够基于时间序列进行预测而不是基于字符序列进行预测的长短期记忆网络模型。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1549603.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

QT基础 制作简单登录界面

作业&#xff1a; 1、创建一个新项目&#xff0c;将默认提供的程序都注释上意义 01zy.pro代码 QT core gui # QT表示要引入的类库 core&#xff1a;核心库例如IO操作在该库中 gui&#xff1a;图形化界面库 # 如果要使用其他类库中的相关函数&#xff0c;则需要加对…

【深度学习】—线性回归 线性回归的基本元素 线性模型 损失函数 解析解 随机梯度下降

【深度学习】— 线性回归线性回归的基本元素 线性模型损失函数解析解随机梯度下降小批量随机梯度下降梯度下降算法的详细步骤解释公式 线性回归 回归&#xff08;regression&#xff09;是能为⼀个或多个⾃变量与因变量之间关系建模的⼀类⽅法。在⾃然科学和社会科学领域&…

正点原子——DS100示波器操作手册

目录 基础按键&#xff1a; 快捷键 主界面&#xff1a; 垂直设置&#xff1a; 通道设置&#xff1a; 探头比列&#xff1a; 垂直档位&#xff1a; 垂直偏移&#xff1a; 幅度单位&#xff1a; 水平设置&#xff1a; 触发方式&#xff1a; 测量和运算: 光标测量&am…

如何用好通义灵码企业知识库问答能力?

通义灵码企业版&#xff1a;通义灵码企业标准版快速入门_智能编码助手_AI编程_智能编码助手通义灵码(Lingma)-阿里云帮助中心 通义灵码提供了基于企业知识库的问答检索增强的能力&#xff0c;在开发者使用通义灵码 IDE 插件时&#xff0c;可以结合企业知识库内上传的文档、文件…

《深度学习》【项目】OpenCV 发票识别 透视变换、轮廓检测解析及案例解析

目录 一、透视变换 1、什么是透视变换 2、操作步骤 1&#xff09;选择透视变换的源图像和目标图像 2&#xff09;确定透视变换所需的关键点 3&#xff09;计算透视变换的变换矩阵 4&#xff09;对源图像进行透视变换 5&#xff09;对变换后的图像进行插值处理 二、轮廓检测…

YOLOv8改进,YOLOv8主干网络替换为GhostNetV3(2024年华为提出的轻量化架构,全网首发),助力涨点

摘要 GhostNetV3 是由华为诺亚方舟实验室的团队发布的,于2024年4月发布。 摘要:紧凑型神经网络专为边缘设备上的应用设计,具备更快的推理速度,但性能相对适中。然而,紧凑型模型的训练策略目前借鉴自传统模型,这忽略了它们在模型容量上的差异,可能阻碍紧凑型模型的性能…

【d53】【Java】【力扣】24.两两交换链表中的节点

思路 定义一个指针cur, 先指向头节点&#xff0c; 1.判断后一个节点是否为空&#xff0c;不为空则交换值&#xff0c; 2.指针向后走两次 代码 /*** Definition for singly-linked list.* public class ListNode {* int val;* ListNode next;* ListNode() {}*…

Java_集合_双列集合_Map

第一章Map集合 Map是双列集合顶级接口 什么叫做双列集合:一个元素有两部分构成:key和value -> 键值对 1.1.HashMap 常用方法: V put(K key, V value) -> 添加元素,返回的是被替换的value值 V remove(Object key) ->根据key删除键值对,返回的是被删除的value…

Codeforces Round 975 (Div. 1) D. Max Plus Min Plus Size(思维题 并查集/动态dp 线段树维护状态合并)

题目 思路来源 hhoppitree代码 官方题解 题解 注意到最大值一定会被取到&#xff0c; 对于最小值固定的话&#xff0c;对于1 2 3 4 5的连续段&#xff0c;要么贪心地取1 3 5&#xff0c;要么取2 4 如果最大值被包含在1 3 5里显然取1 3 5&#xff0c;否则换成2 4一定能取到…

亚马逊爆款三明治封口器发明专利维权,恐涉及大量卖家,速查

案件基本情况&#xff1a;起诉时间&#xff1a;2024-9-18案件号&#xff1a;2024-cv-08606原告&#xff1a;Jetteo, LLC原告律所&#xff1a;AVEK IP, LLC起诉地&#xff1a;伊利诺伊州北部法院涉案商标/版权&#xff1a;原告品牌简介&#xff1a;Jetteo&#xff0c;LLC&#x…

蜂鸟bebirdt15、西圣find、泰视朗可视挖耳勺好用吗?测评数据对比看这里

可视挖耳勺在当下已经被广泛使用&#xff0c;不过对于新手来说&#xff0c;选择一款优质产品却并不容易。蜂鸟t15、西圣find、泰视朗可视挖耳勺好用吗&#xff1f;作为一个测评博主&#xff0c;近期有不少用户问我这个问题。 根据目前市场上可视挖耳勺的品牌情况来看&#xff0…

A股突破3000,连续大涨,公司国庆假放10天

关注▲洋洋科创星球▲一起成长&#xff01; 庆祝A股突破3000&#xff0c;连续大涨&#xff0c;也不知道老板抽了什么风&#xff0c;公司今天开始放国庆假了&#xff0c;连休10天&#xff0c;哈哈哈哈哈哈。 27号开始放国庆假&#xff0c;连休10&#xff0c;刺激。 中秋国庆这一…

【C++】继承,菱形继承,虚拟继承,组合详解

目录 1. 继承概念与定义 1.1 概念 1.2 定义 2. 父类与子类的赋值规则 3. 继承的作用域 4. 子类的默认成员函数 5. 继承与友元 6. 继承与静态成员 7. 菱形继承 7.1 继承关系 7.2 菱形继承的问题 7.3 虚拟继承 8. 继承与组合 1. 继承概念与定义 1.1 概念 1. 继承&a…

论文速递 | Management Science 8月文章合集

编者按 在本系列文章中&#xff0c;我们对顶刊《Management Science》于8月份发布文章中进行了精选&#xff08;共9篇&#xff09;&#xff0c;并总结其基本信息&#xff0c;旨在帮助读者快速洞察行业最新动态。 推荐文章1 ● 题目&#xff1a;Optimal Mechanism Design with …

红队信息搜集扫描使用

红队信息搜集扫描使用 红队行动中需要工具化一些常用攻击&#xff0c;所以学习一下 nmap 等的常规使用&#xff0c;提供灵感 nmap 帮助 nmap --help主机扫描 Scan and no port scan&#xff08;扫描但不端口扫描&#xff09;。-sn 在老版本中是 -sP&#xff0c;P的含义是 P…

基于SPI协议的Flash驱动控制

1、理论知识 SPI&#xff08;Serial Peripheral Interface&#xff0c;串行外围设备接口&#xff09;通讯协议&#xff0c;是Motorola公司提出的一种同步串行接口技术&#xff0c;是一种高速、全双工、同步通信总线&#xff0c;在芯片中只占用四根管脚用来控制及数据传输&#…

【Python】利用Python+thinker实现旋转转盘

需求/目的&#xff1a;用Pythonthinker实现转盘&#xff0c;并且能够随机旋转任意角度。 转盘形式&#xff1a; 主界面&#xff1a; from tkinter import *winTk() win.title("大转盘") win.geometry("300x400")win.mainloop() 转盘绘制&#xff1a; 这…

USMART调试组件学习

USMART调试组件学习日记 写于2024/9/24日晚 文章目录 USMART调试组件学习日记1. 简介2. 调试组件组成3.程序流程图4. 移植解析5. 实验效果5. 实验效果 1. 简介 USMART 是由正点原子开发的一个灵巧的串口调试互交组件&#xff0c;通过它你可以通过串口助手调用程序里面的任何函…

SigLIP技术小结

paperhttps://arxiv.org/abs/2303.15343githubhttps://github.com/google-research/big_vision个人博客位置http://myhz0606.com/article/siglip 1 背景 CLIP[1]自提出以来在zero-shot分类、跨模态搜索、多模态对齐等多个领域得到广泛应用。得益于其令人惊叹的能力&#xff0…

备考中考的制胜法宝 —— 全国历年中考真题试卷大全

在中考这场重要的战役中&#xff0c;每一分都至关重要。为了帮助广大考生更好地备考&#xff0c;我们精心整理了这份全国历年中考真题试卷大全&#xff0c;旨在为大家提供最全面、最权威的备考资料。 文章目录 1. 全科覆盖&#xff0c;无遗漏2. 历年真题&#xff0c;权威可靠3.…