R6:LSTM实现糖尿病探索与预测

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

一、实验目的:

学习使用LSTM对糖尿病进行探索预测

二、实验环境:

  • 语言环境:python 3.8
  • 编译器:Jupyter notebook
  • 深度学习环境:Pytorch
    • torch==2.4.0+cu124
    • torchvision==0.19.0+cu124

三、数据预处理

逻辑回归在二分类问题中应用广泛;KNN(K 近邻算法)、SVM(支持向量机)、决策树、贝叶斯分类器、随机森林和 XGBoost(极端梯度提升树)都是常见的用于结构化数据分类的算法。

本次实验我们采用 LSTM(长短期记忆网络)进行分类预测。LSTM 主要用于处理序列数据,虽然在一些特定情况下可以对序列数据进行分类,但对于一般的二维结构化数据,上述提到的传统分类算法通常更加合适。二维结构化数据通常指表格形式的数据,每一行代表一个样本,每一列代表一个特征,对于这类数据,传统的机器学习分类算法在计算效率和可解释性方面往往具有优势。

在这里插入图片描述

1. 设置GPU、导入数据

#设置GPU 
import torch.nn as nn 
import torch.nn.functional as F 
import torchvision,torch device = torch.device("cuda" if torch.cuda.is_available() else "cpu")device 
#导入数据
import numpy   as np
import pandas  as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
plt.rcParams['savefig.dpi'] = 500 #图片像素
plt.rcParams['figure.dpi'] = 500 #分辨率plt.rcParams['font.sans-serif'] = ['SimHei'] #用来正常显示中文标签import warnings
warnings.filterwarnings('ignore')DataFrame = pd.read_excel('diabetes.xls')
DataFrame.head()

在这里插入图片描述

DataFrame.shape
(1006, 16)    

2. 数据检查

#查看数据是否有缺失值
print('数据缺失值--------------------------')
print(DataFrame.isnull().sum())

在这里插入图片描述

#查看数据是否有重复值
print('数据重复值--------------------------')
print('数据集的重复值为:'f'{DataFrame.duplicated().sum()}')

在这里插入图片描述

3. 数据分布分析

feature_map = { '年龄': '年龄','高密度脂蛋白胆固醇': '高密度脂蛋白胆固醇','低密度脂蛋白胆固醇': '低密度脂蛋白胆固醇','极低密度脂蛋白胆固醇': '极低密度脂蛋白胆固醇','甘油三酯': '甘油三酯','总胆固醇': '总胆固醇','脉搏': '脉搏','舒张压':'舒张压','高血压史':'高血压史','尿素氮':'尿素氮','尿酸':'尿酸','肌酐':'肌酐','体重检查结果':'体重检查结果'}plt.figure(figsize=(15,10))for i, (col, col_name) in enumerate(feature_map.items(), 1):plt.subplot(3,5,i)sns.boxplot(x=DataFrame['是否糖尿病'], y=DataFrame[col])plt.title(f'{col_name}的箱线图', fontsize=14)plt.ylabel('数值', fontsize=12)plt.grid(axis='y', linestyle='--', alpha=0.7)plt.tight_layout()
plt.show()

在这里插入图片描述
以下是分析箱线图的方法,并以年龄的箱线图为例进行介绍:

一、认识箱线图的组成部分

  1. 箱体:箱体的上下边界分别代表数据的上四分位数(Q3)和下四分位数(Q1)。箱体中间的线通常代表中位数。
  2. whiskers(须):从箱体延伸出去的线段,代表数据的范围。一般来说,须的长度是由一些特定的规则决定的,常见的是 1.5 倍的四分位距(IQR,即 Q3 - Q1)。超出须范围的数据点被视为异常值,可能会以单独的点显示。

二、分析年龄箱线图的具体步骤

  1. 观察中位数:

    • 首先找到箱体中间的线,它代表了年龄数据的中位数。如果这条线在箱线图的中间位置附近,说明数据分布相对较为对称;如果偏向箱体的上边界或下边界,则说明数据可能存在偏斜。
    • 假设年龄箱线图中,中位数线靠近箱体上边界,这可能意味着年龄数据整体上偏大,即大部分人的年龄较高。
  2. 分析箱体长度:

    • 箱体的长度反映了数据的离散程度。如果箱体较短,说明数据比较集中;如果箱体较长,说明数据的分散程度较大。
    • 例如,如果年龄箱线图的箱体较短,说明年龄数据相对集中在一个较小的范围内。
  3. 观察须的长度:

    • 须的长度可以让你了解数据的整体范围。较长的须表示数据的范围较大;较短的须可能意味着数据比较集中在一个较小的区间内。
    • 如果年龄箱线图的须较长,说明年龄数据的跨度较大,可能有一些年龄较大或较小的极端值。
  4. 检查异常值:

    • 异常值通常以单独的点显示在箱线图之外。观察异常值的数量和分布,可以了解数据中是否存在极端情况。
    • 如果年龄箱线图中有一些异常值,需要进一步分析这些异常值的来源,例如是否是由于数据录入错误或者特殊的个体情况导致的。
  5. 比较不同组别的箱线图:

    • 如果有多个组别的年龄箱线图,可以比较它们的中位数、箱体长度、须的长度和异常值情况,以了解不同组之间年龄分布的差异。
    • 例如,比较糖尿病患者和非糖尿病患者的年龄箱线图,看是否存在明显的差异。如果糖尿病患者的年龄箱线图中位数较高,箱体较长,可能说明糖尿病患者的年龄普遍较大。

