基于 CycleGAN 对抗网络的自定义数据集训练

目录

生成对抗网络(GAN)

CycleGAN模型训练

训练数据生成

下载开源项目CycleGAN

配置训练环境

开始训练

模型测试

可视化结果


生成对抗网络(GAN)

        首先介绍一下什么是GAN网络,它是由生成器(Generator)和判别器(Discriminator)组成,二者均是由神经网络构成,通过不断的博弈来提高输出数据质量。

        生成器的目的是学习真实数据的分布,从而能够生成与真实数据相似的新样本。它接收随机噪声作为输入,并通过一系列的神经网络层将其转化为具有特定特征的输出,试图欺骗判别器使其认为生成的数据是真实的。

        判别器则负责区分输入数据是来自真实数据集还是由生成器生成的。它接收数据并输出一个概率值,表示该数据为真实数据的可能性。判别器通过不断学习来提高自己区分真实数据和生成数据的能力

        在训练过程中,生成器和判别器进行对抗性的博弈。生成器努力提高生成数据的质量,以使其能够骗过判别器;而判别器则努力提高自己的鉴别能力,不被生成器欺骗。通过不断地迭代训练,双方的性能逐渐提升,最终达到一种平衡状态,此时生成器能够生成非常逼真的样本,而判别器也具有较高的鉴别能力。

CycleGAN 是由 Jun-Yan Zhu 等人于 2017 年提出的,核心思想是通过两个生成器和两个判别器来实现无监督的图像转换2。它引入了循环一致性损失,确保转换是双向的且在转换前后能够保持图像的一致性。

CycleGAN 论文:https://arxiv.org/abs/1703.10593

上面这个图是该网络实现的风格迁移,感觉这个网络还是挺有意思的,就想着训练一下自己的数据集看下效果,那下面我们直接进入正题吧。

CycleGAN模型训练

注意:目前只尝试过图像对的训练,仅支持包含src和dst的数据集

GitHub项目:CycleGAN-based-train

整体目录架构:

训练数据生成

首先准备自己需要训练的数据集,需要包含源和目标,数据集的格式如下:

其中,O-HAZY NTIRE 2018是根目录,GT是源图像存放路径,hazy是目标图像存放路径

同时请准备好测试样本文件夹test-sample(可自定义),准备的一定要是图像文件夹,暂时不会支持单张图像的测试,格式如下:

数据集准备好后运行main.py文件,需要注意参数设置,具体请查看文件说明

# main.pyimport os
import shutil
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import matplotlibmatplotlib.use('TkAgg')
from tqdm import tqdm# ----------------------训练数据路径-----------------------#
#   仅支持包含src和dst的数据集(图像对)
# -------------------------------------------------------#
root = r'O-HAZY NTIRE 2018'
# --------------------------------------------------------#
#       label1:src的路径名  |  label2:dst的路径名
# --------------------------------------------------------#
label1 = 'GT'
label2 = 'hazy'
# -------------------------生成图像可视化-------------------------#
#   !!! 在训练和测试均完成后进行结果检查时仅可设置为True,否则报错  !!!
#   该部分只是对结果的可视化,预测阶段请查看README
# -------------------------------------------------------------#
test = False
# ------------------------测试样本------------------------------#
test_data_path = './test-sample'
# ------------------------测试结果图像保存路径---------------------#
# !!!   里面是已经得到的测试结果和原图     !!!
# -------------------------------------------------------------#
results_path = './results/dehaze_cyclegan/test_latest/images/'def make_data(src_path, dst_path, label):src_path = src_path + f'/{label}/'image_files = [f for f in os.listdir(src_path) if f.endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp'))]with tqdm(total=len(image_files)) as pbar:for filename in image_files:file_path = os.path.join(src_path, filename)if filename.endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp')):image = Image.open(file_path)target_file = os.path.join(dst_path, filename)image.save(target_file)pbar.update(1)if __name__ == '__main__':if not test:# -------------------创建CycleGAN的训练数据路径-----------------------#if not os.path.exists('dataset'):os.makedirs('dataset')if not os.path.exists('dataset/trainA'):os.makedirs('dataset/trainA')if not os.path.exists('dataset/trainB'):os.makedirs('dataset/trainB')# --------------------------检查图像对数量----------------------------#num_images = len(os.listdir(root + f'/{label1}/'))idx = np.arange(1, num_images + 1)print(f'查找到{num_images}个图像对')make_data(root, 'dataset/trainA/', label1)make_data(root, 'dataset/trainB/', label2)# ----------------------可视化阶段-----------------------------------#else:for f in os.listdir(test_data_path):fake = f.split('.')[0] + '_fake.png'real = f.split('.')[0] + '_real.png'fig = plt.figure()ax = plt.subplot(1, 2, 1)img1 = Image.open(results_path + real)plt.imshow(img1)ax = plt.subplot(1, 2, 2)img2 = Image.open(results_path + fake)plt.imshow(img2)plt.show()

