论文速览【序列模型 seq2seq】—— 【Ptr-Net】Pointer Networks

  • 标题:Pointer Networks
  • 文章链接:Pointer Networks
  • 参考代码(非官方):keon/pointer-networks
  • 发表:NIPS 2015
  • 领域:序列模型(RNN seq2seq)改进 / 深度学习解决组合优化问题
  • 【本文为速览笔记,仅记录核心思想,具体细节请看原文】

  • 摘要:我们引入了一种新的神经网络结构,用于学习一个输出序列的条件概率,其中输出序列的元素是对应于输入序列位置的离散标记。这类问题不能通过Seq2Seq模型和神经图灵机等现有方法轻松解决,因为(这些问题中)输出每一步的目标类别数量取决于输入的长度,而输入的长度是可变的。排序可变长度序列、各种组合优化问题等属于这类问题。我们的模型使用了最近提出的神经注意力机制,来解决可变长度输出字典的问题。与先前的注意力机制不同的是,我们的模型不是在每个解码器步骤中将编码器的隐藏单元与上下文向量混合,而是使用注意力作为指针来选择输入序列的成员作为输出。我们将这种结构称为指针网络(Ptr-Net)。我们在平面凸包(planar convex hulls)、计算德劳内三角剖分(computing Delaunay triangulations)和旅行商问题(TSP)三个有挑战性的几何问题上验证了指针网络有能力以 Data-driven 的形式学到近似解。指针网络不仅改进了具有输入注意力的序列到序列模型,还能够推广到可变长度输出字典。我们展示了训练的模型在超过其训练最大长度的情况下也能泛化

文章目录

  • 0. 本文考虑的问题
  • 1. 传统方法及其问题
    • 1.1 Sequence-to-Sequence Model
    • 1.2 Content Based Input Attention
    • 1.3 问题
  • 2. 本文方法
  • 3. 实验
  • 4. 总结

0. 本文考虑的问题

  • 本文主要考虑那些 “输出序列是离散的,并对应于输入序列中位置” 的 Seq2Seq 问题。实验的问题包括

    1. 平面凸包问题planar convex hulls:给定平面上若干个点的坐标,输出一组点的索引,使得这些点围成的多边形可以覆盖所有点
    2. 计算德劳内三角剖分computing Delaunay triangulations:给定平面上若干个点的坐标,以点索引形式输出德劳内三角剖分结果(这是一种以最近的三点形成三角形,且各线段皆不相交的三角网剖分方式)
      在这里插入图片描述

    除了以上两个例子外,很多组合优化问题也具有这种形式,作者测试了 Tsp 问题。作者开源了以上三类问题的数据集

  • 注意这类问题的特点是:输出序列的每个元素都是输入序列包含的位置索引,输入序列长度 = 输出索引范围

1. 传统方法及其问题