通过以上步骤,你可以对年龄箱线图进行较为全面的分析,了解年龄数据的分布特征和潜在的问题。对于其他变量的箱线图,也可以采用类似的方法进行分析。

df_corr = DataFrame.drop(['卡号'],axis=1).corr()
plt.figure(figsize=(12,10))
plt.title('相关性热图')
sns.heatmap(df_corr,annot=True)
plt.show()

在这里插入图片描述

四、LSTM模型

#数据集构建from sklearn.preprocessing import StandardScaler# '高密度脂蛋白胆固醇'字段与糖尿病负相关,故而在 X 中去掉该字段
X = DataFrame.drop(['卡号','是否糖尿病','高密度脂蛋白胆固醇'],axis=1)
y = DataFrame['是否糖尿病']# sc_X    = StandardScaler()
# X = sc_X.fit_transform(X)X = torch.tensor(np.array(X), dtype=torch.float32)
y = torch.tensor(np.array(y), dtype=torch.int64)train_X, test_X, train_y, test_y = train_test_split(X, y, test_size=0.2,random_state=1)
train_X.shape, train_y.shapefrom torch.utils.data import TensorDataset, DataLoadertrain_dl = DataLoader(TensorDataset(train_X, train_y),batch_size=64,shuffle=False)
test_dl  = DataLoader(TensorDataset(test_X, test_y),batch_size=64,shuffle=False)
#定义模型
class model_lstm(nn.Module):def __init__(self):super(model_lstm, self).__init__()self.lstm0 = nn.LSTM(input_size=13,  hidden_size=200, num_layers=1, batch_first=True)self.lstm1 = nn.LSTM(input_size=200, hidden_size=200, num_layers=1, batch_first=True)self.fc0   = nn.Linear(200, 2)def forward(self, x):out, hidden1 = self.lstm0(x)out, _       = self.lstm1(out, hidden1)out          = self.fc0(out)return outmodel = model_lstm().to(device)
model

在这里插入图片描述

五、训练模型

#定义训练函数
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小num_batches = len(dataloader)   # 批次数目, (size/batch_size,向上取整)train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)          # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()        # 反向传播optimizer.step()       # 每一步自动更新# 记录acc与losstrain_acc  += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc  /= sizetrain_loss /= num_batchesreturn train_acc, train_loss
#定义测试函数
def test (dataloader, model, loss_fn):size        = len(dataloader.dataset)  # 测试集的大小num_batches = len(dataloader)          # 批次数目, (size/batch_size,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss        = loss_fn(target_pred, target)test_loss += loss.item()test_acc  += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc  /= sizetest_loss /= num_batchesreturn test_acc, test_loss
#训练模型
loss_fn    = nn.CrossEntropyLoss() # 创建损失函数
learn_rate = 1e-4   # 学习率
opt        = torch.optim.Adam(model.parameters(),lr=learn_rate)
epochs     = 30train_loss = []
train_acc  = []
test_loss  = []
test_acc   = []for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, opt)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = opt.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))print("="*20, 'Done', "="*20)

在这里插入图片描述

六、模型评估

#Loss与Accuracy图
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

在这里插入图片描述

七、总结

分析数据可知存在有一定的过拟合迹象:

  • 随着训练的进行,训练准确率不断上升,而测试准确率在前期长时间停滞,后期虽然有所上升,但上升幅度小于训练准确率。这表明模型在训练集上的学习能力较强,但在测试集上的泛化能力相对较弱。
  • 训练损失持续下降,而测试损失在下降过程中出现了波动,并且在后期与训练损失的差距有一定程度的扩大。这也暗示模型可能过度拟合了训练数据,导致在测试集上的表现不如在训练集上的表现稳定。
  • 实验中尝试通过提高学习率至1e-3,可以将预测准确率提升到71.3%,而提高训练轮数则始终难以收敛。而在构建数据集部分,可以看到注释部分的代码为数据的标准化处理。
  • 除此之外,还可以考虑采用正则化方法、增加数据量、早停法等技术来缓解过拟合问题。

