昇思25天学习打卡营第17天|GAN图像生成

模型简介

GAN模型的核心在于提出了通过对抗过程来估计生成模型这一全新框架。在这个框架中,将会同时训练两个模型——捕捉数据分布的生成模型G和估计样本是否来自训练数据的判别模型D 。

在训练过程中,生成器会不断尝试通过生成更好的假图像来骗过判别器,而判别器在这过程中也会逐步提升判别能力。这种博弈的平衡点是,当生成器生成的假图像和训练数据图像的分布完全一致时,判别器拥有50%的真假判断置信度。

gan

数据集

使用MNIST手写数字数据集

数据集下载和加载

# 数据下载
from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/MNIST_Data.zip"
download(url, ".", kind="zip", replace=True)import numpy as np
import mindspore.dataset as dsbatch_size = 64
latent_size = 100  # 隐码的长度train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')def data_load(dataset):dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, python_multiprocessing=False,num_samples=10000)# 数据增强mnist_ds = dataset1.map(operations=lambda x: (x.astype("float32"), np.random.normal(size=latent_size).astype("float32")),output_columns=["image", "latent_code"])mnist_ds = mnist_ds.project(["image", "latent_code"])# 批量操作mnist_ds = mnist_ds.batch(batch_size, True)return mnist_dsmnist_ds = data_load(train_dataset)iter_size = mnist_ds.get_dataset_size()
print('Iter size: %d' % iter_size)

隐码构造

为了跟踪生成器的学习进度,我们在训练的过程中的每轮迭代结束后,将一组固定的遵循高斯分布的隐码输入到生成器中,通过固定隐码所生成的图像效果来评估生成器的好坏。

import random
import numpy as np
from mindspore import Tensor
from mindspore.common import dtype# 利用随机种子创建一批隐码
np.random.seed(2323)
test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)
random.shuffle(test_noise)

模型构建

生成器

from mindspore import nn
import mindspore.ops as opsimg_size = 28  # 训练图像长(宽)class Generator(nn.Cell):def __init__(self, latent_size, auto_prefix=True):super(Generator, self).__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 100] -> [N, 128]# 输入一个100维的0~1之间的高斯分布,然后通过第一层线性变换将其映射到256维self.model.append(nn.Dense(latent_size, 128))self.model.append(nn.ReLU())# [N, 128] -> [N, 256]self.model.append(nn.Dense(128, 256))self.model.append(nn.BatchNorm1d(256))self.model.append(nn.ReLU())# [N, 256] -> [N, 512]self.model.append(nn.Dense(256, 512))self.model.append(nn.BatchNorm1d(512))self.model.append(nn.ReLU())# [N, 512] -> [N, 1024]self.model.append(nn.Dense(512, 1024))self.model.append(nn.BatchNorm1d(1024))self.model.append(nn.ReLU())# [N, 1024] -> [N, 784]# 经过线性变换将其变成784维self.model.append(nn.Dense(1024, img_size * img_size))# 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间self.model.append(nn.Tanh())def construct(self, x):img = self.model(x)return ops.reshape(img, (-1, 1, 28, 28))net_g = Generator(latent_size)
net_g.update_parameters_name('generator')

判别器

# 判别器
class Discriminator(nn.Cell):def __init__(self, auto_prefix=True):super().__init__(auto_prefix=auto_prefix)self.model = nn.SequentialCell()# [N, 784] -> [N, 512]self.model.append(nn.Dense(img_size * img_size, 512))  # 输入特征数为784,输出为512self.model.append(nn.LeakyReLU())  # 默认斜率为0.2的非线性映射激活函数# [N, 512] -> [N, 256]self.model.append(nn.Dense(512, 256))  # 进行一个线性映射self.model.append(nn.LeakyReLU())# [N, 256] -> [N, 1]self.model.append(nn.Dense(256, 1))self.model.append(nn.Sigmoid())  # 二分类激活函数,将实数映射到[0,1]def construct(self, x):x_flat = ops.reshape(x, (-1, img_size * img_size))return self.model(x_flat)net_d = Discriminator()
net_d.update_parameters_name('discriminator')

损失函数和优化器

lr = 0.0002  # 学习率# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')

模型训练

第一部分是训练判别器。训练判别器的目的是最大程度地提高判别图像真伪的概率。

第二部分是训练生成器。以产生更好的虚拟图像。

