对时间序列SOTA模型Patch TST核心代码逻辑的解读

前言

Patch TST发表于ICLR23,其优势在于保留了局部语义信息;更低的计算和内存使用量;模型可以关注更长的历史信息,Patch TST显著提高了时序预测的准确性,Patch可以说已成为时序模型的基本操作。我在先前的一篇文章对Patch TST做了比较细致的论文解读,各位朋友可参考。

但是最近很多朋友私信问我:Patch TST到底好在哪里?Transformer模型也对时序数据进行了切分,和Patch TST的切片有何区别?其实在我没有阅读Patch TST的代码之前,我也一直没想明白:对时间序列数据进行Patch操作之后,数据是怎么放入到Transformer的编码器

图片

只看论文,确实很难对patch有深刻的理解,最佳的方法还是打断点走一遍代码。今天我这篇文章就梳理了Patch TST代码中几个关键的节点,并标注了数据的维度信息,掌握了Transformer和Patch TST维度变化上的差异,也就解答了上面所有的问题,对Patch的好处也就有了更深刻的理解。

Patch TST与Transformer输入特征的对比

01. Transformer的数据输入维度

我们首先统一基本的符号表示,batch_size表示batch的维度;seq_len表示输入时序数据的长度;Channel表示时序特征的数量;patch_len表示patch的长度;patch_num表示分段后patch的数量;d_model表示模型的维度。

好了,我们现在统一了符号表示,思考第一个问题:原始transformer中时序特征输入到编码器时的特征维度是怎样的?

答案其实是:[batch_size,seq_len,d_model]!

02. Patch TST的数据输入维度

那么切换到Patch TST模型,经过patch处理后,它输入到encoder编码器之前的特征维度是怎么样的?

答案其实是:[(batch_size*channel),patch_num,d_model]

我们对比transformer和Patch TST的输入数据维度可以发现,两者的第三个维度d_model是一致的。但是,序列长度由seq_len变为patch_numbatch的大小由batch_size变为(batch_size*channel)。

经过切分后,patch_num的大小肯定是远远小于seq_len的,相当于输入序列变短了,正是因为如此,patch TST的在计算Attention的时候计算效率大幅提升。同时,我们可以看到Patch TST的第一个维度变为(batch_size*channel)。整个过程(我个人)理解为通过patch降低了序列长度,但增加了batch数量。就是通过这种方式,实现了计算量的减少。

核心代码解读

Patch TST代码下载地址:https://github.com/yuqinie98/PatchTST

以上的分析其实已经给出了本篇文章想说的结论,即为什么Patch效果要比原始模型好。但是,从代码解读的角度来看,我们仍有两个问题没有搞清楚:1、Patch TST是如何把输入到Transformer模型的数据维度变为[(batch_size*channel),patch_num,d_model]的;2、Patch TST的代码逻辑是怎么样的?

图片

上面这张图是我对照Patch TST的代码整理数据流向,其中关键的节点我用绿色和橘黄色做了标注。

  • 我们发现从执行train()函数开始,途经Patch TST类、PatchTST_backbone类、到做Normalization方法,这个过程数据维度一直没变,是[batch_size,channel,seq_len]

  • 执行到PatchTST_backbone.py的unfold()方法时,此时维度发生变化,变为:[batch_size,channel,patch_num,patch_len],代码如下所示,经过这一步完成了数据的切分

# do patchingif self.padding_patch == 'end':    z = self.padding_patch_layer(z)# unfold函数就是按照步长和patch_len进行切分z = z.unfold(dimension=-1, size=self.patch_len, step=self.stride) z = z.permute(0,1,3,2)
  • 然后,经过TSTiEncoder类中的reshape()方法,数据维度变为[(batch_size*channel),patch_num,d_model],代码如下:

