lstm实践

今年华为杯研究生数学建模的C题第四问用到了lstm,这里配合代码简要地讲一下。

数据类型

磁通密度是一个时序数据,包含了一个周期内的磁通密度变化,我们需要对它进行降维,但PCA是不合适的,因为PCA主要关注数据的方差,无法有效捕捉周期性数据的重要特征,而磁通密度是周期性变化的。

在自然语言处理领域中,LSTM可以捕获序列内部元素之间的关联性,并且其隐藏层可以包含前序序列的信息。最后一层的隐藏层就包含了整个序列的信息,所以我们可以将最后一层的隐藏层作为降维后的向量。

我们选择LSTM对1024 维的磁通密度进行降维,具体做法是:训练时对一个周期进行切片,使用LSTM预测切片的下一时刻的磁通密度;降维时使用整个周期,获取最后一 层的hidden state作为该样本的磁通密度特征。

代码

1.数据处理

import pandas as pd
import torch as pt
import os
os.chdir('/home/burger/math/')df1 = pd.read_excel('./data/附件一(训练集).xlsx', sheet_name='材料1')
df2 = pd.read_excel('./data/附件一(训练集).xlsx', sheet_name='材料2')
df3 = pd.read_excel('./data/附件一(训练集).xlsx', sheet_name='材料3')
df4 = pd.read_excel('./data/附件一(训练集).xlsx', sheet_name='材料4')collom_name = [i for i in range(1,1024)]
B1 = df1[['0(磁通密度B,T)']+collom_name]
B2 = df2[['0(磁通密度,T)']+collom_name]
B3 = df3[['0(磁通密度B,T)']+collom_name]
B4 = df4[['0(磁通密度B,T)']+collom_name]
print(B1.head())B1_t = pt.tensor(B1.values)
B2_t = pt.tensor(B2.values)
B3_t = pt.tensor(B3.values)
B4_t = pt.tensor(B4.values)
B = pt.cat((B1_t, B2_t, B3_t, B4_t), 0)
print(B.shape)def create_dataset(data, time_step=64):  x, y = [], []  for i in range(0, data.shape[1] - time_step, 32):  a = data[:, i:(i + time_step)]  x.append(a)  y.append(data[:, i + time_step])  return pt.concat(x).float(), pt.concat(y).float()X, Y = create_dataset(B)
print(X.shape, Y.shape)
X, Y = X.unsqueeze(-1), Y.unsqueeze(-1)
print(X.shape, Y.shape)

 2.模型定义

import torch.nn as nn 
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm# LSTM模型定义
class LSTMModel(nn.Module):def __init__(self, input_size, hidden_size, output_size, num_layers=1):super(LSTMModel, self).__init__()self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x):lstm_out, _ = self.lstm(x)out = self.fc(lstm_out[:, -1, :])  # 只取最后一个时间步的输出return outdef embedding(self, x):_, hid_cell = self.lstm(x)return hid_cell[0]

3.训练

input_size = 1
hidden_size = 1
output_size = 1
num_layers = 1
num_epochs = 5
batch_size = 2048
gpu = 6
train_dataset = TensorDataset(X.to(gpu), Y.to(gpu))
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)model = LSTMModel(input_size, hidden_size, output_size, num_layers).to(gpu)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)def train(model, dataloader, criterion, optimizer, num_epochs):model.train()for epoch in range(num_epochs):for inputs, labels in tqdm(dataloader, unit='batch'):optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')
train(model, train_loader, criterion, optimizer, num_epochs)

4.降维

df_san = pd.read_excel('./data/附件三(测试集).xlsx')
B_san = df_san[['0(磁通密度B,T)']+collom_name]
B_san_t = pt.tensor(B_san.values)
B_san_t = B_san_t.unsqueeze(-1).float().to(gpu)
print(B_san_t.shape)emb_dataset = TensorDataset(B_san_t)
emb_loader = DataLoader(dataset=emb_dataset, batch_size=400, shuffle=False)def embedding(model, dataloader):model.eval()embeddings = []for inputs in tqdm(dataloader, unit='batch'):outputs = model.embedding(inputs[0])embeddings.append(outputs)return pt.cat(embeddings).cpu().detach().numpy()embeddings = embedding(model, emb_loader)
print(embeddings.shape)
embeddings = embeddings.reshape(-1)
print(embeddings.shape)embeddings_df = pd.DataFrame({'磁通密度编码': embeddings,
})
embeddings_df.to_excel('./data/磁通密度编码1.xlsx', index=False)