import os
import time
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import Tensor, save_checkpointtotal_epoch = 12  # 训练周期数
batch_size = 64  # 用于训练的训练集批量大小# 加载预训练模型的参数
pred_trained = False
pred_trained_g = './result/checkpoints/Generator99.ckpt'
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'checkpoints_path = "./result/checkpoints"  # 结果保存路径
image_path = "./result/images"  # 测试结果保存路径
# 生成器计算损失过程
def generator_forward(test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))return loss_g# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):fake_data = net_g(test_noises)fake_out = net_d(fake_data)real_out = net_d(real_data)real_loss = adversarial_loss(real_out, ops.ones_like(real_out))fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))loss_d = real_loss + fake_lossreturn loss_d# 梯度方法
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())def train_step(real_data, latent_code):# 计算判别器损失和梯度loss_d, grads_d = grad_d(real_data, latent_code)optimizer_d(grads_d)loss_g, grads_g = grad_g(latent_code)optimizer_g(grads_g)return loss_d, loss_g# 保存生成的test图像
def save_imgs(gen_imgs1, idx):for i3 in range(gen_imgs1.shape[0]):plt.subplot(5, 5, i3 + 1)plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")plt.axis("off")plt.savefig(image_path + "/test_{}.png".format(idx))# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)net_g.set_train()
net_d.set_train()# 储存生成器和判别器loss
losses_g, losses_d = [], []for epoch in range(total_epoch):start = time.time()for (iter, data) in enumerate(mnist_ds):start1 = time.time()image, latent_code = dataimage = (image - 127.5) / 127.5  # [0, 255] -> [-1, 1]image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])d_loss, g_loss = train_step(image, latent_code)end1 = time.time()if iter % 10 == 10:print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "f"loss_d:{d_loss.asnumpy():>4f} , "f"loss_g:{g_loss.asnumpy():>4f} , "f"time:{(end1 - start1):>3f}s, "f"lr:{lr:>6f}")end = time.time()print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))losses_d.append(d_loss.asnumpy())losses_g.append(g_loss.asnumpy())# 每个epoch结束后,使用生成器生成一组图片gen_imgs = net_g(test_noise)save_imgs(gen_imgs.asnumpy(), epoch)# 根据epoch保存模型权重文件if epoch % 1 == 0:save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))

生成图像:

效果展示

描绘D和G损失与训练迭代的关系图:

plt.figure(figsize=(6, 4))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(losses_g, label="G", color='blue')
plt.plot(losses_d, label="D", color='orange')
plt.xlim(-5,15)
plt.ylim(0, 3.5)
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

可视化训练过程中通过隐向量生成的图像。

import cv2
import matplotlib.animation as animation# 将训练过程中生成的测试图转为动态图
image_list = []
for i in range(total_epoch):image_list.append(cv2.imread(image_path + "/test_{}.png".format(i), cv2.IMREAD_GRAYSCALE))
show_list = []
fig = plt.figure(dpi=70)
for epoch in range(0, len(image_list), 5):plt.axis("off")show_list.append([plt.imshow(image_list[epoch], cmap='gray')])ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)
ani.save('train_test.gif', writer='pillow', fps=1)

训练过程测试动态图

可见随着训练次数的增多,图像质量也越来越好。

模型推理

通过加载生成器网络模型参数文件来生成图像,代码如下:

import mindspore as ms# test_ckpt = './result/checkpoints/Generator199.ckpt'# parameter = ms.load_checkpoint(test_ckpt)
# ms.load_param_into_net(net_g, parameter)
# 模型生成结果
test_data = Tensor(np.random.normal(0, 1, (25, 100)).astype(np.float32))
images = net_g(test_data).transpose(0, 2, 3, 1).asnumpy()
# 结果展示
fig = plt.figure(figsize=(3, 3), dpi=120)
for i in range(25):fig.add_subplot(5, 5, i + 1)plt.axis("off")plt.imshow(images[i].squeeze(), cmap="gray")
plt.show()

总结

生成对抗网络GAN通过生成器和判别器之间的对抗训练,从而获得高质量的数据扩充。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1472322.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

昇思25天学习打卡营第8天|MindSpore保存与加载(保存和加载MindIR)

在MindIR中,一个函数图(FuncGraph)表示一个普通函数的定义,函数图一般由ParameterNode、ValueNode和CNode组成有向无环图,可以清晰地表达出从参数到返回值的计算过程。在上图中可以看出,python代码中两个函…

Python面试宝典第6题:有效的括号

题目 给定一个只包括 (、)、{、}、[、] 这些字符的字符串,判断该字符串是否有效。有效字符串需要满足以下的条件。 1、左括号必须用相同类型的右括号闭合。 2、左括号必须以正确的顺序闭合。 3、每个右括号都有一个对应的相同类型的左括号。 注意:空字符…

搞了个 WEB 串口终端,简单分享下

每次换电脑总要找各种串口终端软件,很烦。 有的软件要付费,有的软件要注册,很烦。 找到免费的,还得先下载下来,很烦。 开源的软件下载速度不稳定,很烦。 公司电脑有监控还得让 IT 同事来安装&#xff0…

后端之路——最规范、便捷的spring boot工程配置

一、参数配置化 上一篇我们学了阿里云OSS的使用,那么我们为了方便使用OSS来上传文件,就创建了一个【util】类,里面有一个【AliOSSUtils】类,虽然本人觉得没啥不方便的,但是黑马视频又说这样还是存在不便维护和管理问题…

生态系统NPP及碳源、碳汇模拟技术教程

原文链接:生态系统NPP及碳源、碳汇模拟技术教程https://mp.weixin.qq.com/s?__bizMzUzNTczMDMxMg&mid2247608293&idx3&sn2604c5c4e061b4f15bb8aa81cf6dadd1&chksmfa826602cdf5ef145c4d170bed2e803cd71266626d6a6818c167e8af0da93557c1288da21a71&a…