在划分数据集过程中添加标准化处理可以提升测试数据集准确率的原因主要有以下几点:

一、消除量纲影响

  1. 不同特征往往具有不同的量纲和尺度。例如,一个特征可能取值范围在 0 到 100 之间,而另一个特征可能取值在 0 到 1 之间。这会使得在某些算法中,具有较大数值范围的特征对模型的影响更大,从而可能导致模型偏向于这些特征,而忽略了其他重要特征的作用。
  2. 标准化处理将数据的各个特征转换到相同的尺度上,通常使得特征的均值为 0,标准差为 1。这样可以确保每个特征在模型中具有相对平等的影响力,避免了因量纲差异而导致的不公平性。

二、加速模型收敛

  1. 许多优化算法在处理标准化后的数据时能够更快地收敛。例如,梯度下降算法在标准化的数据上能够更有效地确定下降的方向和步长,因为数据的分布更加稳定,不会因为特征的尺度差异而导致梯度在不同方向上的变化幅度差异巨大。
  2. 当数据经过标准化后,模型在训练过程中可以更稳定地更新参数,减少了因数据尺度不一致而引起的震荡,从而更快地找到最优解,这也有助于提高模型在测试集上的准确率。

三、提高模型的泛化能力

  1. 标准化可以使模型对不同单位和尺度的输入数据具有更好的适应性,从而提高模型的泛化能力。如果模型在训练时只适应了特定尺度的数据集,那么在面对测试集上不同尺度的数据时,可能表现不佳。
  2. 标准化处理可以减少异常值对模型的影响。异常值在未标准化的数据中可能会对模型产生较大的干扰,而经过标准化后,异常值的影响相对减小,模型能够更加关注数据的整体分布特征,从而提高在测试集上的准确率。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/8442.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

笔试题11 -- 装箱问题(01背包)

装箱问题(01背包) 文章目录 装箱问题(01背包)一、原题复现二、思路剖析三、示例代码 题目链接:NOIP2001装箱问题 一、原题复现 题目描述 有一个箱子容量为V(正整数,0 ≤ V ≤ 20000)…

【D3.js in Action 3 精译_038】4.2 D3 折线图的绘制方法及曲线插值处理

当前内容所在位置(可进入专栏查看其他译好的章节内容) 第一部分 D3.js 基础知识 第一章 D3.js 简介(已完结) 1.1 何为 D3.js?1.2 D3 生态系统——入门须知1.3 数据可视化最佳实践(上)1.3 数据可…

测试-正交表与工具pairs的介绍使用(1)

目录 正交表 生成正交表 步骤 实操 注意事项 编写测试用例 根据正交表编写测试用例 补充遗漏的重要测试用例 正交表 关于长篇大论也不多介绍了,我们只需要知道正交法的⽬的是为了减少⽤例数⽬,⽤尽量少的⽤例覆盖输⼊的两两组合 正交表的构成&…

抗晃电马达保护器在工业厂房中的应用

安科瑞刘鸿鹏 摘要 随着工业自动化水平的提高,生产线上电动机作为关键设备的使用频率不断增加。然而,工厂生产环境中的电力波动,尤其是晃电现象,会对电动机的正常运转造成干扰,甚至导致设备停机和生产中断。抗晃电型…

linux之调度管理(2)-调度器 如何触发运行

一、调度器是如何在程序稳定运行的情况下进行进程调度的 1.1 系统定时器 因为我们主要讲解的是调度器,而会涉及到一些系统定时器的知识,这里我们简单讲解一下内核中定时器是如何组织,又是如何通过通过定时器实现了调度器的间隔调度。首先我们…

RHCE循环执行的例行性任务--crontab(周期性)

1.每分钟执行命令 2.每小时执行 3.每天凌晨3点半和12点半执行脚本 4.每隔6小时,相当于6,12,18,24点半执行脚本 5.30半点,8-18/2表示早上8点到下午18点之间每隔2小时执行脚本代表 6.每天晚上9点30重启nginx 7.每月1号和10号4点45执行脚本 8. 每周六和周日…

ETLCloud异常问题分析ai功能

在数据处理和集成的过程中,异常问题的发生往往会对业务运营造成显著影响。为了提高ETL(提取、转换、加载)流程的稳定性与效率,ETLCloud推出了智能异常问题分析AI功能。这一创新工具旨在实时监测数据流动中的潜在异常,自…

