lag-llama源码解读(Lag-Llama: Towards Foundation Models for Time Series Forecasting)

Lag-Llama: Towards Foundation Models for Time Series Forecasting
文章内容:
时间序列预测任务,单变量预测单变量,基于Llama大模型,在zero-shot场景下模型表现优异。创新点,引入滞后特征作为协变量来进行预测。

获得不同频率的lag,来自glunoTS库里面的源码

def _make_lags(middle: int, delta: int) -> np.ndarray:"""Create a set of lags around a middle point including +/- delta."""return np.arange(middle - delta, middle + delta + 1).tolist()def get_lags_for_frequency(freq_str: str,lag_ub: int = 1200,num_lags: Optional[int] = None,num_default_lags: int = 7,
) -> List[int]:"""Generates a list of lags that that are appropriate for the given frequencystring.By default all frequencies have the following lags: [1, 2, 3, 4, 5, 6, 7].Remaining lags correspond to the same `season` (+/- `delta`) in previous`k` cycles. Here `delta` and `k` are chosen according to the existing code.Parameters----------freq_strFrequency string of the form [multiple][granularity] such as "12H","5min", "1D" etc.lag_ubThe maximum value for a lag.num_lagsMaximum number of lags; by default all generated lags are returned.num_default_lagsThe number of default lags; by default it is 7."""# Lags are target values at the same `season` (+/- delta) but in the# previous cycle.def _make_lags_for_second(multiple, num_cycles=3):# We use previous ``num_cycles`` hours to generate lagsreturn [_make_lags(k * 60 // multiple, 2) for k in range(1, num_cycles + 1)]def _make_lags_for_minute(multiple, num_cycles=3):# We use previous ``num_cycles`` hours to generate lagsreturn [_make_lags(k * 60 // multiple, 2) for k in range(1, num_cycles + 1)]def _make_lags_for_hour(multiple, num_cycles=7):# We use previous ``num_cycles`` days to generate lagsreturn [_make_lags(k * 24 // multiple, 1) for k in range(1, num_cycles + 1)]def _make_lags_for_day(multiple, num_cycles=4, days_in_week=7, days_in_month=30):# We use previous ``num_cycles`` weeks to generate lags# We use the last month (in addition to 4 weeks) to generate lag.return [_make_lags(k * days_in_week // multiple, 1)for k in range(1, num_cycles + 1)] + [_make_lags(days_in_month // multiple, 1)]def _make_lags_for_week(multiple, num_cycles=3):# We use previous ``num_cycles`` years to generate lags# Additionally, we use previous 4, 8, 12 weeksreturn [_make_lags(k * 52 // multiple, 1) for k in range(1, num_cycles + 1)] + [[4 // multiple, 8 // multiple, 12 // multiple]]def _make_lags_for_month(multiple, num_cycles=3):# We use previous ``num_cycles`` years to generate lagsreturn [_make_lags(k * 12 // multiple, 1) for k in range(1, num_cycles + 1)]# multiple, granularity = get_granularity(freq_str)offset = to_offset(freq_str)# normalize offset name, so that both `W` and `W-SUN` refer to `W`offset_name = norm_freq_str(offset.name)if offset_name == "A":lags = []elif offset_name == "Q":assert (offset.n == 1), "Only multiple 1 is supported for quarterly. Use x month instead."lags = _make_lags_for_month(offset.n * 3.0)elif offset_name == "M":lags = _make_lags_for_month(offset.n)elif offset_name == "W":lags = _make_lags_for_week(offset.n)elif offset_name == "D":lags = _make_lags_for_day(offset.n) + _make_lags_for_week(offset.n / 7.0)elif offset_name == "B":lags = _make_lags_for_day(offset.n, days_in_week=5, days_in_month=22) + _make_lags_for_week(offset.n / 5.0)elif offset_name == "H":lags = (_make_lags_for_hour(offset.n)+ _make_lags_for_day(offset.n / 24)+ _make_lags_for_week(offset.n / (24 * 7)))# minuteselif offset_name == "T":lags = (_make_lags_for_minute(offset.n)+ _make_lags_for_hour(offset.n / 60)+ _make_lags_for_day(offset.n / (60 * 24))+ _make_lags_for_week(offset.n / (60 * 24 * 7)))# secondelif offset_name == "S":lags = (_make_lags_for_second(offset.n)+ _make_lags_for_minute(offset.n / 60)+ _make_lags_for_hour(offset.n / (60 * 60)))else:raise Exception("invalid frequency")# flatten lags list and filterlags = [int(lag) for sub_list in lags for lag in sub_list if 7 < lag <= lag_ub]lags = list(range(1, num_default_lags + 1)) + sorted(list(set(lags)))return lags[:num_lags]

