生成对抗网络入门案例

前言

生成对抗网络(Generative Adversarial Networks,简称GANs)是一种用于生成新样本的机器学习模型。它由两个主要组件组成:生成器(Generator)和判别器(Discriminator)。生成器尝试生成与训练数据相似的新样本,而判别器则试图区分生成器生成的样本和真实训练数据。

下面是一个简单的对抗生成网络的入门例子,用于生成手写数字图像:

实现过程

1、导入必要的库和模块

import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.optimizers import Adam

2、加载MNIST数据集

(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 255.0
x_train = np.expand_dims(x_train, axis=3)

3、定义生成器模型

generator = Sequential()
generator.add(Dense(7*7*128, input_shape=(100,), activation='relu'))
generator.add(Reshape((7, 7, 128)))
generator.add(Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', activation='relu'))
generator.add(Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', activation='sigmoid'))

4、定义判别器模型

discriminator = Sequential()
discriminator.add(Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=(28, 28, 1), activation='relu'))
discriminator.add(Conv2D(128, (3, 3), strides=(2, 2), padding='same', activation='relu'))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))

5、编译判别器模型

discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])

6、冻结判别器模型的权重

discriminator.trainable = False

7、定义GAN模型

gan = Sequential()
gan.add(generator)
gan.add(discriminator)

8、编译GAN模型

gan.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))

9、定义训练函数

def train_gan(epochs, batch_size, sample_interval):for epoch in range(epochs):# 生成随机噪声作为输入noise = np.random.normal(0, 1, (batch_size, 100))# 生成假样本generated_images = generator.predict(noise)# 从真实样本中随机选择一批样本real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]# 训练判别器discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, 100))generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))# 打印损失if epoch % sample_interval == 0:print(f"Epoch {epoch}/{epochs}, Discriminator Loss: {discriminator_loss[0]}, Generator Loss: {generator_loss}")# 保存生成的图像save_images(epoch)

10、保存生成的图像

def save_images(epoch):rows, cols = 5, 5noise = np.random.normal(0, 1, (rows * cols, 100))generated_images = generator.predict(noise)generated_images = 0.5 * generated_images + 0.5fig, axs = plt.subplots(rows, cols)idx = 0for i in range(rows):for j in range(cols):axs[i, j].imshow(generated_images[idx, :, :, 0], cmap='gray')axs[i, j].axis('off')idx += 1fig.savefig(f"gan_images/mnist_{epoch}.png")plt.close()

11、训练GAN模型

epochs = 10000
batch_size = 128
sample_interval = 1000

完整代码

import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.optimizers import Adam# 加载MNIST数据集
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 255.0
x_train = np.expand_dims(x_train, axis=3)# 定义生成器模型
generator = Sequential()
generator.add(Dense(7*7*128, input_shape=(100,), activation='relu'))
generator.add(Reshape((7, 7, 128)))
generator.add(Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', activation='relu'))
generator.add(Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', activation='sigmoid'))# 定义判别器模型
discriminator = Sequential()
discriminator.add(Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=(28, 28, 1), activation='relu'))
discriminator.add(Conv2D(128, (3, 3), strides=(2, 2), padding='same', activation='relu'))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))# 编译判别器模型
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])# 冻结判别器模型的权重
discriminator.trainable = False# 定义GAN模型
gan = Sequential()
gan.add(generator)
gan.add(discriminator)# 编译GAN模型
gan.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))# 定义训练函数
def train_gan(epochs, batch_size, sample_interval):for epoch in range(epochs):# 生成随机噪声作为输入noise = np.random.normal(0, 1, (batch_size, 100))# 生成假样本generated_images = generator.predict(noise)# 从真实样本中随机选择一批样本real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]# 训练判别器discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)# 训练生成器noise = np.random.normal(0, 1, (batch_size, 100))generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))# 打印损失if epoch % sample_interval == 0:print(f"Epoch {epoch}/{epochs}, Discriminator Loss: {discriminator_loss[0]}, Generator Loss: {generator_loss}")# 保存生成的图像save_images(epoch)# 保存生成的图像
def save_images(epoch):rows, cols = 5, 5noise = np.random.normal(0, 1, (rows * cols, 100))generated_images = generator.predict(noise)generated_images = 0.5 * generated_images + 0.5fig, axs = plt.subplots(rows, cols)idx = 0for i in range(rows):for j in range(cols):axs[i, j].imshow(generated_images[idx, :, :, 0], cmap='gray')axs[i, j].axis('off')idx += 1fig.savefig(f"gan_images/mnist_{epoch}.png")plt.close()# 训练GAN模型
epochs = 10000
batch_size = 128
sample_interval = 1000train_gan(epochs, batch_size, sample_interval)

