Pyraformer复现心得

Pyraformer复现心得

引用

Liu, Shizhan, et al. “Pyraformer: Low-complexity pyramidal attention for long-range time series modeling and forecasting.” International conference on learning representations. 2021.

代码部分

def long_forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):enc_out = self.encoder(x_enc, x_mark_enc)[:, -1, :]#B,dmodel*3dec_out = self.projection(enc_out).view(enc_out.size(0), self.pred_len, -1)#B,pre,Nreturn dec_out

预测部分就这么长

x_dec, x_mark_dec, mask=None都没用到

enc_out = self.encoder(x_enc, x_mark_enc)[:, -1, :]
#B,dmodel*3
  • 直接进入encoder
def forward(self, x_enc, x_mark_enc):seq_enc = self.enc_embedding(x_enc, x_mark_enc)
  • 重构了encoder和decoder,跟transformer的很不一样
x = self.value_embedding(x) + self.temporal_embedding(x_mark) + self.position_embedding(x)
return self.dropout(x)
  • embedding方法跟former一样
mask = self.mask.repeat(len(seq_enc), 1, 1).to(x_enc.device)

用pyra的方式获取pam掩码

def get_mask(input_size, window_size, inner_size):"""Get the attention mask of PAM-Naive"""# Get the size of all layersall_size = []all_size.append(input_size)for i in range(len(window_size)):layer_size = math.floor(all_size[i] / window_size[i])all_size.append(layer_size)seq_length = sum(all_size)mask = torch.zeros(seq_length, seq_length)# get intra-scale maskinner_window = inner_size // 2for layer_idx in range(len(all_size)):start = sum(all_size[:layer_idx])for i in range(start, start + all_size[layer_idx]):left_side = max(i - inner_window, start)right_side = min(i + inner_window + 1, start + all_size[layer_idx])mask[i, left_side:right_side] = 1# get inter-scale maskfor layer_idx in range(1, len(all_size)):start = sum(all_size[:layer_idx])for i in range(start, start + all_size[layer_idx]):left_side = (start - all_size[layer_idx - 1]) + \(i - start) * window_size[layer_idx - 1]if i == (start + all_size[layer_idx] - 1):right_side = startelse:right_side = (start - all_size[layer_idx - 1]) + (i - start + 1) * window_size[layer_idx - 1]mask[i, left_side:right_side] = 1mask[left_side:right_side, i] = 1mask = (1 - mask).bool()return mask, all_size

接着进入卷积层

seq_enc = self.conv_layers(seq_enc)

先构建CSCM卷积

class Bottleneck_Construct(nn.Module):"""Bottleneck convolution CSCM"""
temp_input = self.down(enc_input).permute(0, 2, 1)
all_inputs = []
self.down = Linear(d_model, d_inner)

下采样

for i in range(len(self.conv_layers)):temp_input = self.conv_layers[i](temp_input)all_inputs.append(temp_input)

堆叠很多次卷积,这个跟former是一样的

class ConvLayer(nn.Module):def __init__(self, c_in, window_size):super(ConvLayer, self).__init__()self.downConv = nn.Conv1d(in_channels=c_in,out_channels=c_in,kernel_size=window_size,stride=window_size)self.norm = nn.BatchNorm1d(c_in)self.activation = nn.ELU()def forward(self, x):x = self.downConv(x)x = self.norm(x)x = self.activation(x)return x

将N次卷积的结果拼接起来

all_inputs = torch.cat(all_inputs, dim=2).transpose(1, 2)#
all_inputs = self.up(all_inputs)
all_inputs = torch.cat([enc_input, all_inputs], dim=1)
self.up = Linear(d_inner, d_model)
all_inputs = self.norm(all_inputs)
return all_inputs
self.norm = nn.LayerNorm(d_model)

之后在跟原始输入拼接起来

  • 卷积layer完了之后是encoderlayer
def forward(self, enc_input, slf_attn_mask=None):attn_mask = RegularMask(slf_attn_mask)
enc_output, _ = self.slf_attn(enc_input, enc_input, enc_input, attn_mask=attn_mask)

进到encoder里面,到了熟悉的former框架