1.1 Sequence-to-Sequence Model

  • 本文是对 Seq2Seq 模型的一个改进,Seq2Seq 模型用于把一个序列转换成另外一个序列,且不要求输入序列和输出序列等长,典型应用有机器翻译等
    在这里插入图片描述
    如上图所示,Seq2Seq 模型通常使用 RNN 及其变种(LSTM/GRU)以 encoder-decoder 结构构建。RNN 类模型内部有一个隐状态代表目前积累的信息,每读入一个序列样本就将其更新,隐变量值可以较好地捕获前驱序列特征。通常使用两个独立的 RNN 模型,一个作为 encoder 提取输入序列的特征,另一个作为 decoder 以 Autoregress 形式解码得到输出序列
  • 考虑第 0 节的问题,形式化地讲,给定训练样例 ( P , C P ) (\mathcal{P}, \mathcal{C^P}) (P,CP),Seq2Seq 模型使用参数模型计算条件概率
    p ( C P ∣ P ; θ ) = p ( C 2 ∣ C 1 , P ; θ ) ⋅ p ( C 3 ∣ C 2 , C 1 , P ; θ ) ⋅ ⋅ ⋅ p ( C m ( P ) ∣ C m ( P ) − 1 , C m ( P ) − 2 , . . . , C 0 , P ; θ ) = ∏ i = 1 m ( P ) p ( C i ∣ C 1 , … , C i − 1 , P ; θ ) \begin{aligned} p(\mathcal{C}^{\mathcal{P}}|\mathcal{P};\theta) &= p (C_2|C_1,\mathcal{P};\theta)·p (C_3|C_2,C_1,\mathcal{P};\theta) ···p (C_{m(\mathcal{P})}|C_{m(\mathcal{P})-1},C_{m(\mathcal{P})-2},...,C_0,\mathcal{P};\theta) \\ &=\prod_{i=1}^{m(\mathcal{P})}p (C_{i}|C_{1},\ldots,C_{i-1},\mathcal{P};\theta) \end{aligned} p(CPP;θ)=p(C2C1,P;θ)p(C3C2,C1,P;θ)⋅⋅⋅p(Cm(P)Cm(P)1,Cm(P)2,...,C0,P;θ)=i=1m(P)p(CiC1,,Ci1,P;θ) 其中 P = { P 1 , … , P n } \mathcal{P}=\{P_{1},\ldots,P_{n}\} P={P1,,Pn} 是包含 n n n 个向量的输入序列(上图中的 v v v), C P = { C 1 , … , C m ( P ) } \mathcal{C}^{\mathcal{P}}=\{C_{1},\ldots,C_{m(\mathcal{P})}\} CP={C1,,Cm(P)} 是由 m ( P ) m(\mathcal{P}) m(P) 个索引组成的序列,每个索引取值范围为 [ 1 , n ] [1,n] [1,n]。直观地看这个条件概率就是模型解码出目标序列的概率。注意目标序列的长度 m ( P ) m(\mathcal{P}) m(P) 通常取决于 P \mathcal{P} P。学习目标是最大化训练集中所有样本的上述概率之和,即
    θ ∗ = arg max ⁡ θ ∑ P , C P log ⁡ p ( C P ∣ P ; θ ) , \theta^{*}=\operatorname*{arg\,max}_{\theta}\sum_{\mathcal{P},\mathcal{C}^{ \mathcal{P}}}\log p(\mathcal{C}^{\mathcal{P}}|\mathcal{P};\theta), θ=θargmaxP,CPlogp(CPP;θ), 训练之后评估阶段,给定输入序列 P \mathcal{P} P,使用学习到的参数 θ ∗ \theta^{*} θ 选择具有最高概率的序列
    C ^ P = arg max ⁡ C P p ( C P ∣ P ; θ ∗ ) \hat{\mathcal{C}}^{\mathcal{P}}=\operatorname*{arg\,max}_{\mathcal{C}^{\mathcal{P}}}p(\mathcal{C}^{\mathcal{P}}|\mathcal{P};\theta^{*}) C^P=CPargmaxp(CPP;θ) 由于输出序列空间大小为 n m ( P ) n^{m(\mathcal{P})} nm(P),找到真正的最大概率输出序列的计算量太大,工程上通常使用贪心或者 beam search 方法进行解码