有问题欢迎在评论区讨论!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1552730.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Python核心知识:pip使用方法大全

什么是 pip? pip 是 Python 的包管理工具,允许用户安装、升级和管理 Python 的第三方库和依赖。它极大地简化了开发过程,使开发者可以轻松地获取并安装所需的软件包。pip 已成为 Python 项目中最常见的包管理工具,并且自 Python …

windows C++-UWP 应用中使用 HttpRequest 类

在 UWP 应用中使用 HttpRequest 类 本节演示在 UWP 应用中如何使用 HttpRequest 类。 应用程序会提供一个输入框,该输入框定义了一个 URL 资源、用于执行 GET 和 POST 操作的按钮命令和用于取消当前操作的按钮命令。 使用 HttpRequest 类 1. 在 MainPage.xaml 中…

8639 折半插入排序

### 思路 折半插入排序是一种改进的插入排序算法,通过二分查找来确定插入位置,从而减少比较次数。每次插入时,先用二分查找找到插入位置,然后将元素插入到正确的位置。 ### 伪代码 1. 读取输入的待排序关键字个数n。 2. 读取n个待…

class 030 异或运算的骚操作

这篇文章是看了“左程云”老师在b站上的讲解之后写的, 自己感觉已经能理解了, 所以就将整个过程写下来了。 这个是“左程云”老师个人空间的b站的链接, 数据结构与算法讲的很好很好, 希望大家可以多多支持左程云老师, 真心推荐. https://space.bilibili.com/8888480?spm_id_f…

【CKA】五、网络策略–NetworkPolicy

5、配置网络策略–NetworkPolicy 1. 考题内容: 2. 答题思路: 1、根据题目分析要创建怎样的网络策略 2、按题目要求查看ns corp-net的label 3、编写yaml,其中注意 namespace、label、port 3. 官网地址: https://kubernetes.io/…

解决connect因父类不明确而报错的问题

如图所示&#xff0c;connect函数报错&#xff0c;原因是connect的检查是在编译期完成的&#xff0c;而传入父类则是在运行时&#xff0c;从而引起connect不知道parent是谁而报错。只需加入类型转换即可。 connect(qobject_cast<TableWidget*>(parent), &TableWidg…

STM32F1+HAL库+FreeTOTS学习15——互斥信号量

STM32F1HAL库FreeTOTS学习15——互斥信号量 1. 优先级翻转2. 互斥信号量3. 相关API函数&#xff1b;3.1 互斥信号量创建3.2 获取信号量3.3 释放信号量3.4 删除信号量 4. 操作实验1. 实验内容2. 代码实现3. 运行结果 上期我们介绍了数值信号量。这一期我们来介绍互斥信号量 1. 优…

【计算机毕业设计】springboot企业客户信息反馈平台

摘 要 网络的广泛应用给生活带来了十分的便利。所以把企业客户信息反馈管理与现在网络相结合&#xff0c;利用java技术建设企业客户信息反馈平台&#xff0c;实现企业客户信息反馈的信息化。则对于进一步提高企业客户信息反馈管理发展&#xff0c;丰富企业客户信息反馈管理经验…

官网:视觉是第一记忆,没有记忆点的官网设计是失败的。

官方网站虽然不像之前那么火爆了&#xff0c;但是依然是企业展示品牌形象和吸引用户的重要渠道。仅仅拥有一个官方网站并不足以吸引用户&#xff0c;更重要的是网站的设计是否能够给用户留下深刻的记忆。 当前&#xff0c;用户对于网站的要求也越来越高&#xff0c;他们不仅仅希…

Arduino UNO R3自学笔记16 之 Arduino的定时器介绍及应用

注意&#xff1a;学习和写作过程中&#xff0c;部分资料搜集于互联网&#xff0c;如有侵权请联系删除。 前言&#xff1a;学习定时器的功能。 1.定时器介绍 定时器也是一种中断&#xff0c;属于软件中断。 它就像一个时钟&#xff0c;可以测量事件的时间间隔。 比如早…

