昇思25天学习打卡营第18天|Pix2Pix实现图像转换

Pix2Pix概述

Pix2Pix是基于条件生成对抗网络实现的一种深度学习图像转换模型。Pix2Pix是将cGAN应用于有监督的图像到图像翻译,包括生成器和判别器。

基础原理

cGAN的生成器是将输入图片作为指导信息,由输入图像不断尝试生成用于迷惑判别器的“假”图像,由输入图像转换输出为相应“假”图像的本质是从像素到另一个像素的映射,而传统GAN的生成器是基于一个给定的随机噪声生成图像,输出图像通过其他约束条件控制生成。Pix2Pix中判别器的任务是判断从生成器输出的图像是真实的训练图像还是生成的“假”图像。在生成器与判别器的不断博弈过程中,模型会达到一个平衡点,生成器输出的图像与真实训练数据使得判别器刚好具有50%的概率判断正确。

CGAN的目标损失函数为:

L_{cGAN}(G,D)=E_{(x,y)}[log(D(x,y))]+E_{(x,z)}[log(1-D(x,G(x,z)))]

目标函数是使判别器的损失最大化,而生成器的损失最小化。

pix2pix1

数据准备

from download import downloadurl = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/models/application/dataset_pix2pix.tar"download(url, "./dataset", kind="tar", replace=True)from mindspore import dataset as ds
import matplotlib.pyplot as pltdataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator(output_numpy=True))
# 可视化部分训练数据
plt.figure(figsize=(10, 3), dpi=140)
for i, image in enumerate(data_iter['input_images'][:10], 1):plt.subplot(3, 10, i)plt.axis("off")plt.imshow((image.transpose(1, 2, 0) + 1) / 2)
plt.show()

创建网络

生成器G结构

使用U-Net,它分为两个部分,其中左侧是由卷积和降采样操作组成的压缩路径,右侧是由卷积和上采样组成的扩张路径,扩张的每个网络块的输入由上一层上采样的特征和压缩路径部分的特征拼接而成。

pix2pix2

定义UNet Skip Connection Block

import mindspore
import mindspore.nn as nn
import mindspore.ops as opsclass UNetSkipConnectionBlock(nn.Cell):def __init__(self, outer_nc, inner_nc, in_planes=None, dropout=False,submodule=None, outermost=False, innermost=False, alpha=0.2, norm_mode='batch'):super(UNetSkipConnectionBlock, self).__init__()down_norm = nn.BatchNorm2d(inner_nc)up_norm = nn.BatchNorm2d(outer_nc)use_bias = Falseif norm_mode == 'instance':down_norm = nn.BatchNorm2d(inner_nc, affine=False)up_norm = nn.BatchNorm2d(outer_nc, affine=False)use_bias = Trueif in_planes is None:in_planes = outer_ncdown_conv = nn.Conv2d(in_planes, inner_nc, kernel_size=4,stride=2, padding=1, has_bias=use_bias, pad_mode='pad')down_relu = nn.LeakyReLU(alpha)up_relu = nn.ReLU()if outermost:up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,kernel_size=4, stride=2,padding=1, pad_mode='pad')down = [down_conv]up = [up_relu, up_conv, nn.Tanh()]model = down + [submodule] + upelif innermost:up_conv = nn.Conv2dTranspose(inner_nc, outer_nc,kernel_size=4, stride=2,padding=1, has_bias=use_bias, pad_mode='pad')down = [down_relu, down_conv]up = [up_relu, up_conv, up_norm]model = down + upelse:up_conv = nn.Conv2dTranspose(inner_nc * 2, outer_nc,kernel_size=4, stride=2,padding=1, has_bias=use_bias, pad_mode='pad')down = [down_relu, down_conv, down_norm]up = [up_relu, up_conv, up_norm]model = down + [submodule] + upif dropout:model.append(nn.Dropout(p=0.5))self.model = nn.SequentialCell(model)self.skip_connections = not outermostdef construct(self, x):out = self.model(x)if self.skip_connections:out = ops.concat((out, x), axis=1)return out

基于UNet的生成器

