基于Python的人工智能应用案例系列(15):LSTM酒类销售预测

        在本篇文章中,我们将使用时间序列数据分析技术,基于美国联邦储备经济数据库(FRED)中的酒类销售数据,使用LSTM(长短期记忆网络)进行未来销售量的预测。本案例展示了如何构建LSTM模型,训练模型进行时间序列预测,并使用历史数据进行模型的评估与未来预测。

1. 加载数据

        首先,我们加载酒类销售数据,该数据集包含了从1992年到2019年的月度销售记录。

import pandas as pd# 加载数据集
df = pd.read_csv('../data/Alcohol_Sales.csv', index_col=0, parse_dates=True)# 查看数据长度
len(df)# 清理数据,移除空值
df.dropna(inplace=True)
len(df)# 查看前5行数据
df.head()# 查看最后5行数据
df.tail()

2. EDA:时间序列数据可视化

        接下来,我们将数据进行可视化,以便更好地理解时间序列趋势。

import matplotlib.pyplot as pltplt.figure(figsize=(12,4))
plt.title('Beer, Wine, and Alcohol Sales')
plt.ylabel('Sales (millions of dollars)')
plt.grid(True)
plt.autoscale(axis='x',tight=True)
plt.plot(df['S4248SM144NCEN'])
plt.show()

3. 特征提取与数据准备

        我们将数据分为训练集和测试集,分别用于模型训练和评估。训练集将被归一化为-1到1之间的值,以提高训练效率。

import numpy as np
from sklearn.preprocessing import MinMaxScaler# 提取数据
y = df['S4248SM144NCEN'].values.astype(float)# 定义测试集大小
test_size = 12# 划分训练集和测试集
train_set = y[:-test_size]
test_set  = y[-test_size:]# 实例化归一化工具
scaler = MinMaxScaler(feature_range=(-1, 1))# 归一化训练集
train_norm = scaler.fit_transform(train_set.reshape(-1, 1))# 转换为张量
import torch
train_norm = torch.FloatTensor(train_norm).view(-1)# 定义窗口大小
window_size = 12# 创建输入数据
def input_data(seq,ws):out = []L = len(seq)for i in range(L-ws):window = seq[i:i+ws]label = seq[i+ws:i+ws+1]out.append((window,label))return outtrain_data = input_data(train_norm, window_size)

4. 构建LSTM模型

        接下来,我们定义一个包含LSTM层的神经网络模型。

import torch.nn as nnclass LSTMnetwork(nn.Module):def __init__(self, input_size=1, hidden_size=100, output_size=1):super().__init__()self.hidden_size = hidden_sizeself.lstm = nn.LSTM(input_size, hidden_size)self.linear = nn.Linear(hidden_size, output_size)self.hidden = (torch.zeros(1, 1, self.hidden_size),torch.zeros(1, 1, self.hidden_size))def forward(self, seq):lstm_out, self.hidden = self.lstm(seq.view(len(seq), 1, -1), self.hidden)pred = self.linear(lstm_out.view(len(seq), -1))return pred[-1]  # 我们只需要最后一个预测值

5. 训练模型

        我们将使用均方误差损失函数(MSE)和Adam优化器来训练模型。

torch.manual_seed(101)
model = LSTMnetwork()criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# 训练模型
epochs = 100for epoch in range(epochs):for seq, y_train in train_data:optimizer.zero_grad()model.hidden = (torch.zeros(1, 1, model.hidden_size),torch.zeros(1, 1, model.hidden_size))y_pred = model(seq)loss = criterion(y_pred, y_train)loss.backward()optimizer.step()print(f'Epoch: {epoch+1:2} Loss: {loss.item():10.8f}')

6. 进行预测并与测试集比较

        模型训练完毕后,我们使用模型对未来12个月的数据进行预测,并与测试集进行对比。

future = 12# 使用最后一个训练窗口的值进行预测
preds = train_norm[-window_size:].tolist()model.eval()
for i in range(future):seq = torch.FloatTensor(preds[-window_size:])with torch.no_grad():model.hidden = (torch.zeros(1, 1, model.hidden_size),torch.zeros(1, 1, model.hidden_size))preds.append(model(seq).item())# 反归一化预测结果
true_predictions = scaler.inverse_transform(np.array(preds[window_size:]).reshape(-1, 1))# 可视化预测结果
import numpy as npx = np.arange('2018-02-01', '2019-02-01', dtype='datetime64[M]').astype('datetime64[D]')plt.figure(figsize=(12,4))
plt.title('Beer, Wine, and Alcohol Sales')
plt.ylabel('Sales (millions of dollars)')
plt.grid(True)
plt.plot(df['S4248SM144NCEN'])
plt.plot(x,true_predictions)
plt.show()

7. 预测未来数据

        我们将模型应用于整个数据集,并预测未来12个月的销售数据。

