LSTM时序预测 | Python实现LSTM长短期记忆神经网络时间序列预测

本文内容:Python实现LSTM长短期记忆神经网络时间序列预测,使用的数据集为AirPassengers

目录

数据集简介

1.步骤一

2.步骤二

3.步骤三

4.步骤四

数据集简介

AirPassengers 数据集的来源可以追溯到经典的统计和时间序列分析文献。原始数据集由 Box, Jenkins 和 Reinsel 在他们的书籍《Time Series Analysis: Forecasting and Control》中引入,这本书在时间序列分析领域非常著名

1.训练结果

2.步骤一

安装darts库:

pip install darts

#在连接处添加注意力机制
class UNetAttention1(nn.Module):def __init__(self, n_channels, n_classes, bilinear=False, attention=False):super(UNetAttention1, self).__init__()self.model_name = 'UNetAttention1'self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinearself.attention = attentionself.inc = (DoubleConv(n_channels, 64))self.down1 = (Down(64, 128))self.down2 = (Down(128, 256))self.down3 = (Down(256, 512))factor = 2 if bilinear else 1self.down4 = (Down(512, 1024 // factor))self.up1 = (Up(1024, 512 // factor, bilinear))self.up2 = (Up(512, 256 // factor, bilinear))self.up3 = (Up(256, 128 // factor, bilinear))self.up4 = (Up(128, 64, bilinear))self.outc = (OutConv(64, n_classes))if self.attention:self.attention1 = CBAM(64)self.attention2 = CBAM(128)self.attention3 = CBAM(256)self.attention4 = CBAM(512)def forward(self, x):x1 = self.inc(x)if self.attention:x1 = self.attention1(x1) + x1x2 = self.down1(x1)if self.attention:x2 = self.attention2(x2) + x2x3 = self.down2(x2)if self.attention:x3 = self.attention3(x3) + x3x4 = self.down3(x3)if self.attention:x4 = self.attention4(x4) + x4x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logitsdef use_checkpointing(self):self.inc = torch.utils.checkpoint(self.inc)self.down1 = torch.utils.checkpoint(self.down1)self.down2 = torch.utils.checkpoint(self.down2)self.down3 = torch.utils.checkpoint(self.down3)self.down4 = torch.utils.checkpoint(self.down4)self.up1 = torch.utils.checkpoint(self.up1)self.up2 = torch.utils.checkpoint(self.up2)self.up3 = torch.utils.checkpoint(self.up3)self.up4 = torch.utils.checkpoint(self.up4)self.outc = torch.utils.checkpoint(self.outc)

3.步骤二

部分代码如下:


import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import pandas as pd
import shutil
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm_notebook as tqdm
import matplotlib.pyplot as pltfrom darts import TimeSeries
from darts.dataprocessing.transformers import Scaler
from darts.models import RNNModel, ExponentialSmoothing, BlockRNNModel
from darts.metrics import mape, mae, mse, rmse
from darts.utils.statistics import check_seasonality, plot_acf
from darts.datasets import AirPassengersDataset, SunspotsDataset
from darts.utils.timeseries_generation import datetime_attribute_timeseriesimport warningswarnings.filterwarnings("ignore")
import logginglogging.disable(logging.CRITICAL)####################数据准备##########################
# Read data:
series = AirPassengersDataset().load()  #原始数据集由 Box, Jenkins 和 Reinsel 在他们的书籍《Time Series Analysis: Forecasting and Control》中引入# Create training and validation sets:
train, val = series.split_after(pd.Timestamp("19590101")) ##可以填写具体的日期,也可以填写比例# Normalize the time series (note: we avoid fitting the transformer on the validation set)
transformer = Scaler()
train_transformed = transformer.fit_transform(train)
val_transformed = transformer.transform(val)
series_transformed = transformer.transform(series)# create month and year covariate series
year_series = datetime_attribute_timeseries(pd.date_range(start=series.start_time(), freq=series.freq_str, periods=1000),attribute="year",one_hot=False,
)
year_series = Scaler().fit_transform(year_series)
month_series = datetime_attribute_timeseries(year_series, attribute="month", one_hot=True
)
covariates = year_series.stack(month_series)
cov_train, cov_val = covariates.split_after(pd.Timestamp("19590101"))####################构建模型##########################
my_model = RNNModel(model="LSTM",hidden_dim=20,dropout=0,batch_size=16,n_epochs=300,optimizer_kwargs={"lr": 1e-3},model_name="Air_RNN",log_tensorboard=True,random_state=42,training_length=20,input_chunk_length=14,force_reset=True,save_checkpoints=True,
)my_model.fit(train_transformed,future_covariates=covariates,val_series=val_transformed,val_future_covariates=covariates,verbose=True,
)

完整代码下载地址:下载地址

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1556967.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

面向对象特性中 继承详解

目录 概念: 定义: 定义格式 继承关系和访问限定符 基类和派生类对象赋值转换: 继承中的作用域: 派生类的默认成员函数 继承与友元: 继承与静态成员: 复杂的菱形继承及菱形虚拟继承: 虚…

VGG16模型实现MNIST图像分类

MNIST图像数据集 MNIST(Modified National Institute of Standards and Technology)是一个经典的机器学习数据集,常用于训练和测试图像处理和机器学习算法,特别是在数字识别领域。该数据集包含了大约 7 万张手写数字图片&#xf…

喜讯 | 攸信技术入选第六批专精特新“小巨人”企业

日前,根据工信部评审结果,厦门市工业和信息化局公示了第六批专精特新“小巨人”企业和第三批专精特新“小巨人”复核通过企业名单,其中,厦门攸信信息技术有限公司进入第六批专精特新“小巨人”企业培育。 “专精特新”企业是指具有…

图像分割恢复方法

传统的图像分割方法主要依赖于图像的灰度值、纹理、颜色等特征,通过不同的算法将图像分割成多个区域。这些方法通常可以分为以下几类: 1.基于阈值的方法 2.基于边缘的方法 3.基于区域的方法 4.基于聚类的方法 下面详细介绍这些方法及其示例代码。 1. 基…

代码随想录--栈与队列--用栈实现队列

队列是先进先出,栈是先进后出。 如图所示: 题目 使用栈实现队列的下列操作: push(x) – 将一个元素放入队列的尾部。 pop() – 从队列首部移除元素。 peek() – 返回队列首部的元素。 empty() – 返回队列是否为空。 示例: MyQueue qu…

draw.io 设置默认字体及添加常用字体

需求描述 draw.io 是一个比较好的开源免费画图软件。但是其添加容器或者文本框时默认的字体是 Helvetica,一般的期刊、会议论文或者学位论文要求的英文字体是 Times New Roman,中文字体是 宋体,所以一般需要在文本字体选项里的下拉列表选择 …

分层解耦-05.IOCDI-DI详解

一.依赖注入的注解 在我们的项目中,EmpService的实现类有两个,分别是EmpServiceA和EmpServiceB。这两个实现类都加上Service注解。我们运行程序,就会报错。 这是因为我们依赖注入的注解Autowired默认是按照类型来寻找bean对象的进行依赖注入…

2-115 基于matlab的瞬态提取变换(TET)时频分析

基于matlab的瞬态提取变换(TET)时频分析,瞬态提取变换是一种比较新的TFA方法。该方法的分辨率较高,能够较好地提取出故障的瞬态特征,用于故障诊断领域。通过对原始振动信号设置不同信噪比噪声,对该方法的抗…

关于一个模仿qq通信程序

7月份的时候还在学校那个时候想要学习嵌入式Linux,但是还没有买开发板来玩,再学linux系统编程,网络编程,Linux系统的文件IO,于是学完之后想做一个模仿qq的通信程序于是就有了这个“ailun.exe”,因为暑假去打…

【数据结构与算法】线性表

文章目录 一.什么是线性表?二.线性表如何存储?三.线性表的类型 我们知道从应用中抽象出共性的逻辑结构和基本操作就是抽象数据类型,然后实现其存储结构和基本操作。下面我们依然按这个思路来认识线性表 一.什么是线性表? 定义 线性…

TryHackMe 第7天 | Web Fundamentals (二)

继续介绍一些 Web hacking 相关的漏洞。 IDOR IDOR (Insecure direct object reference),不安全的对象直接引用,这是一种访问控制漏洞。 当 Web 服务器接收到用户提供的输入来检索对象时 (包括文件、数据、文档),如果对用户输入数据过于信…

【springboot】使用代码生成器快速开发

接上一项目&#xff0c;使用mybatis-plus-generator实现简易代码文件生成 在fast-demo-web模块中的pom.xml中添加mybatis-plus-generator、freemarker和Lombok依赖 <dependency><groupId>com.baomidou</groupId><artifactId>mybatis-plus-generator&…

Python | 由高程计算坡度和坡向

写在前面 之前参加一个比赛&#xff0c;提供了中国的高程数据&#xff0c;可以基于该数据进一步计算坡度和坡向进行相关分析。 对于坡度和坡向&#xff0c;这里分享一个找到的库&#xff0c;可以方便快捷的计算。这个库为&#xff1a;RichDEM&#xff0c;官网地址如下 https…

SAP学习笔记 - 豆知识11 - 如何查询某个字段/DataElement/Domain在哪个表里使用?

大家知道SAP的表有10几万个&#xff08;也有说30多万个的&#xff0c;总之很多就是了&#xff09;&#xff0c;而且不断增多&#xff0c;那么当想知道一个字段在哪个表里使用的时候该怎么办呢&#xff1f; 思路就是SAP的表其实也是存在表里的&#xff1a;&#xff09;&#xf…

【Git】TortoiseGitPlink提示输入密码解决方法

问题 克隆仓库&#xff0c;TortoiseGitPlink提示输入密码 解法 1、打开TortoiseGit 下的puttygen工具 位置&#xff1a;C:\Program Files\TortoiseGit\bin\ 2、点击【Load】按钮&#xff0c;载入 C:\Users\Administrator\.ssh\ 文件夹下的id_rsa文件。 3、点击save private …

qt_c++_xml存这种复杂类型

demo&#xff0c;迅雷链接。或者我主页上传的资源 链接&#xff1a;https://pan.xunlei.com/s/VO8bIvYFfhmcrwF-7wmcPW1SA1?pwdnrp4# 复制这段内容后打开手机迅雷App&#xff0c;查看更方便 #ifndef MAINWINDOW_H #define MAINWINDOW_H#include <QMainWindow>#include…

请散户股民看过来,密切关注两件大事

明天股市要开市&#xff0c;不仅散户股民期盼节后股市大涨&#xff0c;上面也同样想在节后来上一个“开门红”。 为此&#xff0c;上面没休假&#xff0c;关起门来办了两件大事&#xff0c;这两天发布消息已提前预热了。 两件大事如下&#xff1a; 一是&#xff0c;上交所10…

什么是 JavaScript 的数组空槽

JavaScript 中的数组空槽一直是一个非常有趣且颇具争议的话题。我们可能对它的实际意义、历史以及现今的新版本中对它的处理方式有所疑问。数组空槽的存在最早可以追溯到 JavaScript 的诞生之初&#xff0c;当时的设计决定让它成为了现代 JavaScript 开发中的一种特别的现象。 …

大数据新视界 --大数据大厂之数据血缘追踪与治理:确保数据可追溯性

&#x1f496;&#x1f496;&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎你们来到 青云交的博客&#xff01;能与你们在此邂逅&#xff0c;我满心欢喜&#xff0c;深感无比荣幸。在这个瞬息万变的时代&#xff0c;我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

计算机毕业设计hadoop+spark天气预测 天气可视化 天气大数据 空气质量检测 空气质量分析 气象大数据 气象分析 大数据毕业设计 大数据毕设

Hadoop天气预测系统开题报告 一、研究背景与意义 在信息化和大数据时代&#xff0c;天气数据已成为社会生活和经济发展中不可或缺的重要资源。天气预测系统作为现代气象学的重要组成部分&#xff0c;对于农业生产、交通管理、环境保护以及防灾减灾等方面都具有重要意义。然而…