class UNetGenerator(nn.Cell):def __init__(self, in_planes, out_planes, ngf=64, n_layers=8, norm_mode='bn', dropout=False):super(UNetGenerator, self).__init__()unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=None,norm_mode=norm_mode, innermost=True)for _ in range(n_layers - 5):unet_block = UNetSkipConnectionBlock(ngf * 8, ngf * 8, in_planes=None, submodule=unet_block,norm_mode=norm_mode, dropout=dropout)unet_block = UNetSkipConnectionBlock(ngf * 4, ngf * 8, in_planes=None, submodule=unet_block,norm_mode=norm_mode)unet_block = UNetSkipConnectionBlock(ngf * 2, ngf * 4, in_planes=None, submodule=unet_block,norm_mode=norm_mode)unet_block = UNetSkipConnectionBlock(ngf, ngf * 2, in_planes=None, submodule=unet_block,norm_mode=norm_mode)self.model = UNetSkipConnectionBlock(out_planes, ngf, in_planes=in_planes, submodule=unet_block,outermost=True, norm_mode=norm_mode)def construct(self, x):return self.model(x)

基于PatchGAN的判别器

生成的矩阵中的每个点代表原图的一小块区域(patch)。通过矩阵中的各个值来判断原图中对应每个Patch的真假。

import mindspore.nn as nnclass ConvNormRelu(nn.Cell):def __init__(self,in_planes,out_planes,kernel_size=4,stride=2,alpha=0.2,norm_mode='batch',pad_mode='CONSTANT',use_relu=True,padding=None):super(ConvNormRelu, self).__init__()norm = nn.BatchNorm2d(out_planes)if norm_mode == 'instance':norm = nn.BatchNorm2d(out_planes, affine=False)has_bias = (norm_mode == 'instance')if not padding:padding = (kernel_size - 1) // 2if pad_mode == 'CONSTANT':conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad',has_bias=has_bias, padding=padding)layers = [conv, norm]else:paddings = ((0, 0), (0, 0), (padding, padding), (padding, padding))pad = nn.Pad(paddings=paddings, mode=pad_mode)conv = nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad_mode='pad', has_bias=has_bias)layers = [pad, conv, norm]if use_relu:relu = nn.ReLU()if alpha > 0:relu = nn.LeakyReLU(alpha)layers.append(relu)self.features = nn.SequentialCell(layers)def construct(self, x):output = self.features(x)return outputclass Discriminator(nn.Cell):def __init__(self, in_planes=3, ndf=64, n_layers=3, alpha=0.2, norm_mode='batch'):super(Discriminator, self).__init__()kernel_size = 4layers = [nn.Conv2d(in_planes, ndf, kernel_size, 2, pad_mode='pad', padding=1),nn.LeakyReLU(alpha)]nf_mult = ndffor i in range(1, n_layers):nf_mult_prev = nf_multnf_mult = min(2 ** i, 8) * ndflayers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 2, alpha, norm_mode, padding=1))nf_mult_prev = nf_multnf_mult = min(2 ** n_layers, 8) * ndflayers.append(ConvNormRelu(nf_mult_prev, nf_mult, kernel_size, 1, alpha, norm_mode, padding=1))layers.append(nn.Conv2d(nf_mult, 1, kernel_size, 1, pad_mode='pad', padding=1))self.features = nn.SequentialCell(layers)def construct(self, x, y):x_y = ops.concat((x, y), axis=1)output = self.features(x_y)return output

Pix2Pix的生成器和判别器初始化

实例化Pix2Pix生成器和判别器