下载开源项目CycleGAN

这一步如果下载了我上传的GitHub仓库的可以直接跳过,因为我已经将该项目放置在仓库里面,不需要重复下载。当然如果没有下载,请继续往下看

方式一:git clone GitHub - junyanz/pytorch-CycleGAN-and-pix2pix: Image-to-Image Translation in PyTorch

方式二:百度网盘:pytorch-CycleGAN-and-pix2pix

链接:https://pan.baidu.com/s/1WC-kEonwm7bFujO72GZAcQ        提取码:jsw2

配置训练环境

终端打开pytorch-CycleGAN-and-pix2pix,输入以下命令

pip install -r requirements.txt

开始训练

同样的,在终端打开该项目,输入以下指令:

python train.py --dataroot ./dataset --name dehaze_cyclegan --model cycle_gan

其中,只有 --name 是可改参数,可以自己命名模型的名称,但是修改后一定要与测试时的名称一致,请一定注意这一点

此外,如果在训练过程中出现“OSError: [WinError 1455] 页面文件太小,无法完成操作”报错信息,这是由于训练环境所在磁盘虚拟内存不足导致,调整方法如下:

最后一步选择训练环境所在的磁盘进行修改即可

训练过程截图

模型测试

在终端打开该项目,输入以下指令:

cp ./checkpoints/dehaze_cyclegan/latest_net_G_A.pth ./checkpoints/dehaze_cyclegan/latest_net_G.pth
python test.py --dataroot ./test-sample --name dehaze_cyclegan --model test --no_dropout --direction AtoB

这里需要注意的是 --dataroot 是测试样本,可以自己调整路径,同时注意模型名称是否与训练的一致,不一致请修改

生成的结果会保存在results文件夹下,目录结构如下:

其中,fake是生成图像,real是原图像,同时所有图像尺寸均会被调整为256\times 256

可视化结果

运行main.py文件,需要设置3个参数:test、test_data_path、results_path(test=True),详情请查看具体文件

我想要实现图像加雾,但是这个效果看起来一般吧,也有可能是图像数据对和训练轮次太少了。但不管怎么说,终究还是成功了嘛。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1534140.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

分类预测|基于差分优化DE-支持向量机数据分类预测完整Matlab程序 DE-SVM

分类预测|基于差分优化DE-支持向量机数据分类预测完整Matlab程序 DE-SVM 文章目录 一、基本原理DE-SVM 分类预测原理和流程总结 二、实验结果三、核心代码四、代码获取五、总结 一、基本原理 DE-SVM 分类预测原理和流程 1. 差分进化优化算法(DE) 原理…

【运维监控】Prometheus+grafana监控tomcat运行情况

运维监控系列文章入口:【运维监控】系列文章汇总索引 文章目录 一、prometheus二、grafana三、tomcat与jmx_exporter配置1、下载jmx_exporter2、部署jmx_exporter3、添加tomcat的配置信息4、修改tomcat的启动文件5、重启tomcat及验证6、其他 四、集成prometheus与gr…

vue3 动态 svg 图标使用

前言 在做后台管理系统中,我们经常会用到很多图标,比如左侧菜单栏的图标 当然这里 element-ui 或者 element-plus 组件库都会提供图标 但是在有些情况下 element-ui 或者 element-plus 组件库提供的图标满足不了我们的需求时,这个时候我们就需要自己去网上找一些素材或者…

CAN通讯常见错误

CAN通讯常见错误 1.在使用CAN设备进行数据通讯时,有时候参数配置不当可能就会导致通讯的失败,如下图1所示,出现通信错误的原因是两个设备的波特率配置不一致导致。 图1 2.有时候在配置参数的时候,不能只关注波特率速度配置一致…

JEE 设计模式

Java 数据访问对象模式 Java设计模式 - 数据访问对象模式 数据访问对象模式或DAO模式将数据访问API与高级业务服务分离。 DAO模式通常具有以下接口和类。 数据访问对象接口定义模型对象的标准操作。 数据访问对象类实现以上接口。可能有多个实现,例如&#xff0c…

关于Redis缓存一致性问题的优化和实践

目录标题 导语正文分布式场景下无法做到强一致即使是达到最终一致性也很难缓存的一致性问题缓存是如何写入的 如何感知数据库的变化最佳实践一:数据库变更后失效缓存最佳实践二:带版本写入 总结与展望阿里XKV腾讯DCache 导语 Redis缓存一致性的问题是经…

【API安全】威胁猎人发布超大流量解决方案