def forward(self, queries, keys, values, attn_mask, tau=None, delta=None):#后面俩参数应该是作者指定的B, L, _ = queries.shape#B,seq,dmodel_, S, _ = keys.shapeH = self.n_heads
#其实L和S是一个数queries = self.query_projection(queries).view(B, L, H, -1)#B, L, H, dmodel/hkeys = self.key_projection(keys).view(B, S, H, -1)#一样的计算方法values = self.value_projection(values).view(B, S, H, -1)#H 表示头的数量-1 表示自动计算该维度
  • encoder的注意力用的fullattention。并且用到了掩码

回到pyra的encoder

self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout, normalize_before=normalize_before)
def forward(self, x):residual = xif self.normalize_before:x = self.layer_norm(x)x = F.gelu(self.w_1(x))x = self.dropout(x)x = self.w_2(x)x = self.dropout(x)x = x + residualif not self.normalize_before:x = self.layer_norm(x)return x
indexes = self.indexes.repeat(seq_enc.size(0), 1, 1, seq_enc.size(2)).to(seq_enc.device)
#B,seq,3,dmodel
indexes = indexes.view(seq_enc.size(0), -1, seq_enc.size(2))
#B,seq+pred,dmodel
all_enc = torch.gather(seq_enc, 1, indexes)
##B,seq+pred,dmodel
seq_enc = all_enc.view(seq_enc.size(0), self.all_size[0], -1)
#B,seq,dmodel*3
return seq_enc

总结

x_dec, x_mark_dec, mask=None都没用到

  • 直接进入encoder

重构了encoder和decoder,跟transformer的很不一样

embedding方法跟former一样

encoder的注意力用的fullattention,并且用到了掩码

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/7792.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

SDL简介和初次尝试

文章目录 SDL的用途和概念SDL下载 SDL的用途和概念 SDL(Simple DirectMedia Layer)是一套开放源代码的跨平台开发库 ,使用C语言写成,SDL提供了数种 操作 图像 ,声音输入输出的函数,让开发者使用 相识的代码 就能够开发出跨平台的…

WiFi一直获取不到IP地址是怎么回事?

在当今这个信息化时代,WiFi已成为我们日常生活中不可或缺的一部分。无论是家庭、办公室还是公共场所,WiFi都为我们提供了便捷的无线互联网接入。然而,有时我们可能会遇到WiFi连接后无法获取IP地址的问题,这不仅影响了我们的网络使…

【车道线检测】一、传统车道线检测:基于霍夫变换的车道线检测史诗级详细教程

1、定义图像显示函数 首先定义一个函数,函数的作用是通过plt库显示两幅图,为后续实验做准备。该函数的主要功能是: 从指定路径加载图像显示图像的基本信息将图像从BGR格式转换为RGB格式并在一个图形窗口中显示两幅图像进行对比 import nump…

Ftrans数据跨境传输方案:保护隐私与促进合作

数据跨境传输是指在不同国家、地区和法律框架下进行的数据交换和传输,数据跨境传输流程周期是数据产生--数据传输--数据接收,而困境来源也来自这3个环节: 1.本地合规限制 数据出口国(数据输出国)的法律对于数据收集的…

Mybatis学习笔记(三)