1.2 Content Based Input Attention

  • RNN 类模型只能利用隐状态间接地获取之前序列的信息,由于隐藏状态维度一定远远小于之前的变长序列所有样本的连接维度,这种做法无可避免地会损失一些信息。一种补偿方式是引入额外的 Attention 模块,它对整个输入序列的所有 hidden state e 1 , . . . , e n e_1,...,e_n e1,...,en 构造 key 向量,之后在任意第 i i i 个解码位置,用其 hidden state d i d_i di 构造 query 并和整个输入序列计算 attention,根据 attention 结果汇聚(加权平均)整个输入序列的 hidden state,最后用得到的结果和 d i d_i di 做 concatenate 来增强解码时信息输入,缓解信息损耗问题。下图是一个示意
    在这里插入图片描述

  • 形式化地,设 encoder 和 decoder 的隐藏状态为 ( e 1 , … , e n ) (e_{1},\ldots,e_{n}) (e1,,en) ( d 1 , … , d m ( P ) ) (d_{1},\ldots,d_{m(\mathcal{P})}) (d1,,dm(P)),如下计算每个输出时刻 i i i 的附加信息
    u j i = v T tanh ⁡ ( W 1 e j + W 2 d i ) j ∈ ( 1 , … , n ) a j i = softmax ⁡ ( u j i ) j ∈ ( 1 , … , n ) d i ′ = ∑ j = 1 n a j i e j \begin{aligned} u_{j}^{i} & =v^{T} \tanh \left(W_{1} e_{j}+W_{2} d_{i}\right) & j \in(1, \ldots, n) \\ a_{j}^{i} & =\operatorname{softmax}\left(u_{j}^{i}\right) & j \in(1, \ldots, n) \\ d_{i}^{\prime} & =\sum_{j=1}^{n} a_{j}^{i} e_{j} & \end{aligned} ujiajidi=vTtanh(W1ej+W2di)=softmax(uji)=j=1najiejj(1,,n)j(1,,n) 注意这里使用了比较早期的加性注意力,向量 v v v 和矩阵 W 1 , W 2 W_1,W_2 W1,W2 是三组要学习的参数。最后用增强后的 [ d i , d i ′ ] [d_i, d_i'] [di,di] 作为隐状态进行解码。这种方式相对 1.1 节的朴素 Seq2Seq 方法有显著性能提高

1.3 问题

  • 以上两个方法虽然也能部分解决第 0 节的问题,但它们都有一个显著缺陷,即处理问题的尺度无法随着输入泛化:对于每个不同的输入长度 n n n 都要单独训练一个模型。这个问题的本质在于模型无法动态地从输入序列中构造词表,训练时都是事先根据问题规模设置好词表大小的

2. 本文方法

  • 作者解决 1.3 节问题的思路很直接,他注意随着输入序列长度的变化,attention 范围可以自适应地变化,所以解码过程中只要想办法自回归地让 attention 像指针一样指出输入序列中的目标位置即可。这其实是对 1.2 节 decoder 的一种简化,我们不再需要根据 attention 汇聚特征再做分类任务,而是直接用经过 softmax 的 attention 向量做分类,这样输出空间可以根据输入序列长度自动调整。示意图如下
    在这里插入图片描述
    图(a) 是 1.1 节的朴素 Seq2Seq 模型,图(b) 是作者提出的指针网络模型
  • 形式化地,如下用 attention 机制改写 1.1 节中对 p ( C i ∣ C 1 , … , C i − 1 , P ) p(C_{i}|C_{1},\ldots,C_{i-1},\mathcal{P}) p(CiC1,,Ci1,P) 建模的方式
    u j i = v T tanh ⁡ ( W 1 e j + W 2 d i ) j ∈ ( 1 , … , n ) p ( C i ∣ C 1 , … , C i − 1 , P ) = softmax ⁡ ( u i ) \begin{aligned} u_{j}^{i} & =v^{T} \tanh \left(W_{1} e_{j}+W_{2} d_{i}\right) \quad j \in(1, \ldots, n) \\ p\left(C_{i} \mid C_{1}, \ldots, C_{i-1}, \mathcal{P}\right) & =\operatorname{softmax}\left(u^{i}\right) \end{aligned} ujip(CiC1,,Ci1,P)=vTtanh(W1ej+W2di)j(1,,n)=softmax(ui) 这里 u j i u_j^i uji 是长度为 n n n 的注意力得分向量, softmax ⁡ \operatorname{softmax} softmax 操作将其转换为输入序列上的分布,直接把这个 attention 分布看作在尺寸为 n n n 的词表上做分类时的 softmax 分布,使用交叉熵损失进行优化即可

3. 实验

  • 这里仅介绍 TSP 上的结果,另外两个问题详见论文

    TSP 问题是说给定一个城市列表,希望找到一个最短的路线,要求把每个城市访问一次并能返回到起点。作者假设两个城市之间的距离是对称的,即 A->B 的距离 = B-> A 的距离。

    1. 数据生成: 任意训练样本 ( P , C P ) (\mathcal{P}, \mathcal{C^P}) (P,CP) 中, P = { P 1 , … , P n } \mathcal{P}=\{P_{1},\ldots,P_{n}\} P={P1,,Pn} 是在 [0,1]×[0,1] 区间中随机采样的 n n n 个笛卡尔坐标, C P = { C 1 , . . . , C n } \mathcal{C^P} = \{C_1,...,C_n\} CP={C1,...,Cn} 是一个从 1 到 n 的排列,代表最优路线。为了一致性,在训练数据集中,数据集总是从第一个城市开始。为了生成精确的数据,城市数量不一样,所构建的数据集输出结果方式也不一样,具体地说: 在城市数量 n ≤ 20 n\leq 20 n20 的情况下,采用 Held-Karp 算法;对于 n > 20 n>20 n>20 的情况,作者考虑了 A1 A2 A3 三种启发式搜索算法,其中 A3 算法保证在离最优长度1.5倍的范围内找到一个解
    2. 模型设置:所有模型都使用了具有 256 或 512 个隐藏单元的单层 LSTM;使用随机梯度下降(SGD)训练;学习率为1.0;batch_size为128;随机均匀权重初始化从 -0.08 到 0.08;L2正则化梯度裁剪为 2.0
  • 实验结果如下
    在这里插入图片描述
    1. 由于 TSP 问题是有约束的(不能重复访问城市,也不能忽略城市),作者在解码时的波束搜索(beam Search)过程中过滤有效的结果。这种过滤过程在 n > 20 n>20 n>20 时是必须的。当 n = 30 n=30 n=30 时(超过训练时城市数量),失败率到达 30%;当 n = 40 n=40 n=40 时失败率上升到 98%
    2. 表中 OPTIMAL 列是真实最优结果,缺少 n = 50 n=50 n=50 的结果是因为计算复杂度太高了
    3. 表中第一组行显示了在 n n n 相同时使用最优数据训练的结果。注意到使用最差的算法 (A1) 数据来训练 Ptr-Net 时,模型优于其试图模仿的A1算法(6.42 < 6.46)
    4. 表中第二组行显示了在5~20个城市的最佳数据上训练的 Ptr-Nets 如何能够推广到更多的城市。 结果对于n=25来说几乎是完美的,对于n=30来说是好的,但在40或更长的时间里似乎会崩溃(尽管如此,结果还是比随机策略要好得多)

4. 总结

  • Pointer Networks 天生具备从输入序列中提取元素的能力,因此它非常适合用来实现 “复制” 这个功能。NLP 领域很多研究者也确实把它用于复制源文本中的一些词汇。比如摘要任务,由于所需的词汇较多,非常适合使用复制的方法来复制一些词,目前Pointer Networks 已经称为了文本摘要方法中的利器。
  • 此外,在组合优化领域,Ptr-Nets 也得到了广泛的应用,并已成为组合优化问题的端到端方法的入门模型,后来基于此模型,研究者也进行了很多改进,比如与强化学习结合,将 Attention 换成 Transformer 中采用的Self- Attention等。总之,Ptr-Nets为组合优化的端到端解决办法起了一个好头,并促使广大研究者进行更加深入的研究

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/141789.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【基于MBD开发模式的matlab持续集成(一)】

基于MBD开发模式的matlab持续集成 引言 或许是感受到行业内卷的愈加激烈&#xff0c;在传统制造和高新技术相结合的新能源领域对软件工程开发的要求也愈加提高&#xff0c;尤其在互联网已经大行 其道的敏捷开发&#xff0c;便顺其自然的被新能源的老板们所看重。 概述 本文…

java面试题-设计模式基础

面试专题-设计模式 前言 在平时的开发中&#xff0c;涉及到设计模式的有两块内容&#xff0c;第一个是我们平时使用的框架&#xff08;比如spring、mybatis等&#xff09;&#xff0c;第二个是我们自己开发业务使用的设计模式。 面试官一般比较关心的是你在开发过程中&#…

Vue 学习笔记 错误ResizeObserver loop completed with undelivered notifications

环境Vue3 Ts 使用了el-table 后&#xff0c;容易出现如下错误 ERROR ResizeObserver loop completed with undelivered notifications. at handleError (webpack-internal:///./node_modules/webpack-dev-server/client/overlay.js:299:58) at eval (webpack-internal:///./nod…

基于FFmpeg+SDL的视频播放器的制作

基于FFmpegSDL的视频播放器的制作 基于FFmpegSDL的视频播放器的制作实验1实验2实验3实验4基本练习进阶练习 实验5 基于FFmpegSDL的视频播放器的制作 雷霄骅博士的课程。 课程链接&#xff1a;https://blog.csdn.net/leixiaohua1020/article/details/47068015 初学 FFmpeg&am…

PTE 口语练习准备(二)

目录 Day 1单词咬字专项 ChapterI 咬字练习之音书与辅音 Chapter 2咬字练习之“字典音 Chapter 3元“漏&#xff0c;加&#xff0c;换”语速练习 Day 1单词咬字专项 考量项目 单词 单词咬字也决定了句子呈现的位置 重音位置 电脑是有固有的模型的&#xff0c;给你的模型…

Python爬虫在Web应用自动化测试中的应用

在Web应用开发过程中&#xff0c;自动化测试是确保应用质量和稳定性的重要环节。本文将介绍如何使用Python爬虫与自动化测试技术相结合&#xff0c;实现对Web应用进行自动化测试的方法和步骤。通过这种结合&#xff0c;我们可以提高测试效率、减少人力成本&#xff0c;并确保应…

探索入行嵌入式应该要学会哪些知识?

入行嵌入式领域&#xff0c;需掌握嵌入式系统基础、C语言、C编程&#xff0c;了解微控制器、微处理器&#xff0c;学习电子电路基础&#xff0c;正好看我这一套保姆式嵌入式休息资料&#xff0c;里面包含了编程教学、数据处理、毕设800套和语言类教学&#xff0c;非常的全面、放…

自注意力机制

回顾以下注意力机制&#xff1a; 自注意力机制 Self-Attention的关键点 在于 K ≈ \approx ≈V ≈ \approx ≈Q 来源于同一个X&#xff0c;三者是同源的&#xff0c;通过 W Q W_Q WQ​, W K W_K WK​, W V W_V WV​做了一层线性变换。 接下来步骤和注意力机制一模一样。 …

开发者必备!如何将闲置iPad Pro打造为编程工具,使用VS Code编写代码

文章目录 前言1. 本地环境配置2. 内网穿透2.1 安装cpolar内网穿透(支持一键自动安装脚本)2.2 创建HTTP隧道 3. 测试远程访问4. 配置固定二级子域名4.1 保留二级子域名4.2 配置二级子域名 5. 测试使用固定二级子域名远程访问6. ipad pro通过软件远程vscode6.1 创建TCP隧道 7. ip…

Vue3最佳实践 第五章 Vue 组件应用 3( Slots )

5.4 Slots 我们已经了解到组件能够接收任意类型的 JavaScript 值作为 props&#xff0c;但组件要如何接收模板内容呢&#xff1f;在某些场景中&#xff0c;我们可能想要为子组件传递一些模板片段&#xff0c;让子组件在它们的组件中渲染这些片段。Slots 可用于将Html内容从父组…

一文了解什么SEO

搜索引擎优化 (SEO) 是一门让页面在 Google 等搜索引擎中排名更高的艺术和科学。 一、搜索引擎优化的好处 搜索引擎优化是在线营销的关键部分&#xff0c;因为搜索是用户浏览网络的主要方式之一。 搜索结果以有序列表的形式呈现&#xff0c;网站在该列表中的排名越高&#x…

指针笔试题讲解(让指针变得简单易懂)

数组名的理解 : 数组名就是首元素地址 但是有两个例外&#xff1a; 1. sizeof&#xff08;数组名&#xff09;这里的数组名表示整个数组的大小&#xff0c;sizeof&#xff08;数组名&#xff09;计算的是整个数组的大小&#xff0c;单位是字节 2. &数组名 这里的数组…

彩色文本进度条

动态加色打印&#xff0c;\033控制&#xff0c;显示进行到的百分比&#xff0c;实时更新总共用时。 (本笔记适合能熟练应用字符串和循环技能的 coder 翻阅) 【学习的细节是欢悦的历程】 Python 官网&#xff1a;https://www.python.org/ Free&#xff1a;大咖免费“圣经”教程…

CH2--x86系统架构概览

2.1 OVERVIEW OF THE SYSTEM-LEVEL ARCHITECTURE 图中的实线箭头表示线性地址&#xff0c;虚线表示段选择器&#xff0c;虚线箭头表示物理地址 2.1.1 Global and Local Descriptor Tables 全局描述符表 (GDT) GDT是一个全局的段描述符表&#xff0c;它存储在系统内存中的一个固…

Aztec交易架构解析

1. 引言 前序博客有&#xff1a; Aztec的隐私抽象&#xff1a;在尊重EVM合约开发习惯的情况下实现智能合约隐私完全保密的以太坊交易&#xff1a;Aztec网络的隐私架构Aztec.nr&#xff1a;Aztec的隐私智能合约框架——用Noir扩展智能合约功能Account Abstraction账号抽象——…

力扣:102. 二叉树的层序遍历(Python3)

题目&#xff1a; 给你二叉树的根节点 root &#xff0c;返回其节点值的 层序遍历 。 &#xff08;即逐层地&#xff0c;从左到右访问所有节点&#xff09;。 来源&#xff1a;力扣&#xff08;LeetCode&#xff09; 链接&#xff1a;力扣&#xff08;LeetCode&#xff09;官网…

Linux文件内容显示练习

1.新建2个文件b1.txt b2.txt ,使用vim打开b1.txt 输入“Hello World”字符串,将b1.txt硬链接到b2.txt 查看2个文件的硬连接数 [rootserver ~]# vim b1.txt [rootserver ~]# ln b1.txt b2.txt #建立硬链接 [rootserver ~]# stat b2.txt [rootserver ~]# stat b1.txt [r…

Android开发MVP架构记录

Android开发MVP架构记录 安卓的MVP&#xff08;Model-View-Presenter&#xff09;架构是一种常见的软件设计模式&#xff0c;用于帮助开发者组织和分离应用程序的不同组成部分。MVP架构的目标是将应用程序的业务逻辑&#xff08;Presenter&#xff09;、用户界面&#xff08;V…

CockroachDB集群部署

CockroachDB集群部署 1、CockroachDB简介 CockroachDB(有时简称为CRDB)是一个免费的、开源的分布式 SQL 数据库&#xff0c;它建立在一个事务性和强一致性的键 值存储之上。它由 PebbleDB(一个受 RocksDB/leveldb 启发的 K/B 存储库)支持&#xff0c;并使用 Raft 分布式共识…

2023-Chrome插件推荐

Chrome插件推荐 一键管理扩展 链接 https://chromewebstore.google.com/detail/lboblnfejcmcaplhnbkkfcienhlhpnni 介绍 一键开启、禁用Chrome插件。 Checker Plus for Gmail™ 链接 https://jasonsavard.com/zh-CN/Checker-Plus-for-Gmail https://chromewebstore.goo…