import mindspore.nn as nn
from mindspore.common import initializer as initg_in_planes = 3
g_out_planes = 3
g_ngf = 64
g_layers = 8
d_in_planes = 6
d_ndf = 64
d_layers = 3
alpha = 0.2
init_gain = 0.02
init_type = 'normal'net_generator = UNetGenerator(in_planes=g_in_planes, out_planes=g_out_planes,ngf=g_ngf, n_layers=g_layers)
for _, cell in net_generator.cells_and_names():if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):if init_type == 'normal':cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))elif init_type == 'xavier':cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))elif init_type == 'constant':cell.weight.set_data(init.initializer(0.001, cell.weight.shape))else:raise NotImplementedError('initialization method [%s] is not implemented' % init_type)elif isinstance(cell, nn.BatchNorm2d):cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))cell.beta.set_data(init.initializer('zeros', cell.beta.shape))net_discriminator = Discriminator(in_planes=d_in_planes, ndf=d_ndf,alpha=alpha, n_layers=d_layers)
for _, cell in net_discriminator.cells_and_names():if isinstance(cell, (nn.Conv2d, nn.Conv2dTranspose)):if init_type == 'normal':cell.weight.set_data(init.initializer(init.Normal(init_gain), cell.weight.shape))elif init_type == 'xavier':cell.weight.set_data(init.initializer(init.XavierUniform(init_gain), cell.weight.shape))elif init_type == 'constant':cell.weight.set_data(init.initializer(0.001, cell.weight.shape))else:raise NotImplementedError('initialization method [%s] is not implemented' % init_type)elif isinstance(cell, nn.BatchNorm2d):cell.gamma.set_data(init.initializer('ones', cell.gamma.shape))cell.beta.set_data(init.initializer('zeros', cell.beta.shape))class Pix2Pix(nn.Cell):"""Pix2Pix模型网络"""def __init__(self, discriminator, generator):super(Pix2Pix, self).__init__(auto_prefix=True)self.net_discriminator = discriminatorself.net_generator = generatordef construct(self, reala):fakeb = self.net_generator(reala)return fakeb

训练

包括训练判别器和生成器。训练判别器的目的是最大程度地提高判别图像真伪的概率。训练生成器是希望能产生更好的虚假图像。

代码实现:

import numpy as np
import os
import datetime
from mindspore import value_and_grad, Tensorepoch_num = 3
ckpt_dir = "results/ckpt"
dataset_size = 400
val_pic_size = 256
lr = 0.0002
n_epochs = 100
n_epochs_decay = 100def get_lr():lrs = [lr] * dataset_size * n_epochslr_epoch = 0for epoch in range(n_epochs_decay):lr_epoch = lr * (n_epochs_decay - epoch) / n_epochs_decaylrs += [lr_epoch] * dataset_sizelrs += [lr_epoch] * dataset_size * (epoch_num - n_epochs_decay - n_epochs)return Tensor(np.array(lrs).astype(np.float32))dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True, num_parallel_workers=1)
steps_per_epoch = dataset.get_dataset_size()
loss_f = nn.BCEWithLogitsLoss()
l1_loss = nn.L1Loss()def forword_dis(reala, realb):lambda_dis = 0.5fakeb = net_generator(reala)pred0 = net_discriminator(reala, fakeb)pred1 = net_discriminator(reala, realb)loss_d = loss_f(pred1, ops.ones_like(pred1)) + loss_f(pred0, ops.zeros_like(pred0))loss_dis = loss_d * lambda_disreturn loss_disdef forword_gan(reala, realb):lambda_gan = 0.5lambda_l1 = 100fakeb = net_generator(reala)pred0 = net_discriminator(reala, fakeb)loss_1 = loss_f(pred0, ops.ones_like(pred0))loss_2 = l1_loss(fakeb, realb)loss_gan = loss_1 * lambda_gan + loss_2 * lambda_l1return loss_gand_opt = nn.Adam(net_discriminator.trainable_params(), learning_rate=get_lr(),beta1=0.5, beta2=0.999, loss_scale=1)
g_opt = nn.Adam(net_generator.trainable_params(), learning_rate=get_lr(),beta1=0.5, beta2=0.999, loss_scale=1)grad_d = value_and_grad(forword_dis, None, net_discriminator.trainable_params())
grad_g = value_and_grad(forword_gan, None, net_generator.trainable_params())def train_step(reala, realb):loss_dis, d_grads = grad_d(reala, realb)loss_gan, g_grads = grad_g(reala, realb)d_opt(d_grads)g_opt(g_grads)return loss_dis, loss_ganif not os.path.isdir(ckpt_dir):os.makedirs(ckpt_dir)g_losses = []
d_losses = []
data_loader = dataset.create_dict_iterator(output_numpy=True, num_epochs=epoch_num)for epoch in range(epoch_num):for i, data in enumerate(data_loader):start_time = datetime.datetime.now()input_image = Tensor(data["input_images"])target_image = Tensor(data["target_images"])dis_loss, gen_loss = train_step(input_image, target_image)end_time = datetime.datetime.now()delta = (end_time - start_time).microsecondsif i % 2 == 0:print("ms per step:{:.2f}  epoch:{}/{}  step:{}/{}  Dloss:{:.4f}  Gloss:{:.4f} ".format((delta / 1000), (epoch + 1), (epoch_num), i, steps_per_epoch, float(dis_loss), float(gen_loss)))d_losses.append(dis_loss.asnumpy())g_losses.append(gen_loss.asnumpy())if (epoch + 1) == epoch_num:mindspore.save_checkpoint(net_generator, ckpt_dir + "Generator.ckpt")

