基础GAN生成式对抗网络(pytorch实验)

(Generative Adversarial Network)

一、理论

https://zhuanlan.zhihu.com/p/307527293?utm_campaign=shareopn&utm_medium=social&utm_psn=1815884330188283904&utm_source=wechat_session
大佬的文章中的“GEN的本质”部分
在这里插入图片描述

二、实验

1、数据集介绍

采用MNIST数据集,如下是训练集中的一张图片
在这里插入图片描述

2、代码

引入包

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

定义生成器

class Generator(nn.Module):def __init__(self):super(Generator, self).__init__()self.net = nn.Sequential(nn.Linear(100, 256),nn.ReLU(),nn.Linear(256, 512),nn.ReLU(),nn.Linear(512, 1024),nn.ReLU(),nn.Linear(1024, 784),nn.Tanh()  # 输出范围在 -1 到 1 之间)def forward(self, x):return self.net(x)

定义判别器

class Discriminator(nn.Module):def __init__(self):super(Discriminator, self).__init__()self.net = nn.Sequential(nn.Linear(784, 1024),nn.LeakyReLU(0.2),nn.Linear(1024, 512),nn.LeakyReLU(0.2),nn.Linear(512, 256),nn.LeakyReLU(0.2),nn.Linear(256, 1),nn.Sigmoid()  # 输出范围在 0 到 1 之间)def forward(self, x):return self.net(x)

训练
先训练判别器,后训练生成器
训练时先训练判别器:将训练集数据(Training Set)打上真标签(1)和生成器(Generator)生成的假图片(Fake image)打上假标签(0)一同组成batch送入判别器(Discriminator),对判别器进行训练。计算loss时使判别器对真数据(Training Set)输入的判别趋近于真(1),对生成器(Generator)生成的假图片(Fake image)的判别趋近于假(0)。此过程中只更新判别器(Discriminator)的参数,不更新生成器(Generator)的参数。

然后再训练生成器:将高斯分布的噪声z(Random noise)送入生成器(Generator),然后将生成器(Generator)生成的假图片(Fake image)打上真标签(1)送入判别器(Discriminator)。计算loss时使判别器对生成器(Generator)生成的假图片(Fake image)的判别趋近于真(1)。此过程中只更新生成器(Generator)的参数,不更新判别器(Discriminator)的参数。

transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)def train_gan(generator, discriminator, dataloader, num_epochs=25):criterion = nn.BCELoss()optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))for epoch in range(num_epochs):for i, (imgs, _) in enumerate(dataloader):batch_size = imgs.size(0)real_imgs = imgs.view(batch_size, -1)  # 将图像展平成一维# 标签real_labels = torch.ones(batch_size, 1)fake_labels = torch.zeros(batch_size, 1)# 训练判别器outputs = discriminator(real_imgs)d_loss_real = criterion(outputs, real_labels)real_score = outputsz = torch.randn(batch_size, 100)fake_imgs = generator(z)outputs = discriminator(fake_imgs.detach())d_loss_fake = criterion(outputs, fake_labels)fake_score = outputsd_loss = d_loss_real + d_loss_fakeoptimizer_d.zero_grad()d_loss.backward()optimizer_d.step()# 训练生成器outputs = discriminator(fake_imgs)g_loss = criterion(outputs, real_labels)optimizer_g.zero_grad()g_loss.backward()optimizer_g.step()print(f'Epoch [{epoch+1}/{num_epochs}], d_loss: {d_loss.item()}, g_loss: {g_loss.item()}')if (epoch+1) % 10 == 0:with torch.no_grad():fake_imgs = generator(torch.randn(64, 100)).view(-1, 1, 28, 28)grid = torchvision.utils.make_grid(fake_imgs, nrow=8, normalize=True)plt.imshow(grid.permute(1, 2, 0).cpu())plt.title(f'Epoch {epoch+1}')plt.show()generator = Generator()
discriminator = Discriminator()train_gan(generator, discriminator, dataloader)

3、结果

输入一个随机噪声图像,由生成器能得到如下的图片(训练1step的结果)
在这里插入图片描述

输入一个随机噪声图像,由生成器能得到如下的图片(训练10step的结果)
在这里插入图片描述

拓展

可以看看VAE、CGAN等模型

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1535154.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【F的领地】项目拆解:少儿英语虚拟资料项目 | 虚拟资料类目 | 学会利用 AI 去生成素材

项目介绍 前几天我分享的小学教辅资料整合项目,已经有学员私信我,说在实操的过程碰到问题了。 我给出了对应的解答,问了下收益,虽然才出十几单,但起码是行动了。 碰到问题,并不可怕,我之前博…

数业智能心大陆探索生成式AIGC创新前沿

近日,数业智能心大陆参与了第九届“创客中国”生成式人工智能(AIGC)中小企业创新创业大赛。在这场汇聚了众多创新力量的研讨过程中,广东数业智能科技有限公司基于多智能体的心理健康技术探索与应用成果,从众多参赛者中…

「Qt Widget中文示例指南」如何实现一个系统托盘图标?(二)

Qt 是目前最先进、最完整的跨平台C开发工具。它不仅完全实现了一次编写,所有平台无差别运行,更提供了几乎所有开发过程中需要用到的工具。如今,Qt已被运用于超过70个行业、数千家企业,支持数百万设备及应用。 System Tray Icon&a…

农产品自动识别系统(Java+Springboot+SSM+Vue+Maven+二维码溯源+识别农作物CNN模型PyTorch框架)

简介: 本系统有前后台的区分,分别由用户及管理员使用,其中用户还可以使用移动端登录。 用户端分为注册登录模块;个人信息模块;二维码模块;文章模块;溯源信息模块;农产品识别模块&a…