# 训练整个数据集并预测未来
epochs = 100# 归一化整个数据集
y_norm = scaler.fit_transform(y.reshape(-1, 1))
y_norm = torch.FloatTensor(y_norm).view(-1)
all_data = input_data(y_norm, window_size)for epoch in range(epochs):for seq, y_train in all_data:optimizer.zero_grad()model.hidden = (torch.zeros(1, 1, model.hidden_size),torch.zeros(1, 1, model.hidden_size))y_pred = model(seq)loss = criterion(y_pred, y_train)loss.backward()optimizer.step()# 预测未来12个月数据
future = 12
preds = y_norm[-window_size:].tolist()model.eval()
for i in range(future):seq = torch.FloatTensor(preds[-window_size:])with torch.no_grad():model.hidden = (torch.zeros(1, 1, model.hidden_size),torch.zeros(1, 1, model.hidden_size))preds.append(model(seq).item())# 反归一化预测值并可视化
true_predictions = scaler.inverse_transform(np.array(preds).reshape(-1, 1))x = np.arange('2019-02-01', '2020-02-01', dtype='datetime64[M]').astype('datetime64[D]')plt.figure(figsize=(12,4))
plt.title('Beer, Wine, and Alcohol Sales')
plt.ylabel('Sales (millions of dollars)')
plt.grid(True)
plt.plot(df['S4248SM144NCEN'])
plt.plot(x,true_predictions[window_size:])
plt.show()

结语

        在本篇案例中,我们通过LSTM模型对美国酒类销售的时间序列数据进行了分析和预测。通过归一化处理、模型训练、测试集验证以及未来趋势预测,我们可以看到LSTM模型能够有效捕捉数据的时间依赖性,进而在一定程度上准确预测未来的销售趋势。虽然我们的模型表现较好,但仍有一些误差,可以通过调整模型参数、增加训练数据或采用其他高级算法进一步优化。

        本案例展示了LSTM在时间序列预测中的应用,证明了其在捕捉长期依赖性和模式识别中的强大能力。对于需要预测趋势、销售额、库存等的商业决策场景,LSTM提供了可靠的解决方案。未来的工作中,我们可以探索更多模型的改进和应用场景,以提升预测的准确性和实用性。

        通过本案例的学习,希望您对时间序列预测和LSTM网络有了更深入的理解,并能将其应用到更多实际问题中。

如果你觉得这篇博文对你有帮助,请点赞、收藏、关注我,并且可以打赏支持我!

欢迎关注我的后续博文,我将分享更多关于人工智能、自然语言处理和计算机视觉的精彩内容。

谢谢大家的支持!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1551038.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

touch命令:创建文件,更新时间戳

一、命令简介 ​touch​ 命令在 Linux 和其他类 Unix 系统中用于创建空白文件或者更新已存在文件的时间戳。如果指定的文件不存在,touch​ 命令会创建一个空白文件;如果文件已经存在,touch​ 命令会更新文件的访问时间和修改时间&#xff0c…

springboot+大数据+基于协同过滤算法的校园食堂订餐系统【内含源码+文档+部署教程】

博主介绍:✌全网粉丝10W,前互联网大厂软件研发、集结硕博英豪成立工作室。专注于计算机相关专业毕业设计项目实战6年之久,选择我们就是选择放心、选择安心毕业✌ 🍅由于篇幅限制,想要获取完整文章或者源码,或者代做&am…

数据权限的设计与实现系列11——前端筛选器组件Everright-filter集成功能完善2