FPGA问题

fpga 问题 ep2c5t144 开发板 第一道坎,安装软件;没有注册,无法产生sop文件,无法下载 没有相应的库的quartus ii版本,需要另下载 第二道坎,模拟器的下载,安装; 第三道,v…

在window上搭建docker

1、打开Hyper-V安装 在地址栏输入控制面板,然后回车 勾选Hyper-V安装,如果没有找到Hyper-V,那么请走第2步 2、如果没有Hyper-V(可选)第一步无法打开 家庭版本需要开启Hyper-V 创建一个文本文档,后缀名称为.bat.名称…

油猴Safari浏览器插件:Tampermonkey for Mac 下载

Tampermonkey 是一个强大的浏览器扩展,用于运行用户脚本,这些脚本可以自定义和增强网页的功能。它允许用户在网页上执行各种自动化任务,比如自动填写表单、移除广告、改变页面布局等。适用浏览器: Tampermonkey 适用于多数主流浏览…

第二周:计算机网络概述(下)

一、计算机网络性能指标(速率、带宽、延迟) 1、速率 2、带宽 3、延迟/时延 前面讲分组交换的时候介绍了,有一种延迟叫“传输延迟”,即发送一个报文,从第一个分组的发送,到最后一个分组的发送完成的这段时…

PyQT: 开发一款ROI绘制小程序

在一些基于图像或者视频流的应用中,比如电子围栏/客流统计等,我们需要手动绘制一些感兴趣(Region of Interest,简称ROI)区域。 在这里,我们基于Python和PyQt5框架开发了一款桌面应用程序,允许用…

Renesas R7FA8D1BH (Cortex®-M85)串口应用总结

目录 概述 1 软硬件 1.1 软硬件环境信息 1.2 开发板信息 1.3 调试器信息 2 FSP和KEIL配置串口 2.1 配置参数 2.2 生成基于Keil的软件架构 3 FSP代码 3.1 FSP中UART接口函数 3.2 案例代码介绍 3.3 案例代码存在的问题 4 UART代码实现 4.1 功能函数介绍 4.2 完整…

如何在Qt使用uchardet库

如何在 Qt 中使用 uchardet 库 文章目录 如何在 Qt 中使用 uchardet 库一、简介二、uchardet库的下载三、在Qt中直接调用四、编译成库文件后调用4.1 编译工具下载4.2 uchardet源码编译4.3 测试编译文件4.4 Qt中使用 五、一些小问题5.1 测试文件存在的问题5.2 uchardet库相关 六…

嵌入式学习——硬件(UART)——day55

1. UART 1.1 定义 UART(Universal Asynchronous Receiver/Transmitter,通用异步收发器)是一种用于串行通信的硬件设备或模块。它的主要功能是将数据在串行和并行格式之间进行转换。UART通常用于计算机与外围设备或嵌入式系统之间的数据传输。…

【Linux系统编程】深入剖析:四大IO模型机制与应用(阻塞、非阻塞、多路复用、信号驱动IO 全解读)

目录 概述: 1. 阻塞IO (Blocking IO) 2. 非阻塞IO (Non-blocking IO) 3. IO多路复用 (I/O Multiplexing) 4. 信号驱动IO (Signal-driven IO) 阻塞式IO 非阻塞式IO 信号驱动IO(Signal-driven IO) 信号IO实例: IO多路复用…

nginx 搭理禅道

1.安装nginx。 2.安装禅道。 3.nginx 配置文件 location /zentao/ { proxy_pass http://192.168.100.66/zentao/;proxy_set_header Host $host;proxy_set_header X-Real-IP $remote_addr;proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;proxy_set_header X-F…

Labview_Workers5.0 学习笔记

1.Local Request 个人理解该类型的请求针对自身的,由EHL或者MHL到该vi的MHL中。 使用快速放置快捷键"Ctrl9"创建方法如下: 创建后的API接口命名均为rql开头,并且在所选main.vi中的MHL创建对应的条件分支。 此时使用该API函数就…

巴图自动化PN转Modbus RTU协议转换网关模块快速配置

工业领域中常用的通讯协议有:Profinet协议,Modbus协议,ModbusTCP协议,Profibus协议,Profibus DP协议,EtherCAT协议,EtherNET协议,CAN,CanOpen等,它们在自动化…

78110A雷达信号模拟软件

78110A雷达信号模拟软件 78110A雷达信号模拟软件(简称雷达信号模拟软件)主要用于模拟产生雷达发射信号和目标回波信号,软件将编译生成的雷达信号任意波数据下载到信号发生器中,主要是1466-V矢量信号发生器,可实现雷达信号模拟产生。软件可模…

Zabbix 配置SNMP监控

Zabbix SNMP监控介绍 Zabbix提供了强大的SNMP监控功能,可以用于监控网络设备、服务器和其他支持SNMP协议的设备。SNMP(Simple Network Management Protocol,简单网络管理协议)是一种广泛用于网络管理的协议。它用于监控网络设备&…