Java项目实战II基于Spring Boot的个人云盘管理系统设计与实现(开发文档+数据库+源码)

目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。 一、前言 基于Spring Boot的个人云盘管理系统设计…

还在为慢速数据传输苦恼?Linux 零拷贝技术来帮你!

前言 程序员的终极追求是什么?当系统流量大增,用户体验却丝滑依旧?没错!然而,在大量文件传输、数据传递的场景中,传统的“数据搬运”却拖慢了性能。为了解决这一痛点,Linux 推出了 零拷贝 技术&…

密码学是如何保护数据传输的安全性?

密码学通过一系列算法和协议来保护数据传输的安全性。 一、加密技术 对称加密算法 原理:使用相同的密钥进行加密和解密。应用:在数据传输过程中,发送方和接收方共享一个密钥,数据在传输前被加密,接收方使用相同的密钥…

python怎么打开py文件

1、首先在资源管理器里复制一下py文件存放的路径,按下windows键+r,在运行里输入cmd,回车打开命令行: 2、在命令行里,先切换到py文件的路径下面,接着输入“python 文件名.py ”运行python文件&a…

云计算——ACA学习 云计算核心技术

作者简介:一名云计算网络运维人员、每天分享网络与运维的技术与干货。 座右铭:低头赶路,敬事如仪 个人主页:网络豆的主页​​​​​ 写在前面 本系列将会持续更新云计算阿里云ACA的学习,了解云计算及网络安全相关…

企业办公管理软件排名 | 九款企业管理软件助你制胜职场!(好用+实用+全面)

在寻找合适的企业办公管理软件时,你是否感到困惑不已,不知道从众多选项中选择哪一个? 一款好的管理软件不仅能简化工作流程,还能增强数据安全性,优化决策支持。 以下是九款备受推崇的企业管理软件,它们将助…

DNS服务器

DNS服务器 1、简介 DNS域名解析服务器,它作为将域名和IP地址相互映射的一个分布式数据库,端口号为53,通常使用UDP协议,但是在没有查询到完整的信息时,会以TCP这个协议来重新查询,所以在启动NDS服务器时&a…

顾荣辉在新加坡金融科技节发表主旨演讲:安全不仅是竞争优势,更是共同责任

在全球数字化和去中心化进程中,Web3的作用日益凸显,安全问题也日益成为行业的焦点。在这一背景下,顾荣辉教授于新加坡金融科技节(SFF)上发表主旨演讲《超越代码,引领信任》。顾教授在演讲中深入阐述了安全在…

Leetcode328奇偶链表,Leetcode21合并两个有序链表,Leetcode206反转链表 三者综合题

题目描述 思路分析 这题的思路就和我们的标题所述一样,可以看作是这3个题的合并,但是稍微还有一点点区别 比如:奇偶链表这道题主要是偶数链在了奇数后面,字节这个的话是奇偶链表分离了 所以字节这题的大概思路就是: …

「Mac玩转仓颉内测版1」入门篇1 - Cangjie环境的搭建

本篇详细介绍在Mac系统上快速搭建Cangjie开发环境的步骤,涵盖VSCode的下载与安装、Cangjie插件的离线安装、工具链的配置及验证。通过这些步骤,确保开发环境配置完成,为Cangjie项目开发提供稳定的基础支持。 关键词 Cangjie开发环境搭建VSC…

2023数学分析【南昌大学】

计算 求极限 lim ⁡ n → ∞ ( 1 n 2 + 1 2 + 1 n 2 + 2 2 + ⋯ + 1 n 2 + n 2 ) \mathop{\lim }\limits_{n \to \infty } \left( \frac{1}{{\sqrt {n^2 + 1^2} }} + \frac{1}{{\sqrt {n^2 + 2^2} }} + \cdots + \frac{1}{{\sqrt {n^2 + n^2} }} \right) n→∞lim​(n2+12 ​1…

从技术创新到商业应用,智象未来(HiDream.ai)创新不止步

在人工智能领域的最新动态中,智象未来(HiDream.ai)公司,作为全球领先的多模态生成式人工智能技术先驱,已经引起了广泛的行业瞩目。该公司专注于深度学习和计算机视觉技术的融合,致力于开发和优化视觉多模态…

ssm基于Vue的戏剧推广网站+vue

系统包含:源码论文 所用技术:SpringBootVueSSMMybatisMysql 免费提供给大家参考或者学习,获取源码看文章最下面 需要定制看文章最下面 目 录 摘 要 I Abstract II 第1章 绪论 1 1.1 课题背景 1 1.2 课题意义 1 1.3 研究内容 1 第2…