‍ 筛选条件数据类型完善 文本类 筛选器组件给了一个文本类操作的范例,如下: Text: [{label: 等于,en_label: Equal,style: noop},{label: 等于其中之一,en_label: Equal to one of,value: one_of,style: tags},{label: 不等于,en_label: Not equal,v…

LeetCode 面试经典150题 69.x的平方根

题目:给你一个非负整数 x ,计算并返回 x 的 算术平方根 。 由于返回类型是整数,结果只保留 整数部分 ,小数部分将被 舍去 。注意:不允许使用任何内置指数函数和算符,例如 pow(x, 0.5) 或者 x ** 0.5 。 思…

【Python报错已解决】TypeError: ‘list‘ object is not callable

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 专栏介绍 在软件开发和日常使用中,BUG是不可避免的。本专栏致力于为广大开发者和技术爱好者提供一个关于BUG解决的经…

基于Springboot+Vue的课程教学平台的设计与实现系统(含源码数据库)

1.开发环境 开发系统:Windows10/11 架构模式:MVC/前后端分离 JDK版本: Java JDK1.8 开发工具:IDEA 数据库版本: mysql5.7或8.0 数据库可视化工具: navicat 服务器: SpringBoot自带 apache tomcat 主要技术: Java,Springboot,mybatis,mysql,vue 2.视频演示地址 3.功能 这个系…

【电力系统】电力系统状态估计

摘要 电力系统状态估计是确保电力系统安全稳定运行的重要技术之一。本文利用Matlab实现了一种基于加权最小二乘法(WLS)的状态估计算法,能够在不同测量条件下准确估计电力系统的状态变量。通过对典型电力系统的仿真分析,验证了算法…

第三节-类与对象(2)默认成员函数详解

1.类的6个默认成员函数 如果一个类中什么成员都没有,简称为空类(空类大小为1)。 空类中真的什么都没有吗?并不是,任何类在什么都不写时,编译器会自动生成以下6个默认成员函数。 默认成员函数:…

第L2周:机器学习|线性回归模型 LinearRegression:2. 多元线性回归模型

本文为365天深度学习训练营 中的学习记录博客原作者:K同学啊 任务: ●1. 学习本文的多元线形回归模型。 ●2. 参考文本预测花瓣宽度的方法,选用其他三个变量来预测花瓣长度。 一、多元线性回归 简单线性回归:影响 Y 的因素唯一&…

依赖倒置原则(学习笔记)

抽象不应该依赖细节,细节应该依赖抽象。简单的说就是要求对抽象进行编程,不要对实现进行编程,这样就降低了客户与实现模块间的耦合。 依赖倒转原则是基于这样的设计理念:相对于细节的多变性,抽象的东西要稳定的多。 以…

vue + echarts 快速入门

vue echarts 快速入门 本案例即有nodejs和vue的基础,又在vue的基础上整合了echarts Nodejs基础 1、Node简介 1.1、为什么学习Nodejs(了解) 轻量级、高性能、可伸缩web服务器前后端JavaScript同构开发简洁高效的前端工程化 1.2、Nodejs能做什么(了解) Node 打破了…

Android 安卓内存安全漏洞数量大幅下降的原因

谷歌决定使用内存安全的编程语言 Rust 向 Android 代码库中写入新代码,尽管旧代码(用 C/C 编写)没有被重写,但内存安全漏洞却大幅减少。 Android 代码库中每年发现的内存安全漏洞数量(来源:谷歌&#xff09…

常用的cmd命令——使用bat命令创建程序的快捷方式

示例使用场景:例如便携版的软件,需要往桌面发快捷方式 如便携的浏览器,给桌面发送快捷方式,同时设置快捷方式的启动参数。 下面以谷歌浏览器为例: 浏览器的App的下级目录为如下内容 知道了所需文件的位置,…

废品回收小程序/环保垃圾回收/收二手垃圾小程序/分类资源回收系统/独立版系统源码

>>>系统简述: 1.以微信小程序为基础进行开发,体验好,操作方便 2.从用户下单到回收员接单,在到回收站接收,在到代理全流程通过手机端管理 3.支持废品分类下单,并支持分类数据统计 4.独创回收员多个…

五金精密加工提升效率的方法与技巧

在五金精密加工领域,提高加工效率是企业增强竞争力的关键。以下是一些有效的提升方法与技巧。 一、优化加工设备 设备升级与更新 定期评估加工设备的性能,引进先进的五金精密加工机床。例如,高精度的数控加工中心能够实现多轴联动加工&#x…

Android15车载音频之CarAudioService加载解析各音区参数过程(八十七)

简介: CSDN博客专家、《Android系统多媒体进阶实战》一书作者 新书发布:《Android系统多媒体进阶实战》🚀 优质专栏: Audio工程师进阶系列【原创干货持续更新中……】🚀 优质专栏: 多媒体系统工程师系列【原创干货持续更新中……】🚀 优质视频课程:AAOS车载系统+…

【sourceTree问题】拉取提交的时候需要频繁输入账号密码

用sourceTree进行代码管理的时候会出现一直让输入账号密码的问题,烦不胜烦,可以点击【设置】 → 【编辑配置文件...】打开配置文件: 在配置文件里找到url,把url里面的网址修改为: http://username:passwordxxxxx/xx…

Qt——如何创建一个项目

前言 本文主要通过实操带领大家来实现基础文件的操作,主要包括文件的打开,读取,写入,当然文件读写我们可以有几种不同的方式来进行操作,分别是文件流,字节流来进行的操作这里就需要两个类分别是文件流&…

速通数据结构与算法第六站 树堆

系列文章目录 速通数据结构与算法系列 1 速通数据结构与算法第一站 复杂度 http://t.csdnimg.cn/sxEGF 2 速通数据结构与算法第二站 顺序表 http://t.csdnimg.cn/WVyDb 3 速通数据结构与算法第三站 单链表 http://t.csdnimg.cn/cDpcC 4 速通…

学习记录:js算法(四十四):二叉树的最大深度

文章目录 二叉树的最大深度我的思路网上思路 总结 二叉树的最大深度 给定一个二叉树 root ,返回其最大深度。 二叉树的 最大深度 是指从根节点到最远叶子节点的最长路径上的节点数。 图一: 示例 1:(如图一) 输入:root [3,9,20,…