目标检测中的解耦和耦合、anchor-free和anchor-base

解耦和耦合 写在前面 在目标检测中,objectness(或 objectness score)指的是一个评分,用来表示某个预测框(bounding box)中是否包含一个目标物体。 具体来说,YOLO等目标检测算法需要在每个候选区…

深入理解Python中的生成器:高效迭代与延迟计算的艺术

在处理大量数据时,如何有效地管理内存成为了一个关键问题。Python中的生成器(Generator)提供了一种优雅的解决方案,它允许你在迭代过程中按需生成数据,而不是一次性加载所有数据到内存中。本文将详细探讨生成器的工作原…

OSSEC搭建与环境配置Ubuntu

尝试使用Ubuntu配置了OSSEC,碰见很多问题并解决了,发表博客让后来者不要踩那么多坑 环境 : server :Ubuntu22.04 64位 内存4GB 处理器4 硬盘60G agent: 1.Windows11 64位 2.Ubuntu22.04 64位 服务端配置 一、配置安装依赖项&…

信息安全数学基础(11)同余的概念及基本性质

一、同余的概念 同余是一个数学概念,用于描述两个数在除以某个数时所得的余数相同的情况。具体地,设m是一个正整数,a和b是两个整数,如果a和b除以m的余数相同,则称a和b模m同余,记作a≡b(mod m)。反之&#x…

筑牢网络安全防线:为数字时代保驾护航

《筑牢网络安全防线:为数字时代保驾护航》 一、网络安全:数字时代的关键课题 网络安全在当今数字时代的重要性愈发凸显。2024 年国家网络安全宣传周以 “网络安全为人民,网络安全靠人民” 为主题,深刻体现了网络安全与每个人息息…

Django视图:构建动态Web页面的核心技术

Django,作为一个强大的Python Web框架,提供了一套完整的工具来构建这些动态页面。在Django的架构中,视图(Views)是处理用户请求并生成响应的关键组件。本文将深入探讨Django视图的工作原理,以及如何使用它们…

Auracast认证:蓝牙广播音频的革新之旅

低功耗音频(LE Audio)技术的突破,为蓝牙世界带来了前所未有的广播音频功能。Auracast™,作为蓝牙技术联盟精心打造的音频广播解决方案,正引领着一场全新的音频分享革命。它不仅革新了传统蓝牙技术的局限,更…

[进阶]面向对象之多态(练习)

需求: //父类animal package polymorphism.Test;public abstract class Animal {private int age;private String color;public Animal() {}public Animal(int age, String color) {this.age age;this.color color;}public int getAge() {return age;}public void setAge(i…

JSP经典设计模式流程分析:JSP+JavaBean设计模式+MVC设计模式

JSP两种经典设计模式 Model1设计模式:JSPJavaBean 架构图 什么是JavaBean JavaBean是一种JAVA语言写成的可重用组件,它遵循特定的编程规范,如类必须是公共的、具有无参构造函数,并提供getter/setter方法等。这里的JavaBean不单单指的是实体…

ESP32-WROOM-32 开篇(刚买)

简介 买了一个ESP32-WROOM-32模块的开发板, 记录板初上机细节。 模块简介 Look 连接PC 1. 解决驱动问题 https://www.silabs.com/developers/usb-to-uart-bridge-vcp-drivers?tabdownloads 下载驱动, 如下图 解压缩下载的包, 然后电机64位的版本, 一直…

grafana升级指南

已有grafana在使用,需要升级新版本的grafana,操作如下: 1.先把之前的grafana文件夹整个备份 2.在grafana官网下载OSS的zip版本,不要msi版本 3.在原来的grafana文件夹里,把新版本的文件夹都复制进来,但是…

数据库课程 CMU15-445 2023 Fall Project-1 Buffer Pool Manager

0 实验结果 1 任务总结 本章按照任务书,需要完成 LRU-K替换策略磁盘调度器——后台线程接收请求,处理数据的读/写。缓冲池管理——使用上面完成的功能,来对抽象的页操作。 1.1 LRU-K替换策略 每个函数的说明都很清楚,按照指示…

【python计算机视觉编程——9.图像分割】

python计算机视觉编程——9.图像分割 9.图像分割9.1 图割安装Graphviz下一步:正文9.1.1 从图像创建图9.1.2 用户交互式分割 9.2 利用聚类进行分割9.3 变分法 9.图像分割 9.1 图割 可以选择不装Graphviz,因为原本觉得是要用,后面发现好像用不…

齐活儿了:一文读懂ERP和MRP、MES、CRM、WMS、SRM、APS等系统

ERP,即企业资源计划系统,是驱动企业资源整合与高效管理的核心引擎。它覆盖了企业财务、人力资源、研发创新、生产制造、供应链管理、采购活动、销售市场、客户服务以及资产管理这九大核心业务领域,形成了一个全方位、多层次的企业价值链管理体…

初学者指南:如何在Windows 11中自定义任务栏颜色,全面解析!

Windows任务栏如何修改颜色?任务栏可以说是电脑桌面上比较不“起眼”的东西,但是也有不少小伙伴会想要将自己的电脑任务栏设置得好看,比如说修改电脑任务栏透明度,以及修改任务栏颜色。 电脑任务栏设置可以修改任务栏颜色&#xf…

27 顺序表 · 链表

目录 一、单链表 (一)概念 1、节点 2、链表的性质 (二)单链表的实现 (三)单链表算法题 1、移除链表元素 2、反转链表 3、链表的中间节点 4、合并两个有序的单链表 5、链表分割 6、链表的回文结构…