第一部分,生成以middle为中心,以delta为半径的区间[middle-delta,middle+delta] ,这很好理解,比如一周的周期是7天,周期大小在7天附近波动很正常。
在这里插入图片描述

第二部分,对于年月日时分秒这些不同的采样频率,采用不同的具体的函数来确定lags,其中有一个参数num_cycle,进一步利用了周期性,我们考虑间隔1、2、3、…num个周期的时间点之间的联系
在这里插入图片描述
原理类似于这张图,这种周期性的重复性体现在邻近的多个周期上

在这里插入图片描述

lag的用途

计算各类窗口大小

计算采样窗口大小

window_size = estimator.context_length + max(estimator.lags_seq) + estimator.prediction_length# Here we make a window slightly bigger so that instance sampler can sample from each window# An alternative is to have exact size and use different instance sampler (e.g. ValidationSplitSampler)
window_size = 10 * window_size
# We change ValidationSplitSampler to add min_pastestimator.validation_sampler = ValidationSplitSampler(min_past=estimator.context_length + max(estimator.lags_seq),min_future=estimator.prediction_length,)
  1. 构建静态特征
lags = lagged_sequence_values(self.lags_seq, prior_input, input, dim=-1)#构建一个包含给定序列的滞后值的数组static_feat = torch.cat((loc.abs().log1p(), scale.log()), dim=-1)
expanded_static_feat = unsqueeze_expand(static_feat, dim=-2, size=lags.shape[-2]
)return torch.cat((lags, expanded_static_feat, time_feat), dim=-1), loc, scale

数据集准备过程

对每个数据集采样,window_size=13500,也挺离谱的

 train_data, val_data = [], []for name in TRAIN_DATASET_NAMES:new_data = create_sliding_window_dataset(name, window_size)train_data.append(new_data)new_data = create_sliding_window_dataset(name, window_size, is_train=False)val_data.append(new_data)

采样的具体过程,这里有个问题,样本数量很小的数据集,实际采样窗口大小小于设定的window_size,后续会如何对齐呢?

文章设置单变量预测单变量,所以样本进行了通道分离,同一样本的不同特征被采样为不同的样本

def create_sliding_window_dataset(name, window_size, is_train=True):#划分非重叠的滑动窗口数据集,window_size是对数据集采样的数量,对每个数据集只取前windowsize个样本# Splits each time series into non-overlapping sliding windowsglobal_id = 0freq = get_dataset(name, path=dataset_path).metadata.freq#从数据集中获取时间频率data = ListDataset([], freq=freq)#创建空数据集dataset = get_dataset(name, path=dataset_path).train if is_train else get_dataset(name, path=dataset_path).test#获取原始数据集for x in dataset:windows = []#划分滑动窗口#target:滑动窗口的目标值#start:滑动窗口的起始位置#item_id,唯一标识符#feat_static_cat:静态特征数组for i in range(0, len(x['target']), window_size):windows.append({'target': x['target'][i:i+window_size],'start': x['start'] + i,'item_id': str(global_id),'feat_static_cat': np.array([0]),})global_id += 1data += ListDataset(windows, freq=freq)return data

合并数据集

# Here weights are proportional to the number of time series (=sliding windows)weights = [len(x) for x in train_data]# Here weights are proportinal to the number of individual points in all time series# weights = [sum([len(x["target"]) for x in d]) for d in train_data]train_data = CombinedDataset(train_data, weights=weights)val_data = CombinedDataset(val_data, weights=weights)
class CombinedDataset:def __init__(self, datasets, seed=None, weights=None):self._seed = seedself._datasets = datasetsself._weights = weightsn_datasets = len(datasets)if weights is None:#如果未提供权重,默认平均分配权重self._weights = [1 / n_datasets] * n_datasetsdef __iter__(self):return CombinedDatasetIterator(self._datasets, self._seed, self._weights)def __len__(self):return sum([len(ds) for ds in self._datasets])

网络结构

lagllama

class LagLlamaModel(nn.Module):def __init__(self,max_context_length: int,scaling: str,input_size: int,n_layer: int,n_embd: int,n_head: int,lags_seq: List[int],rope_scaling=None,distr_output=StudentTOutput(),num_parallel_samples: int = 100,) -> None:super().__init__()self.lags_seq = lags_seqconfig = LTSMConfig(n_layer=n_layer,n_embd=n_embd,n_head=n_head,block_size=max_context_length,feature_size=input_size * (len(self.lags_seq)) + 2 * input_size + 6,rope_scaling=rope_scaling,)self.num_parallel_samples = num_parallel_samplesif scaling == "mean":self.scaler = MeanScaler(keepdim=True, dim=1)elif scaling == "std":self.scaler = StdScaler(keepdim=True, dim=1)else:self.scaler = NOPScaler(keepdim=True, dim=1)self.distr_output = distr_outputself.param_proj = self.distr_output.get_args_proj(config.n_embd)self.transformer = nn.ModuleDict(dict(wte=nn.Linear(config.feature_size, config.n_embd),h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),ln_f=RMSNorm(config.n_embd),))

