昇思MindSpore学习笔记4-02生成式--DCGAN生成漫画头像

摘要:

        记录了昇思MindSpore AI框架使用70171张动漫头像图片训练一个DCGAN神经网络生成式对抗网络,并用来生成漫画头像的过程、步骤。包括环境准备、下载数据集、加载数据和预处理、构造网络、模型训练等。

一、概念

深度卷积对抗生成网络DCGAN

Deep Convolutional Generative Adversarial Networks

        扩展GAN

        判别器

                组成

                        卷积层

                        BatchNorm层

                        LeakyReLU激活层

                功能

                        输入是3*64*64图像

                        输出是真图像概率

        生成器

                组成

                        转置卷积层

                        BatchNorm层

                        ReLU激活层

                功能

                        输入是标准正态分布中提取出的隐向量z

                        输出是3*64*64 RGB图像。

  • 环境准备
%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14

三、数据准备与处理

1.下载数据集

下载到指定目录下并解压代码如下

from download import download
url = "https://download.mindspore.cn/dataset/Faces/faces.zip"
path = download(url, "./faces", kind="zip", replace=True)

输出:

Downloading data from https://download-mindspore.osinfra.cn/dataset/Faces/faces.zip (274.6 MB)file_sizes: 100%|████████████████████████████| 288M/288M [00:52<00:00, 5.49MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./faces

2.数据集介绍

使用的动漫头像数据集共有70,171张动漫头像图片,图片大小均为96*96。

数据集目录结构如下:

./faces/faces
├── 0.jpg
├── 1.jpg
├── 2.jpg
├── 3.jpg
├── 4.jpg...
├── 70169.jpg
└── 70170.jpg

3.数据处理

(1) 执行过程参数定义:

batch_size = 128          # 批量大小
image_size = 64           # 训练图像空间大小
nc = 3                    # 图像彩色通道数
nz = 100                  # 隐向量的长度
ngf = 64                  # 特征图在生成器中的大小
ndf = 64                  # 特征图在判别器中的大小
num_epochs = 3            # 训练周期数
lr = 0.0002               # 学习率
beta1 = 0.5               # Adam优化器的beta1超参数

(2) 数据处理和增强

create_dataset_imagenet函数