推理

from mindspore import load_checkpoint, load_param_into_netparam_g = load_checkpoint(ckpt_dir + "Generator.ckpt")
load_param_into_net(net_generator, param_g)
dataset = ds.MindDataset("./dataset/dataset_pix2pix/train.mindrecord", columns_list=["input_images", "target_images"], shuffle=True)
data_iter = next(dataset.create_dict_iterator())
predict_show = net_generator(data_iter["input_images"])
plt.figure(figsize=(10, 3), dpi=140)
for i in range(10):plt.subplot(2, 10, i + 1)plt.imshow((data_iter["input_images"][i].asnumpy().transpose(1, 2, 0) + 1) / 2)plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)plt.subplot(2, 10, i + 11)plt.imshow((predict_show[i].asnumpy().transpose(1, 2, 0) + 1) / 2)plt.axis("off")plt.subplots_adjust(wspace=0.05, hspace=0.02)
plt.show()

总结

Pix2Pix作为GAN的一种变体,再生成图像和扩充数据方面有着重要作用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1473159.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【JavaSE】异常(Exception)

目录 异常是什么异常的分类异常的处理方法throw抛出异常异常的声明异常的捕获和处理finally 自定义异常类 异常是什么 异常就是程序在进行时的不正常行为,就像之前数组时会遇到空指针异常(NullPointerException),数组越界异常&am…

lua入门(1) - 基本语法

本文参考自: Lua 基本语法 | 菜鸟教程 (runoob.com) 需要更加详细了解的还请参看lua 上方链接 交互式编程 Lua 提供了交互式编程模式。我们可以在命令行中输入程序并立即查看效果。 Lua 交互式编程模式可以通过命令 lua -i 或 lua 来启用: 如下图: 按…

用Python轻松转换PDF为CSV

数据的可访问性和可操作性是数据管理的核心要素。PDF格式因其跨平台兼容性和版面固定性,在文档分享和打印方面表现出色,尤其适用于报表、调查结果等数据的存储。然而,PDF的非结构化特性限制了其在数据分析领域的应用。相比之下,CS…

Golang | Leetcode Golang题解之第218题天际线问题