def forward(self, x) -> Tensor:                                                 n_vars = x.shape[1]    # Input encoding    x = x.permute(0,1,3,2)                                                                                                             u = torch.reshape(x, (x.shape[0]*x.shape[1],x.shape[2],x.shape[3]))    u = self.dropout(u + self.W_pos)                                             # Encoder    z = self.encoder(u)                                                         z = torch.reshape(z, (-1,n_vars,z.shape[-2],z.shape[-1]))                  z = z.permute(0,1,3,2)                                                        return z
 
  • reshape之后的数据送入到encoder(),encoder输出的仍然是三维的,既[(batch_size*channel),patch_num,d_model],所以我们看到encoder的输出结果再次经过reshape变回四维,然后再经过head()变到与预测序列的维度一致,从而计算损失。

总结

Patch TST的代码推荐大家亲自跑一遍,其实模型结构没有太大变化,重点是对数据数据的前处理,特别是要理解patch切分后,从四维向量到三维的转变过程(batch_size*channel),经过这一步骤,输入序列长度大大减小,同时batch数量增加。


欢迎大家关注我的公众号【科学最top】,专注于时序高水平论文解读。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/147084.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【掘金量化使用技巧】用日线合成长周期k线

掘金API中的接口最长的周期是‘1d’的,因此周线/月线/年线等数据需要自己进行合成。 基本思路 用日线合成长周期的k线只需要确定好合成的周期以及需要的数据即可。 周期: 一般行情软件上提供年k、月k、周k,我也选择年、月、周再加一个季度频率。 数据:…

Linux:终端(terminal)与终端管理器(agetty)

终端的设备文件 打开/dev目录可以发现其中有许多字符设备文件,例如对于我的RedHat操作系统,拥有tty0到tty59,它们是操作系统提供的终端设备。对于tty1-tty12使用ctrlaltF*可以进行快捷切换,下面的命令可以进行通用切换。 sudo ch…

GPU加速时代:如何用CuPy让你的Python代码飞起来?

你是不是也有这样的感受:明明写的Python代码很简洁,用NumPy处理数据也很方便,可是一跑起来就慢得像乌龟?尤其是当你面对庞大的数据集时,光是等结果出来,就已经耗掉大半天了。其实,我以前也是这么干的,直到我发现了CuPy,一个能让NumPy飞速跑起来的GPU加速神器。 你…

10. 排序

一、排序的概念及引用 1. 排序的概念 排序:所谓排序,就是使一串记录,按照其中的某个或某些关键字的大小,递增或递减的排列起来的操作。 稳定性:假定在待排序的记录序列中,存在多个具有相同的关键字的记录…

基于SpringBoot的医院管理系统【附源码】

基于SpringBoot的医院管理系统(源码L文说明文档) 目录 4 系统设计 4.1 系统概述 4系统概要设计 4.1概述 4.2系统结构 4.3.数据库设计 4.3.1数据库实体 4.3.2数据库设计表 5系统详细实现 5.1 医生模块的实现 5.1.…

Mybatis 返回 Map 对象

一、场景介绍 假设有如下一张学生表: CREATE TABLE student (id int NOT NULL AUTO_INCREMENT COMMENT 主键,name varchar(100) NOT NULL COMMENT 姓名,gender varchar(10) NOT NULL COMMENT 性别,grade int NOT NULL COMMENT 年级,PRIMARY KEY (id) ) ENGINEInnoD…

【RocketMQ】一、基本概念

文章目录 1、举例2、MQ异步通信3、背景4、Rocket MQ 角色概述4.1 主题4.2 队列4.3 消息4.4 生产者4.5 消费者分组4.6 消费者4.7 订阅关系 5、消息传输模型5.1 点对点模型5.2 发布订阅模型 1、举例 以坐火车类比MQ: 安检大厅就像是一个系统的门面,接受来…

整合多方大佬博客以及视频 一文读懂 servlet

参考文章以及视频 文章: 都2023年了,Servlet还有必要学习吗?一文带你快速了解Servlet_servlet用得多吗-CSDN博客 【计算机网络】HTTP 协议详解_3.简述浏览器请求一个网址的过程中用到的网络协议,以及协议的用途(写关键点即可)-CSDN博客 【…

大数据可视化-三元图