import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as vision
​
def create_dataset_imagenet(dataset_path):"""数据加载"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)
​# 数据增强操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]
​# 数据映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')
​# 批量操作dataset = dataset.batch(batch_size)return dataset
​
dataset = create_dataset_imagenet('./faces')

(3) 查看训练数据

matplotlib模块

数据转换成字典迭代器

        create_dict_iterator函数

import matplotlib.pyplot as plt
​
def plot_data(data):# 可视化部分训练数据plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()
​
sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)

四、构造网络

模型权重随机初始化

范围:mean为0,sigma为0.02的正态分布【数学不好】

1. 生成器

生成器G

        隐向量z映射数据空间

        数据源是图像

        生成与图像大小相同的 RGB 图像

        Conv2dTranspose转置卷积层

        每个层与BatchNorm2d层和ReLu激活层配对

        tanh函数

        输出[-1,1]范围内数据

DCGAN生成图像过程如下所示:

生成器结构参数:

        nz         隐向量z的长度

        ngf         有关生成器传播的特征图大小

        nc         输出图像通道数

生成器代码:

import mindspore as ms
from mindspore import nn, ops
from mindspore.common.initializer import Normal
​
weight_init = Normal(mean=0, sigma=0.02)
gamma_init = Normal(mean=1, sigma=0.02)
​
class Generator(nn.Cell):"""DCGAN网络生成器"""
​def __init__(self):super(Generator, self).__init__()self.generator = nn.SequentialCell(nn.Conv2dTranspose(nz, ngf * 8, 4, 1, 'valid', weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 8, ngf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 4, ngf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf * 2, ngf, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf, gamma_init=gamma_init),nn.ReLU(),nn.Conv2dTranspose(ngf, nc, 4, 2, 'pad', 1, weight_init=weight_init),nn.Tanh())
​def construct(self, x):return self.generator(x)
​
generator = Generator()

2. 判别器

判别器D

        二分类网络模型

                Conv2d

                BatchNorm2d

                LeakyReLU

                Sigmoid激活函数

                输出判定图像真实概率

判别器代码:

class Discriminator(nn.Cell):"""DCGAN网络判别器"""
​def __init__(self):super(Discriminator, self).__init__()self.discriminator = nn.SequentialCell(nn.Conv2d(nc, ndf, 4, 2, 'pad', 1, weight_init=weight_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf, ndf * 2, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 2, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 4, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 'pad', 1, weight_init=weight_init),nn.BatchNorm2d(ngf * 8, gamma_init=gamma_init),nn.LeakyReLU(0.2),nn.Conv2d(ndf * 8, 1, 4, 1, 'valid', weight_init=weight_init),)self.adv_layer = nn.Sigmoid()
​def construct(self, x):out = self.discriminator(x)out = out.reshape(out.shape[0], -1)return self.adv_layer(out)
​
discriminator = Discriminator()

五、模型训练

1. 损失函数

二进制交叉熵损失函数MindSpore.nn.BCELoss

# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

2. 优化器

Adam优化器

        lr = 0.0002

        beta1 = 0.5

# 为生成器和判别器设置优化器
optimizer_D = nn.Adam(discriminator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G = nn.Adam(generator.trainable_params(), learning_rate=lr, beta1=beta1)
optimizer_G.update_parameters_name('optim_g.')
optimizer_D.update_parameters_name('optim_d.')

3. 训练模型

训练判别器

        提高判别图像真伪的概率

        Goodfellow方法:提高随机梯度更新判别器

        最大化logD(x)+log(1-D(G(z)))

训练生成器

        最小化log(1−D(G(z)))

        产生更好的虚拟图像

两个部分分别

        获取训练损失

        每个周期结束统计

        批量推送fixed_noise到生成器

        跟踪G的训练进度

模型训练正向逻辑:

def generator_forward(real_imgs, valid):# 将噪声采样为发生器的输入z = ops.standard_normal((real_imgs.shape[0], nz, 1, 1))
​# 生成一批图像gen_imgs = generator(z)
​# 损失衡量发生器绕过判别器的能力g_loss = adversarial_loss(discriminator(gen_imgs), valid)
​return g_loss, gen_imgs
​
def discriminator_forward(real_imgs, gen_imgs, valid, fake):# 衡量鉴别器从生成的样本中对真实样本进行分类的能力real_loss = adversarial_loss(discriminator(real_imgs), valid)fake_loss = adversarial_loss(discriminator(gen_imgs), fake)d_loss = (real_loss + fake_loss) / 2return d_loss
​
grad_generator_fn = ms.value_and_grad(generator_forward, None,optimizer_G.parameters,has_aux=True)
grad_discriminator_fn = ms.value_and_grad(discriminator_forward, None,optimizer_D.parameters)
​
@ms.jit
def train_step(imgs):valid = ops.ones((imgs.shape[0], 1), mindspore.float32)fake = ops.zeros((imgs.shape[0], 1), mindspore.float32)
​(g_loss, gen_imgs), g_grads = grad_generator_fn(imgs, valid)optimizer_G(g_grads)d_loss, d_grads = grad_discriminator_fn(imgs, gen_imgs, valid, fake)optimizer_D(d_grads)
​return g_loss, d_loss, gen_imgs

循环训练网络

        迭代50次收集生成器、判别器的损失一次

        绘制损失函数的图像

import mindspore
​
G_losses = []
D_losses = []
image_list = []
​
total = dataset.get_dataset_size()
for epoch in range(num_epochs):generator.set_train()discriminator.set_train()# 为每轮训练读入数据for i, (imgs, ) in enumerate(dataset.create_tuple_iterator()):g_loss, d_loss, gen_imgs = train_step(imgs)if i % 100 == 0 or i == total - 1:# 输出训练记录print('[%2d/%d][%3d/%d]   Loss_D:%7.4f  Loss_G:%7.4f' % (epoch + 1, num_epochs, i + 1, total, d_loss.asnumpy(), g_loss.asnumpy()))D_losses.append(d_loss.asnumpy())G_losses.append(g_loss.asnumpy())
​# 每个epoch结束后,使用生成器生成一组图片generator.set_train(False)fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))img = generator(fixed_noise)image_list.append(img.transpose(0, 2, 3, 1).asnumpy())
​# 保存网络模型参数为ckpt文件mindspore.save_checkpoint(generator, "./generator.ckpt")mindspore.save_checkpoint(discriminator, "./discriminator.ckpt")

输出:

[ 1/3][  1/549]   Loss_D: 0.2635  Loss_G: 4.8150
[ 1/3][101/549]   Loss_D: 0.4023  Loss_G: 4.9807
[ 1/3][201/549]   Loss_D: 0.2425  Loss_G: 1.6335
[ 1/3][301/549]   Loss_D: 0.5856  Loss_G: 0.6079
[ 1/3][401/549]   Loss_D: 0.1922  Loss_G: 4.3977
[ 1/3][501/549]   Loss_D: 0.1065  Loss_G: 2.3724
[ 1/3][549/549]   Loss_D: 0.1893  Loss_G: 1.6483
[ 2/3][  1/549]   Loss_D: 0.3370  Loss_G: 4.4347
[ 2/3][101/549]   Loss_D: 0.4681  Loss_G: 0.8623
[ 2/3][201/549]   Loss_D: 0.1856  Loss_G: 3.7501
[ 2/3][301/549]   Loss_D: 0.1932  Loss_G: 2.6333
[ 2/3][401/549]   Loss_D: 0.1310  Loss_G: 2.2524
[ 2/3][501/549]   Loss_D: 0.2531  Loss_G: 1.4690
[ 2/3][549/549]   Loss_D: 0.1192  Loss_G: 5.7166
[ 3/3][  1/549]   Loss_D: 0.0716  Loss_G: 2.9886
[ 3/3][101/549]   Loss_D: 0.1345  Loss_G: 2.6544
[ 3/3][201/549]   Loss_D: 0.1097  Loss_G: 2.8604
[ 3/3][301/549]   Loss_D: 0.2066  Loss_G: 6.1513
[ 3/3][401/549]   Loss_D: 0.0797  Loss_G: 3.2336
[ 3/3][501/549]   Loss_D: 0.2618  Loss_G: 4.0991
[ 3/3][549/549]   Loss_D: 0.5600  Loss_G:10.7509

4. 结果展示

描绘D和G损失与训练迭代的关系图:

plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G", color='blue')
plt.plot(D_losses, label="D", color='orange')
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

输出:

显示隐向量fixed_noise训练生成的图像

import matplotlib.pyplot as plt
import matplotlib.animation as animation
​
def showGif(image_list):show_list = []fig = plt.figure(figsize=(8, 3), dpi=120)for epoch in range(len(image_list)):images = []for i in range(3):row = np.concatenate((image_list[epoch][i * 8:(i + 1) * 8]), axis=1)images.append(row)img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)plt.axis("off")show_list.append([plt.imshow(img)])
​ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)ani.save('./dcgan.gif', writer='pillow', fps=1)
​
showGif(image_list)

输出:

训练次数增多,图像质量越好

num_epochs达到50以上,生成动漫头像图片与数据集较为相似

加载生成器网络模型参数文件来生成图像代码:

# 从文件中获取模型参数并加载到网络中
mindspore.load_checkpoint("./generator.ckpt", generator)
​
fixed_noise = ops.standard_normal((batch_size, nz, 1, 1))
img64 = generator(fixed_noise).transpose(0, 2, 3, 1).asnumpy()
​
fig = plt.figure(figsize=(8, 3), dpi=120)
images = []
for i in range(3):images.append(np.concatenate((img64[i * 8:(i + 1) * 8]), axis=1))
img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)
plt.axis("off")
plt.imshow(img)
plt.show()

输出:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1472548.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Runway Gen-3 实测,这就是 AI 视频生成的 No.1!视频高清化EvTexture 安装配置使用!

Runway Gen-3 实测,这就是 AI 视频生成的 No.1!视频高清化EvTexture 安装配置使用! 由于 Runway 作为一个具体的工具或平台,其详细信息在搜索结果中没有提供,我将基于假设 Runway 是一个支持人工智能和机器学习模型的创意工具,提供一个关于使用技巧和类似开源项目的文稿总…

视频字幕提取在线工具有哪些?总结5个字幕提取工具

平时在沉浸式追剧的时候&#xff0c;我们常常都会被影视剧中的各种金句爆梗而逗得开怀大笑~而真正要用到时候却总是一片头脑空白。其实要记住它们最好的办法便是将其提取留档下来&#xff0c;每次有需要的时候打开就能一下子回顾到~ 今天就来带大家盘一盘视频字幕提取的软件好…

动态规划入门,从简单递归到记忆化搜索到动态规划

动态规划入门&#xff0c;从简单递归到记忆化搜索到动态规划 打家劫舍 class Solution {private int nums[];public int rob(int[] nums) {this.nums nums;return dfs(nums.length - 1);}public int dfs(int i){if (i < 0){return 0;}int res Math.max(dfs(i - 1), dfs(i…

【分布式数据仓库Hive】Hive的安装配置及测试

目录 一、数据库MySQL安装 1. 检查操作系统是否有MySQL安装残留 2. 删除残留的MySQL安装&#xff08;使用yum&#xff09; 3. 安装MySQL依赖包、客户端和服务器 4. MySQL登录账户root设置密码 5. 启动MySQL服务 6. 登录MySQL&#xff0c;进入数据库操作提示符 7. 授权H…

SpringBoot 启动流程六

SpringBoot启动流程六 这句话是创建一个上下文对象 就是最终返回的那个上下文 我们这个creatApplicationContext方法 是调用的这个方法 传入一个类型 我们通过打断点的方式 就可以看到context里面的东西 加载容器对象 当我们把依赖改成starter-web时 这个容器对象会进行…

深度解析Java世界中的对象镜像:浅拷贝与深拷贝的奥秘与应用

在Java编程的浩瀚宇宙中&#xff0c;对象拷贝是一项既基础又至关重要的技术。它直接关系到程序的性能、资源管理及数据安全性。然而&#xff0c;提及对象拷贝&#xff0c;不得不深入探讨其两大核心类型&#xff1a;浅拷贝&#xff08;Shallow Copy&#xff09;与深拷贝&#xf…

等保2.0标准相比之前的有哪些重大变化?

在数字化的浪潮中&#xff0c;网络安全如同一艘坚固的航船&#xff0c;承载着国家与民族的希望&#xff0c;驶向信息化的彼岸。等级保护制度&#xff08;等保&#xff09;作为中国网络安全的守护神&#xff0c;经过岁月的洗礼与智慧的积淀&#xff0c;迎来了等保2.0的时代&…

Android仿天眼查人物关系图

效果图预览 绘制思路 这里使用了中学解析几何知识 XPoint OPointX OPointXcosθ&#xff1b; YPoint OPointY OPointYsinθ&#xff1b; canvas.drawText(lists.get(i).getName(), XPoint (float) Math.cos(pere * i 5) * radius[i % radius.length] - 30, YPoint (fl…

C语言 | Leetcode C语言题解之第213题打家劫舍II

题目&#xff1a; 题解&#xff1a; int robRange(int* nums, int start, int end) {int first nums[start], second fmax(nums[start], nums[start 1]);for (int i start 2; i < end; i) {int temp second;second fmax(first nums[i], second);first temp;}retur…

昇思25天学习打卡营第9天|保存与加载

保存与加载 在训练网络模型的过程中&#xff0c;保存中间和最后结果可以用来微调和后续的模型推理与部署。 # 首先定义一个模型 def network():model nn.SequentialCell(nn.Flatten(),nn.Dense(28*28, 512),nn.ReLU(),nn.Dense(512, 512),nn.ReLU(),nn.Dense(512, 10))retur…

Hive-存储-文件格式

一、前言 数据存储是Hive的基础&#xff0c;选择合适的底层数据存储格式&#xff0c;可以在不改变Hql的前提下得到大的性能提升。类似mysql选择适合场景的存储引擎。 Hive支持的存储格式有 文本格式&#xff08;TextFile&#xff09; 二进制序列化文件 &#xff08;SequenceF…

(vue)el-tabs选中最后一项后更新数据后无法展开

(vue)el-tabs选中最后一项后更新数据后无法展开 效果&#xff1a; 原因&#xff1a;选中时绑定的值在数据更新后找不到 思路&#xff1a;更新数据时把选中的v-model的属性赋为初始值 写法&#xff1a; <el-form-item label"字段选择"><el-tabsv-model&qu…

相关款式1111

一、花梨木迎客松 1. 风速打单 发现只有在兄弟店铺有售卖 六月份成交订单数有62笔 2. 生意参谋 兄弟店铺商品访客数&#xff1a;3548&#xff0c;支付件数&#xff1a;95件 二. 竹节茶刷&#xff08;引流&#xff09; 1. 风速打单 六月订单数有165笔 兄弟&#xff1a;…

喜报 | 怿星携高性价比国产方案亮相IAEIS峰会并荣获“优秀创新产品奖”

近日&#xff0c;由深圳市汽车电子行业协会主办的主题为&#xff1a;“布局全球产业链&#xff0c;促进智能网联汽车产业高质量发展”IAEIS 2024第十三届国际汽车电子产业峰会”暨“2023年度汽车电子科学技术奖”颁奖典礼在深圳隆重举行。 怿星科技携高性价比的「车载网络通信 …

【45 Pandas+Pyecharts | 去哪儿海南旅游攻略数据分析可视化】

文章目录 &#x1f3f3;️‍&#x1f308; 1. 导入模块&#x1f3f3;️‍&#x1f308; 2. Pandas数据处理2.1 读取数据2.2 查看数据信息2.3 日期处理&#xff0c;提取年份、月份2.4 经费处理2.5 天数处理 &#x1f3f3;️‍&#x1f308; 3. Pyecharts数据可视化3.1 出发日期_…

PyCharm中如何将某个文件设置为默认运行文件

之前在使用JetBrain公司的另一款软件IDEA的时候&#xff0c;如果在选中static main函数后按键altenter可以默认以后运行Main类的main函数。最近在使用PyCharm学习Python&#xff0c;既然同为一家公司的产品而且二者的风格如此之像&#xff0c;所以我怀疑PyCharm中肯定也有类似的…

获取VC账号,是成为亚马逊供应商的全面准备与必要条件

成为亚马逊的供应商&#xff0c;拥有VC&#xff08;Vendor Central&#xff09;账号&#xff0c;是众多制造商和品牌所有者的共同目标。这不仅代表了亚马逊对供应商的高度认可&#xff0c;也意味着获得了更多的销售机会和更广阔的市场前景。 全面准备与必要条件是获取VC账号的关…

中科院分区表中被“On Hold”的TOP期刊!爱思唯尔会对中国学者下手吗?

关注GZH【欧亚科睿学术】&#xff0c;GET完整版2023JCR分区列表&#xff01; ↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓↓ 目前&#xff0c;共37本期刊被科睿唯安标记为“On Hold”状态&#xff0c;其中包含5本中科院TOP期刊&#xff0c;主要来自Elsevier出版社旗下。…

简单配置VScode轻量级C++竞赛环境

1. 安装拓展 Chinese是中文&#xff0c;需要重启才可以运行&#xff0c;C/C拓展只是进行语法代码提示&#xff0c;不需要进行任何配置修改&#xff0c;默认即可。 2. 创建文件 如上图创建好各级文件夹&#xff0c;其中C是工作文件夹&#xff0c;.vscode是配置文件夹&#xff0…

前端根据目录生成模块化路由routes

根据约定大于配置的逻辑&#xff0c;如果目录结构约定俗成&#xff0c;前端是可以根据目录结构动态生成路由所需要的 route 结构的&#xff0c;这个过程是要在编译时 进行&#xff0c;生成需要的代码&#xff0c;保证运行时的代码正确即可 主流的打包工具都有对应的方法读取文…