十、MyBatis的逆向工程 (一)逆向工程介绍 MyBatis的一个主要的特点就是需要程序员自己编写sql,那么如果表太多的话,难免会很麻烦,所以mybatis官方提供了一个逆向工程,可以针对单表自动生成mybatis执行所需要的代码(包…

Github 2024-11-08Java开源项目日报 Top9

根据Github Trendings的统计,今日(2024-11-08统计)共有9个项目上榜。根据开发语言中项目的数量,汇总情况如下: 开发语言项目数量Java项目9Vue项目1经验丰富的Java(后端)开发人员核心面试问题和答案 | 互联网Java工程师进阶知识完全扫盲 创建周期:2085 天开发语言:Java协议…

【新闻文本分类识别】Python+CNN卷积神经网络算法+深度学习+人工智能+机器学习+文本处理

一、介绍 文本分类识别系统。本系统使用Python作为主要开发语言,首先收集了10种中文文本数据集(“体育类”, “财经类”, “房产类”, “家居类”, “教育类”, “科技类”, “时尚类”, “时政类”, “游戏类”, “娱乐类”),然…

数据结构 ——— 链式二叉树的前中后序遍历递归实现

目录 前言 链式二叉树示意图​编辑 手搓一个链式二叉树 链式二叉树的前序遍历 链式二叉树的中序遍历 链式二叉树的后序遍历 前言 在上一章学习了链式二叉树的前中后序遍历的解析 数据结构 ——— 链式二叉树的前中后序遍历解析-CSDN博客 接下来要学习的是代码实现链式…

<项目代码>YOLOv8 pcb板缺陷检测<目标检测>

YOLOv8是一种单阶段(one-stage)检测算法,它将目标检测问题转化为一个回归问题,能够在一次前向传播过程中同时完成目标的分类和定位任务。相较于两阶段检测算法(如Faster R-CNN),YOLOv8具有更高的…

yarn报错`warning ..\..\package.json: No license field`:已解决

出现这个报错有两个原因 1、项目中没有配置许可证 在项目根目录package.json添加 {"name": "next-starter","version": "1.0.0",# 添加这一行"license": "MIT", }或者配置私有防止发布到外部仓库 {"priv…

大模型学习笔记------CLIP模型解读与思考

大模型学习笔记------CLIP模型详解 1、为什么提出CLIP模型2、CLIP模型详解3、CLIP模型的意义4、一些思考 上文说到,多模态大模型应该是非常有发展前景的,首先来学习 CLIP(Contrastive Language-Image Pretraining)这个多模态模型…

昇思25天学习打卡营第1天|快速入门

昇思25天学习打卡营第1天|快速入门 目录 昇思25天学习打卡营第1天|快速入门实操教程 一、MindSpore内容简介 主要特点: MindSpore的组成部分: 二、入门实操步骤 1. 安装必要的依赖包 2. 下载并处理数据集 3. 构建网络模型 4. 训练模型 5. 测试…

【Python TensorFlow】入门到精通

TensorFlow 是一个开源的机器学习框架,由 Google 开发,广泛应用于机器学习和深度学习领域。本篇将详细介绍 TensorFlow 的基础知识,并通过一系列示例来帮助读者从入门到精通 TensorFlow 的使用。 1. TensorFlow 简介 1.1 什么是 TensorFlow…

Python 学习完基础语法知识后,如何进一步提高?

入门Python后,就可以拿些小案例练手了,这时候千万不要傻乎乎地成天啃语法书。 编程是一门实践的手艺,讲究孰能生巧。不管是去手撸算法、或者照葫芦画瓢写几个小游戏都可以让你的Python突飞猛进。 之前看github比较多,推荐给大家…

Java:数据结构-再谈String类

字符串常量池 首先我们来思考这段代码,为什么运行结果一个是true,一个是false呢? public class Test {public static void main(String[] args) {String s1"123";String s2"123";String s3new String("555")…

书生第四期实训营基础岛——L1G2000 玩转书生「多模态对话」与「AI搜索」产品

基础任务 MindSearch使用示例 书生浦语使用示例 书生万象使用示例 进阶任务 问题:目前生成式AI在学术和工业界有什么最新进展? 回答截图: 知乎回答链接:目前生成式AI在学术和工业界有什么最新进展?

ReactPress:重塑内容管理的未来

ReactPress Github项目地址:https://github.com/fecommunity/reactpress 欢迎提出宝贵的建议,欢迎一起共建,感谢Star。 ReactPress:重塑内容管理的未来 在当今信息爆炸的时代,一个高效、易用的内容管理系统&#xff0…

短视频矩阵系统源码/抖去推源头技术4年开发

#短视频矩阵系统# #短视频矩阵系统源码# #短视频矩阵系统源码开发# #短视频矩阵系统源头技术开发# 抖音短视频矩阵系统集成开发是指利用抖音平台的开放接口和API,构建一个系统,该系统能够管理多个抖音矩阵账号,实现内容的统一发布、账号管理、…

CJ/T188-2004 报文举例

CJ/T188-2004 报文举例 # 读水表地址 # 请求报文: FE FE FE FE 68 AA AA AA AA AA AA AA AA 03 03 81 0A 00 49 16FE FE FE FE :前导字符 FE68 :起始字符AA :仪表类型AA AA AA AA AA AA AA :仪表地址(当…

JavaEE进阶---第一个SprintBoot项目创建过程我的感受

文章目录 1.我的创建感受2.环境配置说明2.1xml文件国内源2.2配置流程 3.创建项目4.项目创建说明5.第一个程序--helloworld 1.我的创建感受 今天是学习这个spring boot项目创建的一天,这个确实过程坎坷,于是我自己决定弄一个这个IDEA的 专业版本&#xf…