重置linux后vscode无法再次使用ssh连接

如果你使用过vscode ssh远程连接了一个Linux系统&#xff0c;但该系统被重置了&#xff0c;并且关键配置没有改变。再次使用vscode连接时&#xff0c;vscode可能无法连接。 原因&#xff1a;vscode远程连接后会在C:\Users{{你的用户名}}.ssh下的known_hosts和known_hosts.old。…

停止模式下USART为什么可以唤醒MCU?

在MCU的停止模式下&#xff0c;USART之类的外设时钟是关闭的&#xff0c;但是USART章节有描述到在停止模式下可以用USART来对MCU进行唤醒&#xff1a; 大家是否会好奇在外设的时钟被关闭的情况下&#xff0c;USART怎么能通过接收中断或者唤醒事件对MCU进行唤醒的呢&#xff1…

2024多模态大模型发展调研

随着生成式大语言模型应用的日益广泛&#xff0c;其输入输出模态受限的问题日益凸显&#xff0c;成为制约技术进一步发展的瓶颈。为突破这一局限&#xff0c;本文聚焦于研究多模态信息的协同交互策略&#xff0c;旨在探索一种能够统一理解与生成的多模态模型构建方法。在此基础…

基于springboot+小程序的在线选课管理系统1(源码+sql脚本+视频导入教程+文档)

&#x1f449;文末查看项目功能视频演示获取源码sql脚本视频导入教程视频 1、项目介绍 基于springboot小程序的在线选课管理系统实现了管理员、教师及学生。 1、管理员实现了首页、个人中心、管理员管理、教师管理、学生管理、课程信息管理、选课信息、公告管理、论坛管理、基…

Redis哨兵模式的搭建以及配置参数简介

原理 Redis哨兵模式是一种用于在Redis主从复制环境中进行高可用性监控和故障恢复的机制。该模式引入了一个或多个哨兵节点&#xff0c;这些节点负责监控Redis服务器的状态&#xff0c;并在主节点发生故障时切换为新的主节点。 哨兵节点的工作原理如下&#xff1a; 1、哨兵节点…

PDF阅读器工具集萃:满足你的多样需求

现在阅读书籍大部分都喜欢电子书的形式了吧&#xff0c;因为小小的一个设备就能存下上万本书。从流传程度来说PDF无疑是一个使用最广的格式。除了福昕PDF阅读器阅读之外还有哪些好用的阅读工具呢/&#xff1f;今天我们一起来探讨一下吧。 1.福昕阅读器 链接一下>>www.f…

MongoDB微服务部署

一、安装MongoDB 1.在linux中拉去MongoDB镜像文件 docker pull mongo:4.4.18 2. 2.创建数据挂载目录 linux命令创建 命令创建目录: mkdir -p /usr/local/docker/mongodb/data 可以在sshclient工具查看是否创建成功。 进入moogodb目录&#xff0c;给data赋予权限777 cd …

【算法】链表:21.合并两个有序链表(easy)

系列专栏 《分治》 《模拟》 《Linux》 目录 1、题目链接 2、题目介绍 3、解法&#xff08;双指针&#xff09; 4、代码 1、题目链接 21. 合并两个有序链表 - 力扣&#xff08;LeetCode&#xff09; 2、题目介绍 3、解法&#xff08;双指针&#xff09; 推荐一篇题解…

计算机毕业设计Python+Spark知识图谱高考分数线预测 高考志愿推荐系统 高考数据分析 高考可视化 高考大数据 大数据毕业设计

《PythonSpark知识图谱高考分数线预测与志愿推荐系统》开题报告 一、课题背景及意义 1. 背景 随着我国高考制度的不断完善以及大数据技术的快速发展&#xff0c;高考志愿推荐系统的需求日益增长。高考作为中国教育体系中的重要环节&#xff0c;其志愿填报直接关系到考生的未…

双指针--收尾的两道题

双指针 (封面起到吸引读者作用&#xff0c;和文章内容无关哈&#xff0c;但是文章也是用心写的&#xff09; 三数之和 给你一个整数数组 nums &#xff0c;判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k &#xff0c;同时还满足 nums[i] nums…