题目: 题解: type pair struct{ right, height int } type hp []pairfunc (h hp) Len() int { return len(h) } func (h hp) Less(i, j int) bool { return h[i].height > h[j].height } func (h hp) Swap(i, j int) { h[i], h[j]…

python: create Envircomnet in Visual Studio Code 创建虚拟环境

先配置python开发环境 1.在搜索栏输入“>" 或是用快捷组合键ctrlshiftP键 就会显示”>",再输入"python:" 选择已经安装好的python的版本,选定至当前项目中,都是按回车 就可以看到创建了一个虚拟环境的默认的文件夹名".venv" 2 …

14-38 剑和诗人12 - RAG+ 思维链 ⇒ 检索增强思维(RAT)

在快速发展的 NLP 和 LLM 领域,研究人员不断探索新技术来增强这些模型的功能。其中一种备受关注的技术是检索增强生成 (RAG) 方法,它将 LLM 的生成能力与从外部来源检索相关信息的能力相结合。然而,最近一项名为检索增强思维 (RAT) 的创新通过…

大数据平台之数据同步

数据同步也成为CDC (Chanage Data Capture) 。Change Data Capture (CDC) 是一种用于跟踪和捕获数据库中数据变更的技术,它可以在数据发生变化时实时地将这些变更捕获并传递到下游系统。以下是一些常用的开源 CDC 方案: 1. Flink CDC Flink CDC 是基于 …

iptables与firewalld

iptables Linux上常用的防火墙软件 1、 防火墙的策略 防火墙策略一般分为两种,一种叫通策略,一种叫堵策略,通策略,默认门是关着的,必须要定义谁能进。堵策略则是,大门是洞开的,但是你必须有身…

迅捷PDF编辑器合并PDF

迅捷PDF编辑器是一款专业的PDF编辑软件,不仅支持任意添加文本,而且可以任意编辑PDF原有内容,软件上方的工具栏中还有丰富的PDF标注、编辑功能,包括高亮、删除线、下划线这些基础的,还有规则或不规则框选、箭头、便利贴…

C++ volatile 关键字

C volatile (只有release下才会生效) 1、告诉编译器volatile修饰的变量不要进行指令顺序的优化,以保证代码编写者的真实意图; int a 0;int b 10;int c 100;int* p &a;p &b;p &c;如果不加volatile修饰 p , 编译…

香橙派AIpro测评:yolo8+usb鱼眼摄像头的Camera图像获取及识别

一、前言 近期收到了一块受到业界人士关注的开发板"香橙派AIpro",因为这块板子具有极高的性价比,同时还可以兼容ubuntu、安卓等多种操作系统,今天博主便要在一块832g的香橙派AI香橙派AIpro进行YoloV8s算法的部署并使用一个外接的鱼眼USB摄像头…

Bellman equation的不同形式及变化

总忘记贝尔曼方程的推导过程,自己推一遍吧 matrix-vector form就省略了 对于matrix-vector form形式的状态价值贝尔曼方程求解,若已知MDP的动态(转移矩阵P和奖励函数R),则计算复杂度的贡献主要来自矩阵求逆&#xff…

HTTP与HTTPS的主要区别

HTTP(超文本传输协议)与HTTPS(超文本传输安全协议)的主要区别在于安全性、数据传输方式、默认使用的端口以及对网站的影响。 一、安全性: HTTP是一种无加密的协议,数据在传输过程中以明文形式发送&#x…

日志自动分析-Web---360星图GoaccessALBAnolog

目录 1、Web-360星图(IIS/Apache/Nginx) 2、Web-GoAccess (任何自定义日志格式字符串) 源码及使用手册 安装goaccess 使用 输出 3-Web-自写脚本(任何自定义日志格式字符串) 4、Web-机器语言analog(任何自定义日…

Ros2中动作通信的goal_handle类型在不同回调函数中的区别

在进行Ros2学习和进行项目的开发途中,准确来说实在动作通信项目的实战中,我给出了以下示例的ActionServer端初始化,并且使用goal_handle进行下一步操作。 self.server ActionServer(self,Nav,"nav",execute_callbackself.execute,…

APP渗透-android12夜神模拟器+Burpsuite实现

一、夜神模拟器下载地址:https://www.yeshen.com/ 二、使用openssl转换证书格式 1、首先导出bp证书 2、将cacert.der证书在kali中转换 使用openssl生成pem格式证书,并授予最高权限 openssl x509 -inform der -in cacert.der -out cacert.pem chmod 777 cacert…

七、MyBatis-Plus高级用法:最优化持久层开发-个人版

七、MyBatis-Plus高级用法:最优化持久层开发 目录 文章目录 七、MyBatis-Plus高级用法:最优化持久层开发目录 一、MyBatis-Plus快速入门1.1 简介1.2 快速入门回顾复习 二、MyBatis-Plus核心功能2.1 基于Mapper接口CRUDInsert方法Delete方法Update方法Se…

加入运动健康数据开放平台,共赢鸿蒙未来

HarmonyOS SDK运动健康服务(Health Service Kit)是为华为生态应用打造的基于华为帐号和用户授权的运动健康数据开放平台。在获取用户授权后,开发者可以使用运动健康服务提供的开放能力获取运动健康数据,基于多种类型数据构建运动健…

伯克利、斯坦福和CMU面向具身智能端到端操作联合发布开源通用机器人Policy,可支持多种机器人执行多种任务

不同于LLM或者MLLM那样用于上百亿甚至上千亿参数量的大模型,具身智能端到端大模型并不追求参数规模上的大,而是指其能吸收大量的数据,执行多种任务,并能具备一定的泛化能力,如笔者前博客里的RT1。目前该领域一个前沿工…

51单片机基础11——蓝牙模块控制亮灭

串口初试——蓝牙模块 蓝牙模块的使用1. 软硬件条件2. 蓝牙模块3. 代码(分文件处理之后的代码) 蓝牙模块的使用 1. 软硬件条件 单片机型号:STC89C52RC开发环境:KEIL4烧录软件串口通信软件:stc-isp蓝牙模块:HC-04LED模块(高电平点…