主要是transformer里面首先是一个线性层,然后加了n_layer个Block,最后是RMSNorm,接下来解析Block的代码

在这里插入图片描述

Block

class Block(nn.Module):def __init__(self, config: LTSMConfig) -> None:super().__init__()self.rms_1 = RMSNorm(config.n_embd)self.attn = CausalSelfAttention(config)self.rms_2 = RMSNorm(config.n_embd)self.mlp = MLP(config)self.y_cache = Nonedef forward(self, x: torch.Tensor, is_test: bool) -> torch.Tensor:if is_test and self.y_cache is not None:# Only use the most recent one, rest is in cachex = x[:, -1:]x = x + self.attn(self.rms_1(x), is_test)y = x + self.mlp(self.rms_2(x))if is_test:if self.y_cache is None:self.y_cache = y  # Build cacheelse:self.y_cache = torch.cat([self.y_cache, y], dim=1)[:, 1:]  # Update cachereturn y

代码看到这里不太想继续看了,太多glunoTS库里面的函数了,我完全不熟悉这个库,看起来太痛苦了,还有很多的困惑,最大的困惑就是数据是怎么对齐的,怎么输入到Llama里面的,慢慢看吧

其他

来源
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/823240.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Power Apps 学习笔记 - IOrganizationService Interface

文章目录 1. IOrganization Interface1.1 基本介绍1.2 方法分析 2. Entity对象2.1 Constructor2.2 Properties2.3 Methods 3. 相关方法3.1 单行查询 Retrive3.2 多行查询 RetriveMultiple3.3 增加 Create3.4 删除 Delete3.5 修改 Update 1. IOrganization Interface 1.1 基本介…

rax3000m刷openwrt固件

rax3000m刷机过程&#xff08;nand版本&#xff09; 刷机准备文件https://www.123pan.com/s/X5m9-6Ynj.html提取码:VtBW 接线关系&#xff1a;路由器lan口接电脑 1.上传配置开启ssh的配置文件&#xff08;登录路由器后台管理界面在找到配置管理&#xff0c;上传配置文件rax3…

[NCTF 2022] web题解

[NCTF 2022]calc 考点&#xff1a;python环境变量注入 打开题目&#xff0c;F12有hint 访问一下得到源码 app.route("/calc",methods[GET]) def calc():ip request.remote_addrnum request.values.get("num")log "echo {0} {1} {2}> ./tmp/log…

【Unity美术】Unity工程师对3D模型需要达到的了解【一】

&#x1f468;‍&#x1f4bb;个人主页&#xff1a;元宇宙-秩沅 &#x1f468;‍&#x1f4bb; hallo 欢迎 点赞&#x1f44d; 收藏⭐ 留言&#x1f4dd; 加关注✅! &#x1f468;‍&#x1f4bb; 本文由 秩沅 原创 &#x1f468;‍&#x1f4bb; 收录于专栏&#xff1a;Uni…

后端程序员React初接触1

后端程序员React初接触 学习react基础与相关库的使用学习 包括react基础 路由 组件库等等 react是用于构建用户界面的JavaScript库 发送请求获取数据处理数据操作dom呈现页面&#xff08;react帮忙操作dom&#xff09; 数据渲染为视图 有facebook打造并开源 解决的问题 dom操…

集群部署篇--Redis 哨兵模式

文章目录 前言一、哨兵模式介绍&#xff1a;1.1 介绍&#xff1a;1.2 工作机制&#xff1a; 二、哨兵模式搭建&#xff1a;2. 1 redis 主从搭建&#xff1a;2.2 setinel 集群搭建&#xff1a;2.2.1 配置&#xff1a; sentinel.conf &#xff1a;2.2.2 运行容器&#xff1a;2.2.…

jQuery日历签到插件下载

jQuery日历签到插件下载-遇见你与你分享

【MySQL】数据库之存储过程(“SQL语句的脚本“)

目录 一、什么是存储过程&#xff1f; 二、存储过程的作用 三、如何创建、调用、查看、删除、修改存储过程 四、存储过程的参数&#xff08;输入参数&#xff0c;输出参数&#xff0c;输入输出参数&#xff09; 第一种&#xff1a;输入参数 第二种&#xff1a;输出参数 …

Leetcode算法系列| 10. 正则表达式匹配