三元图是一种用于表示三种变量之间关系的可视化工具,常用于化学、材料科学和地质学等领域。它的特点是将三个变量的比例关系在一个等边三角形中展示,使得每个点的位置代表三个变量的相对比例。 1. 结构 三个角分别表示三个变量的最大值(通常…

TikTok流量不佳:是网络环境选择不当还是其他原因?

TikTok,作为全球短视频社交平台的佼佼者,每天都有海量的内容被上传和分享。然而,很多用户和内容创作者发现,他们的TikTok视频流量并不理想。这引发了一个问题:TikTok流量不佳,是因为网络环境选择不当&#…

Lumos学习王佩丰Excel第十五讲:条件格式与公式

一、使用简单的条件格式 1、为特定范围的数值标记特殊颜色 条件格式-需选择设定范围(大于/小于/介于/......): 数值会动态根据条件判断更新颜色: 模糊匹配+条件格式:选择包含部分文本的特殊值 2、查找重复…

【BurpSuite】Cross-site scripting (XSS 学徒部分:1-9)

🏘️个人主页: 点燃银河尽头的篝火(●’◡’●) 如果文章有帮到你的话记得点赞👍收藏💗支持一下哦 【BurpSuite】Cross-site scripting (XSS 学徒部分:1-9) 实验一 Lab: Reflected XSS into HTML context with nothing…

国自然基金项目撰写技巧、技术路线与ChatGPT融合应用

随着社会经济发展和科技进步,基金项目对创新性的要求越来越高。申请人需要提出独特且有前瞻性的研究问题,具备突破性的科学思路和方法。因此,基金项目申请往往需要进行跨学科的技术融合。申请人需要与不同领域结合,形成多学科交叉…

一款批量下载 B 站动态页图片的脚本

在逛 B 站的时候,总能看到不少 UP 会发很多图片,此时一个一个保存非常麻烦,而且文件名都是随机的字符串,还得手工重命名。 为此,特地搜索了下有没相关的浏览器插件或油猴脚本,还真给我找到一个。 脚本地址…

图解 TCP 四次挥手|深度解析|为什么是四次|为什么要等2MSL

写在前面 今天我们来图解一下TCP的四次挥手、深度解析为什么是四次? 上一片文章我们已经介绍了TCP的三次握手 解析四次挥手 数据传输完毕之后,通信的双方都可释放连接。现在客户端A和服务端B都处于ESTABLISHED状态。 第一次挥手 客户端A的应用进…

计算机网络-小型综合网络的搭建涉及到无线路由交换安全

目录 1 拓扑架构 2 做项目的思路 3 做配置 3.1先做核心交换 3.2 防火墙的配置 4 ac 和ap 的配置 4.1 ac上配置安全的东西 5.1 测试​编辑 1 拓扑架构 要求看上面的图 2 做项目的思路 这张网很明显是一个小综合,设计到我们的无线交换,路由…

MISC - 第二天(wireshark,base64解密图片,zip文件伪加密,LSB二进制最低位,ARCHPR工具)

前言 各位师傅大家好,我是qmx_07,今天给大家讲解杂项 乌镇峰会种图 使用了stegsolve工具,查看更多信息 发现flag信息 更改为html后缀flag{97314e7864a8f62627b26f3f998c37f1} wireshark 看题目是 分析pacp数据包,通过网站登录…

HarmonyOS应用开发(组件库)--组件模块化开发、工具包、设计模式(持续更新)

致力于,UI开发拿来即用,提高开发效率 正则表达式...手机号校验...邮箱校验 文件判断文件是否存在 网络下载下载图片从沙箱中图片转为Base64格式从资源文件中读取图片转Base64 组件输入框...矩形输入框...输入框堆叠效果(用于登录使用&#xf…

基于ECC簇内分组密钥管理算法的无线传感器网络matlab性能仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于ECC簇内分组密钥管理算法的无线传感器网络matlab性能仿真,对比网络通信开销,存活节点数量,网络能耗以及数据通信量四个指标…

C语言的文件函数

此篇文章主要对C语言中的" 文件读写函数 "进行详细的刨析~通过此篇文章能够了解并学习到:" 字符读写函数 "," 文本行读写函数 "," 格式化读写函数 "," 二进制读写函数 "&#…