使用 Keras 训练一个循环神经网络(RNN)

在前面的文章中,我们介绍了如何使用 Keras 训练全连接神经网络(MLP)和卷积神经网络(CNN)。本文将带你深入学习如何使用 Keras 构建和训练一个循环神经网络(RNN),用于处理序列数据。我们将使用 IMDB 电影评论数据集 进行情感分析任务。

目录

  1. 环境准备
  2. 导入必要的库
  3. 加载和预处理数据
  4. 构建循环神经网络模型
  5. 编译模型
  6. 训练模型
  7. 评估模型
  8. 保存和加载模型
  9. 可视化训练过程
  10. 总结

1. 环境准备

确保你已经安装了 Python(推荐 3.6 及以上版本)和 TensorFlow(Keras 已集成在 TensorFlow 中)。如果尚未安装,请运行以下命令:

pip install tensorflow

2. 导入必要的库

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
import numpy as np
import matplotlib.pyplot as plt
  • tensorflow: 深度学习框架,Keras 已集成其中。
  • numpy: 用于数值计算。
  • matplotlib.pyplot: 用于数据可视化。

3. 加载和预处理数据

我们将使用 Keras 自带的 IMDB 电影评论数据集,这是一个用于情感分析的二分类数据集,包含 25,000 条训练评论和 25,000 条测试评论。

# 加载 IMDB 数据集
max_features = 10000  # 词汇表大小
maxlen = 500          # 每条评论的最大长度(x_train, y_train), (x_test, y_test) = keras.datasets.imdb.load_data(num_words=max_features)print(f"训练数据形状: {x_train.shape}, 训练标签形状: {y_train.shape}")
print(f"测试数据形状: {x_test.shape}, 测试标签形状: {y_test.shape}")# 数据预处理
# 将序列填充到相同长度
x_train = keras.preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen)
x_test = keras.preprocessing.sequence.pad_sequences(x_test, maxlen=maxlen)print(f"填充后的训练数据形状: {x_train.shape}")
print(f"填充后的测试数据形状: {x_test.shape}")

说明:

  • max_features: 词汇表大小,表示只考虑最常见的 10,000 个单词。
  • maxlen: 每条评论的最大长度,超过的部分将被截断,不足的部分将被填充。
  • 使用 pad_sequences 将所有序列填充到相同长度,以便输入到 RNN 中。

4. 构建循环神经网络模型

我们将构建一个简单的 RNN 模型,使用 LSTM 层来处理序列数据。

model = models.Sequential([layers.Embedding(input_dim=max_features, output_dim=128, input_length=maxlen),  # 嵌入层,将单词索引转换为向量layers.LSTM(128, dropout=0.2, recurrent_dropout=0.2),  # LSTM 层,128 个单元,dropout 和 recurrent_dropout 用于防止过拟合layers.Dense(1, activation='sigmoid')  # 输出层,二分类使用 sigmoid 激活函数
])# 查看模型结构
model.summary()

说明:

  • Embedding: 将单词索引转换为稠密向量表示。
  • LSTM: 长短期记忆网络,用于处理序列数据。
  • dropoutrecurrent_dropout: 用于防止过拟合。
  • Dense: 输出层,使用 sigmoid 激活函数进行二分类。

5. 编译模型

model.compile(optimizer='adam',loss='binary_crossentropy',metrics=['accuracy'])

说明:

  • 使用 Adam 优化器和二元交叉熵损失函数。
  • 评估指标为准确率。

6. 训练模型

# 设置训练参数
batch_size = 64
epochs = 5# 训练模型
history = model.fit(x_train, y_train,batch_size=batch_size,epochs=epochs,validation_split=0.1)  # 使用 10% 的训练数据作为验证集

说明:

  • 使用 10% 的训练数据作为验证集,以监控模型在验证集上的性能。

7. 评估模型

test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print(f"\n测试准确率: {test_acc:.4f}")

8. 保存和加载模型

# 保存模型
model.save("imdb_rnn_model.h5")# 加载模型
new_model = keras.models.load_model("imdb_rnn_model.h5")

9. 可视化训练过程

# 绘制训练 & 验证的准确率和损失值
plt.figure(figsize=(12,4))# 准确率
plt.subplot(1,2,1)
plt.plot(history.history['accuracy'], label='训练准确率')
plt.plot(history.history['val_accuracy'], label='验证准确率')
plt.xlabel('Epoch')
plt.ylabel('准确率')
plt.legend(loc='lower right')
plt.title('训练与验证准确率')# 损失值
plt.subplot(1,2,2)
plt.plot(history.history['loss'], label='训练损失')
plt.plot(history.history['val_loss'], label='验证损失')
plt.xlabel('Epoch')
plt.ylabel('损失')
plt.legend(loc='upper right')
plt.title('训练与验证损失')plt.show()

