昇思25天学习打卡营第十一天|DCGAN生成漫画头像

练习营进入第11天了,今天学习的内容是DCGAN生成漫画头像,记录一下学习内容:

GAN基础原理

这部分原理介绍参考GAN图像生成。

DCGAN原理

DCGAN(深度卷积对抗生成网络,Deep Convolutional Generative Adversarial Networks)是GAN的直接扩展。不同之处在于,DCGAN会分别在判别器和生成器中使用卷积和转置卷积层。

它最早由Radford等人在论文Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks中进行描述。判别器由分层的卷积层、BatchNorm层和LeakyReLU激活层组成。输入是3x64x64的图像,输出是该图像为真图像的概率。生成器则是由转置卷积层、BatchNorm层和ReLU激活层组成。输入是标准正态分布中提取出的隐向量 z z z,输出是3x64x64的RGB图像。

本教程将使用动漫头像数据集来训练一个生成式对抗网络,接着使用该网络生成动漫头像图片。

数据准备与处理

首先我们将数据集下载到指定目录下并解压。
下载后的数据集目录结构如下:

./faces/faces
├── 0.jpg
├── 1.jpg
├── 2.jpg
├── 3.jpg
├── 4.jpg...
├── 70169.jpg
└── 70170.jpg

数据处理

首先为执行过程定义一些输入:

batch_size = 128          # 批量大小
image_size = 64           # 训练图像空间大小
nc = 3                    # 图像彩色通道数
nz = 100                  # 隐向量的长度
ngf = 64                  # 特征图在生成器中的大小
ndf = 64                  # 特征图在判别器中的大小
num_epochs = 3           # 训练周期数
lr = 0.0002               # 学习率
beta1 = 0.5               # Adam优化器的beta1超参数

定义create_dataset_imagenet函数对数据进行处理和增强操作。

import numpy as np
import mindspore.dataset as ds
import mindspore.dataset.vision as visiondef create_dataset_imagenet(dataset_path):"""数据加载"""dataset = ds.ImageFolderDataset(dataset_path,num_parallel_workers=4,shuffle=True,decode=True)# 数据增强操作transforms = [vision.Resize(image_size),vision.CenterCrop(image_size),vision.HWC2CHW(),lambda x: ((x / 255).astype("float32"))]# 数据映射操作dataset = dataset.project('image')dataset = dataset.map(transforms, 'image')# 批量操作dataset = dataset.batch(batch_size)return datasetdataset = create_dataset_imagenet('./faces')

通过create_dict_iterator函数将数据转换成字典迭代器,然后使用matplotlib模块可视化部分训练数据。

import matplotlib.pyplot as pltdef plot_data(data):# 可视化部分训练数据plt.figure(figsize=(10, 3), dpi=140)for i, image in enumerate(data[0][:30], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow(image.transpose(1, 2, 0))plt.show()sample_data = next(dataset.create_tuple_iterator(output_numpy=True))
plot_data(sample_data)

在这里插入图片描述

构造网络

当处理完数据后,就可以来进行网络的搭建了。按照DCGAN论文中的描述,所有模型权重均应从mean为0,sigma为0.02的正态分布中随机初始化。

生成器

生成器G的功能是将隐向量z映射到数据空间。由于数据是图像,这一过程也会创建与真实图像大小相同的 RGB 图像。在实践场景中,该功能是通过一系列Conv2dTranspose转置卷积层来完成的,每个层都与BatchNorm2d层和ReLu激活层配对,输出数据会经过tanh函数,使其返回[-1,1]的数据范围内。

DCGAN论文生成图像如下所示:

dcgangenerator

图片来源:Unsupervised Representation Learning With Deep Convolutional Generative Adversarial Networks.

我们通过输入部分中设置的nzngfnc来影响代码中的生成器结构。nz是隐向量z的长度,ngf与通过生成器传播的特征图的大小有关,nc是输出图像中的通道数。

判别器

如前所述,判别器D是一个二分类网络模型,输出判定该图像为真实图的概率。通过一系列的Conv2dBatchNorm2dLeakyReLU层对其进行处理,最后通过Sigmoid激活函数得到最终概率。

DCGAN论文提到,使用卷积而不是通过池化来进行下采样是一个好方法,因为它可以让网络学习自己的池化特征。

模型训练

损失函数

当定义了DG后,接下来将使用MindSpore中定义的二进制交叉熵损失函数BCELoss。

# 定义损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

优化器

这里设置了两个单独的优化器,一个用于D,另一个用于G。这两个都是lr = 0.0002beta1 = 0.5的Adam优化器。

训练模型