这个例子使用了MNIST数据集,生成手写数字图像。生成器和判别器模型使用了卷积神经网络的结构。在训练过程中,生成器试图生成逼真的手写数字图像,而判别器则试图区分真实图像和生成图像。通过反复迭代训练生成器和判别器,GAN模型能够逐渐生成更逼真的手写数字图像。生成的图像会保存在gan_images文件夹中。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/148927.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【算法导论】线性时间排序(计数排序、基数排序、桶排序)

引言:   在排序的最终结果中,各元素的次序依赖于它们之间的比较,我们把这类排序算法称为比较排序,对于包含n个元素的输入序列来说,任何比较排序在最坏情况下都要经过 Ω ( n l g n ) \Omega(nlgn) Ω(nlgn)次比较&a…

学信息系统项目管理师第4版系列17_干系人管理

1. 项目经理和团队管理干系人的能力决定着项目的成败 2. 干系人满意度应作为项目目标加以识别和管理 3. 发展趋势和新兴实践 3.1. 识别所有干系人,而非在限定范围内 3.2. 确保所有团队成员都涉及引导干系人参与的活 3.3. 定期审查干系人群体,可与单…

星际争霸之小霸王之小蜜蜂(十六)--狂奔的花猫

系列文章目录 星际争霸之小霸王之小蜜蜂(十五)--剧将终场 星际争霸之小霸王之小蜜蜂(十四)--资本家的眼泪 星际争霸之小霸王之小蜜蜂(十三)--接着奏乐接着舞 星际争霸之小霸王之小蜜蜂(十二…

Lucene学习总结之Lucene的索引文件格式

当我们真正进入到Lucene源代码之中的时候,我们会发现: Lucene的索引过程,就是按照全文检索的基本过程,将倒排表写成此文件格式的过程。Lucene的搜索过程,就是按照此文件格式将索引进去的信息读出来,然后计算每篇文档打…

Android开源 Skeleton 骨架屏 V1.3.0

目录 一、简介 二、效果图 三、引用 Skeleton 添加jitpack 仓库 添加依赖: 四、新增 “块”骨架屏 1、bind方法更改和变化: 2、load方法更改和变化: 五、关于上一个版本 一、简介 骨架屏的作用是在网络请求较慢时,提供基础占位&…

Android LitePal byte[]类型字段不被创建

我创建了以下实体类,主要是用户分享的内容、分享的照片、分享的标题,然后百度了一下LitePal可以识别byte[],因为需要文件的上传与读取: public class Context extends LitePalSupport {private Integer ContextId;private String…

LLMs 用强化学习进行微调 RLHF: Fine-tuning with reinforcement learning

让我们把一切都整合在一起,看看您将如何在强化学习过程中使用奖励模型来更新LLM的权重,并生成与人对齐的模型。请记住,您希望从已经在您感兴趣的任务上表现良好的模型开始。您将努力使指导发现您的LLM对齐。首先,您将从提示数据集…

用向量数据库Milvus Cloud 搭建AI聊天机器人

加入大语言模型(LLM) 接着,需要在聊天机器人中加入 LLM。这样,用户就可以和聊天机器人开展对话了。本示例中,我们将使用 OpenAI ChatGPT 背后的模型服务:GPT-3.5。 聊天记录 为了使 LLM 回答更准确,我们需要存储用户和机器人的聊天记录,并在查询时调用这些记录,可以用…

软件设计模式系列之二十五——访问者模式

访问者模式(Visitor Pattern)是一种强大的行为型设计模式,它允许你在不改变被访问对象的类的前提下,定义新的操作和行为。本文将详细介绍访问者模式,包括其定义、举例说明、结构、实现步骤、Java代码实现、典型应用场景…

软件工程从理论到实践客观题汇总(头歌第九章至第十七章)

九、软件体系结构设计 1、软件体系结构设计概述 2、软件体系结构模型的表示方法 3、软件体系结构设计过程 4、设计初步的软件体系结构 5、重用已有软件资源 6、精化软件体系结构 7、设计软件部署模型 8、文档化和评审软件体系结构设计 十、软件用户界面设计 1、用户界面设计概…

Ae 效果:CC Page Turn

扭曲/CC Page Turn Distort/CC Page Turn CC Page Turn (CC 翻页)主要用于模拟书页翻动的效果。通过使用该效果,用户可以创建出像书页或杂志页面翻动的视觉效果,增强影片的交互性和视觉吸引力。 ◆ ◆ ◆ 效果属性说明 Contro…

Is This The Intelligent Model(这是智能模型吗)

Is This The Intelligent Model 这是智能模型吗 Ruoqi Sun Academy of Military Science Defense Innovation Institute, Beijing, 100091, China E-mail: ruoqisun7163.com The exposed models are called artificial intelligent models[1-3]. These models rely on knowled…

第2篇 机器学习基础 —(1)机器学习方式及分类、回归

前言:Hello大家好,我是小哥谈。机器学习是一种人工智能的分支,它使用算法和数学模型来使计算机系统能够从经验数据中学习和改进,而无需显式地编程。机器学习的目标是通过从数据中发现模式和规律,从而使计算机能够自动进…

Leetcode 231.2的幂

给你一个整数 n,请你判断该整数是否是 2 的幂次方。如果是,返回 true ;否则,返回 false 。 如果存在一个整数 x 使得 n 2x ,则认为 n 是 2 的幂次方。 示例 1: 输入:n 1 输出:tr…

【重拾C语言】二、顺序程序设计(基本符号、数据、语句、表达式、顺序控制结构、数据类型、输入/输出操作)

目录 前言 二、顺序程序设计 2.1 求绿化带面积——简单程序 2.2基本符号: 2.2.1 字符集 可视字符 不可视字符 2.2.2 C特定符 关键字 分隔符 运算符 2.2.3 标识符 2.2.4 间隔符 2.2.5 注释 2.3 数据 2.3.1 字面常量(Literal Constants&am…

【RP-RV1126】烧录固件使用记录

文章目录 烧录完整固件进入MASKROM模式固件烧录升级中:升级完成: 烧录部分进入Loader模式选择文件切换loader模式 烧录完整固件 完整固件就是update.img包含了所有的部件,烧录后可以直接运行。 全局编译:./build.sh all生成固件…

SDL2绘制ffmpeg解析的mp4文件

文章目录 1.FFMPEG利用命令行将mp4转yuv4202.ffmpeg将mp4解析为yuv数据2.1 核心api: 3.SDL2进行yuv绘制到屏幕3.1 核心api 4.完整代码5.效果展示 本项目采用生产者消费者模型,生产者线程:使用ffmpeg将mp4格式数据解析为yuv的帧,消费者线程&am…

单层神经网络

神经网络 人工神经网络(Artificial Neural Network,ANN),简称神经网络(Neural Network,NN),是一种模仿生物神经网络的结构和功能的数学模型或计算模型。1943年,McCulloc…

计数排序(Counting Sort)详解

计数排序(Counting Sort)是一种非比较排序算法,其核心思想是通过计数每个元素的出现次数来进行排序,适用于整数或有限范围内的非负整数排序。这个算法的特点是速度快且稳定,适用于某些特定场景。在本文中,我…

多层神经网络和激活函数

多层神经网络的结构 多层神经网络就是由单层神经网络进行叠加之后得到的,所以就形成了层的概念,常见的多层神经网络有如下结构: 1)输入层(Input layer),众多神经元(Neuron&#xff…