说明:

  • 通过可视化训练过程中的准确率和损失值,可以帮助我们了解模型的训练情况,判断是否存在过拟合或欠拟合。

10. 课程回顾

本文介绍了如何使用 Keras 构建和训练一个简单的循环神经网络(RNN),用于处理序列数据(如文本)。主要步骤包括:

  1. 环境准备和库导入: 确保安装了必要的库,并导入所需模块。
  2. 数据加载和预处理: 加载 IMDB 数据集,进行序列填充和标签编码。
  3. 构建 RNN 模型: 使用 Embedding、LSTM、Dense 等层构建模型。
  4. 编译模型: 指定优化器、损失函数和评估指标。
  5. 训练模型: 使用训练数据训练模型,并使用验证集监控性能。
  6. 评估模型: 在测试集上评估模型性能。
  7. 保存和加载模型: 将训练好的模型保存到磁盘,并可加载进行预测。
  8. 可视化训练过程: 通过绘制准确率和损失值曲线,了解模型的训练情况。

其实, RNN 模型如语言建模、机器可以用在,机器翻译、语音识别等应用领域,感兴趣可以自行探索。keras 本身也很容易找到这方面的例子。

作者简介

前腾讯电子签的前端负责人,现 whentimes tech CTO,专注于前端技术的大咖一枚!一路走来,从小屏到大屏,从 Web 到移动,什么前端难题都见过。热衷于用技术打磨产品,带领团队把复杂的事情做到极简,体验做到极致。喜欢探索新技术,也爱分享一些实战经验,帮助大家少走弯路!

温馨提示:可搜老码小张公号联系导师

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/15458.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Vulkan 开发(十一):Vulkan 交换链