训练分为两个主要部分:训练判别器和训练生成器。

  • 训练判别器

    训练判别器的目的是最大程度地提高判别图像真伪的概率。按照Goodfellow的方法,是希望通过提高其随机梯度来更新判别器,所以我们要最大化 l o g D ( x ) + l o g ( 1 − D ( G ( z ) ) log D(x) + log(1 - D(G(z)) logD(x)+log(1D(G(z))的值。

  • 训练生成器

    如DCGAN论文所述,我们希望通过最小化 l o g ( 1 − D ( G ( z ) ) ) log(1 - D(G(z))) log(1D(G(z)))来训练生成器,以产生更好的虚假图像。

在这两个部分中,分别获取训练过程中的损失,并在每个周期结束时进行统计,将fixed_noise批量推送到生成器中,以直观地跟踪G的训练进度。

结果展示

运行下面代码,描绘DG损失与训练迭代的关系图:

plt.figure(figsize=(10, 5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses, label="G", color='blue')
plt.plot(D_losses, label="D", color='orange')
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

可视化训练过程中通过隐向量fixed_noise生成的图像。

import matplotlib.pyplot as plt
import matplotlib.animation as animationdef showGif(image_list):show_list = []fig = plt.figure(figsize=(8, 3), dpi=120)for epoch in range(len(image_list)):images = []for i in range(3):row = np.concatenate((image_list[epoch][i * 8:(i + 1) * 8]), axis=1)images.append(row)img = np.clip(np.concatenate((images[:]), axis=0), 0, 1)plt.axis("off")show_list.append([plt.imshow(img)])ani = animation.ArtistAnimation(fig, show_list, interval=1000, repeat_delay=1000, blit=True)ani.save('./dcgan.gif', writer='pillow', fps=1)showGif(image_list)

在这里插入图片描述

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1474336.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Java的基础语法

叠甲:以下文章主要是依靠我的实际编码学习中总结出来的经验之谈,求逻辑自洽,不能百分百保证正确,有错误、未定义、不合适的内容请尽情指出! 文章目录 1.第一份程序1.1.代码编写1.2.代码运行1.2.1.命令行编译1.2.2.IEDA…

FL Studio 2024 发布,添加 FL Cloud 插件、AI 等功能

作为今年最受期待的音乐制作 DAW 更新之一,FL Studio 2024发布引入了新功能,同时采用了新的命名方式,从现在起将把发布年份纳入其名称中。DAW 的新增功能包括在 FL Cloud 中添加插件、AI 驱动的音乐创作工具和 FL Studio 的新效果。 FL Cloud…

【项目设计】负载均衡式——Online Judge

负载均衡式——Online Judge😎 前言🙌Online Judge 项目一、项目介绍二、项目技术栈三、项目使用环境四、项目宏观框架五、项目后端服务实现过程1、comm模块设计1.1 Log.hpp实现1.2 Util.hpp实现 2、compiler_server 模块设计2.1compile.hpp文件代码编写…

调制信号识别系列 (一):基准模型

调制信号识别系列 (一):基准模型 说明:本文包含对CNN和CNNLSTM基准模型的复现,模型架构参考下述两篇文章 文章目录 调制信号识别系列 (一):基准模型一、论文1、DL-PR: Generalized automatic modulation classification method b…

ThreadPoolExecutor - 管理线程池的核心类

下面是使用给定的初始参数创建一个新的 ThreadPoolExecutor &#xff08;构造方法&#xff09;。 public ThreadPoolExecutor(int corePoolSize,int maximumPoolSize,long keepAliveTime,TimeUnit unit,BlockingQueue<Runnable> workQueue,ThreadFactory threadFactory,…

【Python】搭建属于自己 AI 机器人

目录 前言 1 准备工作 1.1 环境搭建 1.2 获取 API KEY 2 写代码 2.1 引用库 2.2 创建用户 2.3 创建对话 2.4 输出内容 2.5 调试 2.6 全部代码 2.7 简短的总结 3 优化代码 3.1 规范代码 3.1.1 引用库 3.1.2 创建提示词 3.1.3 创建模型 3.1.4 规范输出&#xf…

Git详细安装和使用教程

文章目录 准备工作-gitee注册认识及安装GitGit配置用户信息本地初始化Git仓库记录每次更新到仓库查看及切换历史版本Git忽略文件和查看文件状态Git分支-查看及切换Git分支-创建分支Git分支-合并及删除分支Git分支-命令补充Git分支-冲突需求: 准备工作-gitee注册 传送门: gite…

HDF4文件转TIF格式

HDF4 HDF4&#xff08;Hierarchical Data Format version 4&#xff09;是一种用于存储和管理机器间数据的库和多功能文件格式。它是一种自描述的文件格式&#xff0c;用于存档和管理数据。 HDF4与HDF5是两种截然不同的技术&#xff0c;HDF5解决了HDF4的一些重要缺陷。因此&am…

[终端安全]-3 移动终端之硬件安全(TEE)

&#xff08;参考资料&#xff1a;TrustZone for V8-A. pdf&#xff0c;来源ARM DEVELOPER官网&#xff09; TEE&#xff08;Trusted Execution Environment&#xff0c;可信执行环境&#xff09;是用于执行敏感代码和处理敏感数据的独立安全区域&#xff1b;以ARM TrustZone为…

cs231n作业1——Softmax

参考文章&#xff1a;cs231n assignment1——softmax Softmax softmax其实和SVM差别不大&#xff0c;两者损失函数不同&#xff0c;softmax就是把各个类的得分转化成了概率。 损失函数&#xff1a; def softmax_loss_naive(W, X, y, reg):loss 0.0dW np.zeros_like(W)num_…

信号与系统笔记分享

文章目录 一、导论信号分类周期问题能量信号和功率信号系统的线性判断时变&#xff0c;时不变系统因果系统判断记忆性系统判断稳定性系统判断 二、信号时域分析阶跃函数冲激函数取样性质四种特性1 筛选特性2 抽样特性3 展缩特性4 卷积特性卷积作用 冲激偶函数奇函数性质公式推导…

Ubuntu 20.04下多版本CUDA的安装与切换 超详细教程

目录 前言一、安装 CUDA1.找到所需版本对应命令2.下载 .run 文件3.安装 CUDA4.配置环境变量4.1 写入环境变量4.2 软连接 5.验证安装 二、安装 cudnn1.下载 cudnn2.解压文件3.替换文件4.验证安装 三、切换 CUDA 版本1.切换版本2.检查版本 前言 当我们复现代码时&#xff0c;总会…

彻底解决Path with “WEB-INF“ or “META-INF“: [WEB-INF/views/index.jsp]

背景描述 项目使用的是springboot2jsp的架构。以前好好的项目复制了一份&#xff0c;然后就无法访问报错。百度了好久都乱七八糟的&#xff0c;还没有解决问题。错误消息如下&#xff1a; 2024-07-05 15:45:51.335 INFO [http-nio-12581-exec-1]org.springframework.web.ser…

Linux服务器使用总结-不定时更新

# 查看升级日志 cat /var/log/dpkg.log |grep nvidia|grep libnvidia-common

阶段三:项目开发---搭建项目前后端系统基础架构:任务13:实现基本的登录功能

任务描述 任务名称&#xff1a; 实现基本的登录功能 知识点&#xff1a; 了解前端Vue项目的基本执行过程 重 点&#xff1a; 构建项目的基本登陆功能 内 容&#xff1a; 通过实现项目的基本登录功能&#xff0c;来了解前端Vue项目的基本执行过程&#xff0c;并完成基…

firewalld(8) policies

简介 前面的文章中我们介绍了firewalld的一些基本配置以及NAT的相关配置。在前面的配置中&#xff0c;我们所有的策略都是与zone相关的&#xff0c;例如配置的rich rule&#xff0c;--direct,以及NAT,并且这些配置都是数据包进入zone或者从zone发出时设置的策略。 我们在介绍…

昇思25天学习打卡营第15天 | Vision Transformer图像分类

内容介绍&#xff1a; 近些年&#xff0c;随着基于自注意&#xff08;Self-Attention&#xff09;结构的模型的发展&#xff0c;特别是Transformer模型的提出&#xff0c;极大地促进了自然语言处理模型的发展。由于Transformers的计算效率和可扩展性&#xff0c;它已经能够训练…

在VMware虚拟机的创建以及安装linux操作系统

一、创建虚拟机 1.双击打开下载好的VMware Workstation软件 2.点击“创建新的虚拟机” 3.根据个人选择需要创建的虚拟机&#xff0c;点击下一步 4.直接点击下一步 5.选择稍后安装操作系统&#xff0c;点击下一步 、 6.选择需要的操作系统&#xff0c;点击下一步 7.根据…

YOLOv8改进---BiFPN特征融合

一、BiFPN原理 1.1 基本原理 BiFPN&#xff08;Bidirectional Feature Pyramid Network&#xff09;&#xff0c;双向特征金字塔网络是一种高效的多尺度特征融合网络&#xff0c;其基本原理概括分为以下几点&#xff1a; 双向特征融合&#xff1a;BiFPN允许特征在自顶向下和自…

【踩坑】修复pyinstaller报错 No module named pkg_resources.extern

转载请注明出处&#xff1a;小锋学长生活大爆炸[xfxuezhagn.cn] 如果本文帮助到了你&#xff0c;欢迎[点赞、收藏、关注]哦~ 报错如下&#xff1a; 修复方法&#xff1a; pip install --upgrade setuptools pippyinstaller -F -w main.py --hidden-importpkg_resources.py2_wa…