目录 1.题目2.题解C# 解法一&#xff1a;分段匹配法C# 解法二&#xff1a;回溯法C# 解法三&#xff1a;动态规划 1.题目 给你一个字符串 s 和一个字符规律 p&#xff0c;请你来实现一个支持 ‘.’ 和 ‘*’ 的正则表达式匹配。 1.‘.’ 匹配任意单个字符 2.‘.’ 匹配任意单个字…

【DevOps 工具链】日志管理工具 - 22种 选型(读这一篇就够了)

文章目录 1、简述2、内容分类3、归纳对比表&#xff08;排序不分先后&#xff09;4、日志管理主要目的5、日志管理工具 22种 详细&#xff08;排序不分先后&#xff09;5.1、ManageEngine EventLog Analyzer5.1.1、简介5.1.2、效果图5.1.3、日志管理架构5.1.4、EventLog Analyz…

HarmonyOS 路由传参

本文 我们来说两个page界面间的数据传递 路由跳转 router.pushUrl 之前我们用了不少了 但是我们只用了它的第一个参数 url 其实他还有个params参数 我们第一个组件可以编写代码如下 import router from ohos.router Entry Component struct Index {build() {Row() {Column() …

交互式笔记Jupyter Notebook本地部署并实现公网远程访问内网服务器

最近&#xff0c;我发现了一个超级强大的人工智能学习网站。它以通俗易懂的方式呈现复杂的概念&#xff0c;而且内容风趣幽默。我觉得它对大家可能会有所帮助&#xff0c;所以我在此分享。点击这里跳转到网站。 文章目录 1.前言2.Jupyter Notebook的安装2.1 Jupyter Notebook下…

C编程指针篇----包括历年真题

一&#xff0c;&#xff08;20年&#xff09;用指针字符逆序 代码&#xff1a; int main() {char s[7] "monkey", * p1, * p2, c;p1 p2 s;while (*p2) p2;p2--;while (p2 > p1) {c *p1; *p1 *p2; *p2-- c; }printf("%s", s);return 0; } 运行结…

【华为机试】2023年真题B卷(python)-解密犯罪时间

一、题目 题目描述&#xff1a; 警察在侦破一个案件时&#xff0c;得到了线人给出的可能犯罪时间&#xff0c;形如 “HH:MM” 表示的时刻。 根据警察和线人的约定&#xff0c;为了隐蔽&#xff0c;该时间是修改过的&#xff0c;解密规则为&#xff1a; 利用当前出现过的数字&am…

jdk与cglib动态代理及原理

Spring的AOP在运行时多以jdk及cglib动态代理来实现。&#xff08;作者jdk是1.8版本&#xff09; 1 jdk 动态代理 Java中使用动态代理&#xff0c;只能对接口进行代理&#xff0c;不能对普通类进行代理。主要是由一个类及一个接口来实现&#xff1a; InvocationHandler&#…

【并发设计模式】聊聊等待唤醒机制的规范实现

在多线程编程中&#xff0c;其实就是分工、协作、互斥。在很多场景中&#xff0c;比如A执行的过程中需要同步等待另外一个线程处理的结果&#xff0c;这种方式下&#xff0c;就是一种等待唤醒的机制。本篇我们来讲述等待唤醒机制的三种实现&#xff0c;以及对应的应用场景。 G…

Python基础进阶3:函数和方法不是一回事

你好&#xff0c;我是kelly&#xff0c;今天分享的是Python的函数与方法的不同点。 对于Python的函数和方法是不一样的&#xff0c;这一点需要注意下。 一、结论 1、不存在隐式传参&#xff0c;所有参数都是显式传递的是函数。 2、存在隐式传参的是方法&#xff0c;一般指隐式…

神经元科技发布AI agent—“萨蔓莎”

今天神经元科技发布AI agent—“萨蔓莎“&#xff08;Samantha &#xff09;&#xff01; 取名“萨蔓莎”&#xff0c;是来自于一部讲述AI的电影《HER》。 电影讲述的是电影讲述男子西奥多汤布里&#xff08;Theodore Twombly&#xff0c;饰&#xff09;与拟人化萨曼莎&#…

日志记录、跟踪和指标

我的新书《Android App开发入门与实战》已于2020年8月由人民邮电出版社出版&#xff0c;欢迎购买。点击进入详情 日志记录、跟踪和指标是系统可观察性的三大支柱。 下图显示了它们的定义和典型架构。 记录 日志记录系统中的离散事件。例如&#xff0c;我们可以将传入请求或对…

挑战Python100题(8)

100+ Python challenging programming exercises 8 Question 71 Please write a program which accepts basic mathematic expression from console and print the evaluation result. 请编写一个从控制台接受基本数学表达式的程序,并打印评估结果。 Example: If the follo…