Vulkan 系列文章: 1. 开篇,Vulkan 概述 2. Vulkan 实例 3. Vulkan 物理设备 4. Vulkan 设备队列 5. Vulkan 逻辑设备 6. Vulkan 内存管理 7. Vulkan 缓存 8. Vulkan 图像 9. Vulkan 图像视图 10. Vulkan 窗口表面(Surface&#xff…

【HarmonyOS】鸿蒙系统在租房项目中的项目实战(一)

从今天开始,博主将开设一门新的专栏用来讲解市面上比较热门的技术 “鸿蒙开发”,对于刚接触这项技术的小伙伴在学习鸿蒙开发之前,有必要先了解一下鸿蒙,从你的角度来讲,你认为什么是鸿蒙呢?它出现的意义又是…

百度搜索AI探索版多线程批量生成TXT原创文章软件-可生成3种类型文章

百度搜索AI探索版是百度推出的一款基于大语言模型文心一言的综合搜索产品‌。以下是关于百度搜索AI探索版的详细介绍: ‌产品发布‌:百度搜索AI探索版在百度世界大会上进行了灰度测试,并面向用户开放体验‌。 ‌核心功能‌:与传…

Linux软件包管理与Vim编辑器使用指南

目录 一、Linux软件包管理器yum 1.什么是软件包? 2.什么是软件包管理器? 3.查看软件包 4.安装软件 ​编辑 5.卸载软件 Linux开发工具: 二、Linux编辑器---vim 1.vim的基本概念 (1) 正常/普通模式(Normal mode&#xff0…

Android Osmdroid + 天地图 (一)

Osmdroid 天地图 前言正文一、配置build.gradle二、配置AndroidManifest.xml三、获取天地图的API Key① 获取开发版SHA1② 获取发布版SHA1 四、请求权限五、显示地图六、源码 前言 Osmdroid是一款完全开源的地图基本操作SDK,我们可以通过这个SDK去加一些地图API&am…

2024国内AI工具十大推荐丨亲测好用‼️

🚀探索了市面上数百款AI工具后,我精心挑选了10款在不同场景下超级好用的神器,快来一起看看吧!🌟 1️⃣豆包 基于云雀模型开发,具备聊天机器人、写作助手、英语学习助手等功能,能够进行多轮对话…

Unity学习---IL2CPP打包时可能遇到的问题

写这篇主要是怕自己之后打包的时候出问题不知道怎么搞,所以记录一下。 问题一:类型裁剪 IL2CPP打包后会自动对Unity工程的dll进行裁剪,将代码中没有引用到的类型裁剪掉。特别是通过反射等方式调用一些类的时候,很容易出问题。 …

多模态大模型(2)--BLIP

大模型如火如荼,研究者们已经不再满足于基本文本的大语言模型(LLM, Large Language Model),AI领域的热点正逐步向多模态转移,具备多模态能力的多模态大型语言模型(MM(Multi-Modal)-L…

基于MATLAB身份证号码识别

课题介绍 本课题为基于连通域分割和模板匹配的二代居民身份证号码识别系统,带有一个GUI人机交互界面。可以识别数十张身份证图片。 首先从身份证图像上获取0~9和X共十一个号码字符的样本图像作为后续识别的字符库样本,其次将待测身份证图像…

Siggraph Asia 2024 | Adobe发布MagicClay:可通过文字引导去对3D模型中的特定部分进行雕刻

今天给大家介绍一篇来自Adobe研究人员在Siggraph Asia 2024上发表的最新工作MagicClay,它是一款结合网格和距离场(SDF)的混合式工具,可以通过文字引导去对3D模型中的特定部分进行雕刻。允许艺术家通过文字提示进行局部网格编辑,支持生成具纹理…

滑动窗口的使用

一、定义与基本原理 滑动窗口是一种流量控制技术,也用于管理和处理数据流。它通过定义一个固定大小或可根据特定条件动态调整的窗口,在数据流或数据序列上滑动,以便高效地处理其中的数据。这种技术能够限制同时处理的数据量,从而…

Python学习26天

集合 # 定义集合 num {1, 2, 3, 4, 5} print(f"num:{num}\nnum数据类型为:{type(num)}") # 求集合中元素个数 print(f"num中元素个数为:{len(num)}") # 增加集合中的元素 num.add(6) print(num) # {1,2,3,4,5,6} # 删除…

android开发

文章目录 android开发 类微信界面整体框架展示:主页Fragment_MainActivity2:1. 聊天界面2. 用户界面用户界面的跳转 3. 朋友圈界面4. 我的界面 android开发 类微信界面 整体效果展示: 整体框架展示: 4个主要的fragment页面&#…

【大数据学习 | flume】flume的概述与组件的介绍

1. flume概述 Flume是cloudera(CDH版本的hadoop) 开发的一个分布式、可靠、高可用的海量日志收集系统。它将各个服务器中的数据收集起来并送到指定的地方去,比如说送到HDFS、Hbase,简单来说flume就是收集日志的。 Flume两个版本区别: ​ 1&…

vmware安装Ubuntu桌面版系统

1安装环境 vmware版本:VMware Workstation 17 Ubuntu版本:ubuntu-24.04.1-desktop-amd64.iso 文档时间:2024年11月 每一个Ubuntu的版本安装显示可能不一样,但安装方法是类似的 2镜像下载 Ubuntu官网:[https://ubun…

STL--map、set的使用和模拟实现

1.set 1.1 set的概念 set 是一种基于 平衡二叉搜索树(通常是红黑树) 实现的容器,它提供了有序集合的功能。set 用于存储唯一的元素,并且元素是按照某种顺序排列的(通常是升序)。 set 确实是一个关联式容…

软件测试之什么是缺陷

软件测试之什么是缺陷 1. 缺陷定义2. 缺陷判定标准3. 缺陷产生原因3.1 缺陷产生的原因3.2 缺陷的生命周期 4. 缺陷核心内容5. 缺陷提交要素6. 缺陷类型 1. 缺陷定义 软件在使用过程中存在的任何问题都叫软件的缺陷, 简称Bug. 2. 缺陷判定标准 3. 缺陷产生原因 3.1 缺陷产生的…

二叉树的遍历(手动)

树的遍历分四种: 层序遍历 前序遍历 中序遍历 后序遍历 层序遍历: 很好理解,就是bfs嘛(二不二叉都行) 前序遍历: 又叫先跟遍历,遍历顺序是根->左->右(子树里也是&#…

Unix进程

文章目录 命令行参数进程终止正常结束异常终止exit和_exitatexit 环境变量环境变量性质环境表shell中操作环境变量查看环境变量设置环境变量 环境变量接口获取环境变量设置环境变量 环境变量的继承性 进程资源shell命令查看进程的资源限制 进程关系进程标识进程组会话控制终端控…

供应链管理、一件代发系统功能及源码分享 PHP+Mysql

随着电商行业的不断发展,传统的库存管理模式已经逐渐无法满足市场需求。越来越多的企业选择“一件代发”模式,即商家不需要自己储备商品库存,而是将订单直接转给供应商,由供应商直接进行发货。这种方式极大地降低了企业的运营成本…