随着数字化进程加速,企业API接口数量激增,已经成为连接内外部服务的重要桥梁。然而,对于拥有庞大的外部客户群体和错综复杂的内部业务系统的大型企业而言,API安全管控面临超大流量下的性能瓶颈与数据安全双重挑战。 性能上&#…

【软件测试】常用的开发、测试模型

哈喽,哈喽,大家好~ 我是你们的老朋友:保护小周ღ 今天给大家带来的是 【软件测试】常用的开发、测试模型,首先了解, 什么是软件的生命周期, 测试的生命周期, 常见的开发模型: 瀑布, 螺旋, 增量, 迭代, 敏捷. 常用的测试模型, …

Serverless 安全新杀器:云安全中心护航容器安全

作者:胡志广(独鳌) 云安全中心对于 Serverless 容器用户的价值 从云计算发展之初,各大云厂商及传统安全厂商就开始围绕云计算的形态来做安全解决方案。传统安全与云计算安全的形态与做法开始发生变化,同时随着这 10 多年的发展,…

ThreeJS入门(002):学习思维路径

查看本专栏目录 - 本文是第 002篇入门文章 文章目录 如何使用这个思维导图 Three.js 学习思维导图可以帮助你系统地了解 Three.js 的各个组成部分及其关系。下面是一个简化的 Three.js 学习路径思维导图概述,它包含了学习 Three.js 的主要概念和组件。你可以根据这个…

Redis 入门 - 收官

《Redis 入门》系列文章总算完成了,希望这个系列文章可以想入门或刚入门的同学提供帮助,希望能让你形成学习Redis系统性概念。 当时为什么要写这个系列文章,是因为我自己就是迷迷糊糊一路踩坑走过来的,我踩完的坑就踩完了&#x…

Kamailio-基于Zabbix+Kamcli的SIP指标监控

什么是Kamailio? Kamailio 是一个开源的 Session Initiation Protocol (SIP) 服务器,它主要用于建立和管理实时通信会话,如语音和视频通话,与opensips这个产品是同根同源的存在。它们相似,没有更好,是有更合适。 此…

LLM - 理解 多模态大语言模型 (MLLM) 的指令微调与相关技术 (四)

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://spike.blog.csdn.net/article/details/142063880 免责声明:本文来源于个人知识与公开资料,仅用于学术交流,欢迎讨论,不支持转载。 完备(F…

获取京东商品详情数据API接口优惠券信息(通过商品id获取商品详情页数据)调用说明文档

在当今数字化时代,应用程序之间的互操作性已成为推动业务创新和技术进步的关键因素。API(Application Programming Interface,应用程序编程接口)作为这一生态系统中不可或缺的一环,扮演着连接不同软件服务、数据资源和…

AE 让合成重复循环播放

在合成上点右键 > Time > Enable Time Remapping 按住 Alt 键,点秒表图标 输入 loop_out("cycle", 0) 将子合成拖到此合成结束的位置 结束

Ton的编译过程(上)

系列文章目录 FunC编写初始准备 文章目录 系列文章目录预先准备第一个FunC合约深入compileFunc的内部compileFunc初探艾丽卡的疑惑package.json 初览index.js 预先准备 首先请大家跟着艾丽卡一步一步的完成FunC编写初始准备 这里面环境的搭建。 接下来,请做好下面…

通过python提取PDF文件指定页的图片

整体思路 要从 PDF 文件中提取指定页和指定位置的图片,可以分几个步骤来实现: 1.1 准备所需工具与库 在 Python 中处理 PDF 和图像时,需要使用几个库: PyMuPDF (fitz):用于读取和处理 PDF 文件,可以精确…

Android 测试机

要测手机应用,直接挂电脑上跑虚拟机的话,怀疑电脑都要起火了。 eBay 上买了个新的机器,也才 100 美元多点,机器都没有拆过,电池是完全无电的状态。 操作系统是 Android 12 的版本,升级到 Android 14 后&am…

表格标记<table>

一.表格标记、 1table&#xff1a;表格标记 2.caption:表单标题标记 3.tr:表格行标记 4.td:表格中数据单元格标记 5.th:标题单元格 table标记是表格中最外层标记&#xff0c;tr表示表格中的行标记&#xff0c;一对<tr>表示表格中的一行&#xff0c;在<tr>中可…

Spring Boot集成Akka Stream快速入门Demo

1.什么是Akka Stream&#xff1f; Akka Streams是一个用于处理和传输元素序列的库。它建立在Akka Actors之上&#xff0c;使流的摄入和处理变得简单。由于它是建立在Akka Actors之上的&#xff0c;它为Akka现有的actor模型提供了一个更高层次的抽象。Akka流由3个主要部分组成-…