第十九周机器学习笔记:GAN的数学理论知识与实际应用的操作

第十九周周报

  • 摘要
  • Abstratc
  • 一、机器学习——GAN Basic Theory
    • 1. Maximum Likelihood Estimation
    • 2. 复习训练GAN的过程
    • 3. Objective function与JS散度相关性推导
    • 4. GAN的实际做法
  • 总结

摘要

本周周报主要围绕生成对抗网络(GAN)的基础知识和理论进行深入探讨。首先回顾了GAN的基本概念、训练原理和应用场景。随后,周报详细分析了GAN背后的理论基础,包括如何通过高维空间中的点来理解图像生成,以及如何通过生成模型来寻找数据的分布。然后周报还描述了最大似然估计(MLE)在生成任务中的应用,并对比了传统方法与GAN的不同。然后复习了训练GAN的过程,包括如何通过判别器(Discriminator)来衡量两个分布之间的差异。最后,我们探讨了GAN的目标函数与JS散度(Jensen-Shannon Divergence)之间的关系,并讨论了在实际训练中如何通过样本来近似期望值。

Abstratc

In the weekly report, the basic concepts of Generative Adversarial Networks (GANs) are reviewed, followed by an in-depth discussion of the theoretical underpinnings of GANs. This includes an analysis of how image generation can be understood through points in high-dimensional spaces and the search for data distributions via generative models. The application of Maximum Likelihood Estimation (MLE) in generative tasks is described, with a comparison made between traditional methods and GANs. The training process of GANs is also reviewed, highlighting how discriminators are utilized to assess the divergence between two distributions. Lastly, the relationship between the objective function of GANs and Jensen-Shannon Divergence (JS divergence) is explored, along with a discussion on approximating expectations with samples during actual training.

一、机器学习——GAN Basic Theory

在之前GAN的学习中,我们了解了GAN的概念,训练的原理以及应用,这只是GAN的基础内容,接下来我们将详细的了解GAN背后的理论知识。

假设我们要生成的东西是image,我们用x呢来代表一张image
(每一个image都是高维空间中的一个点,假设产生64×64的image,那它是64×64维度空间中的一个点)
如下图所示:
为了方便解释我们将其视为二维空间中的一个点,所以它实际上是高维空间中的一个点。
我们要产生的image,它其实有一个固定的distribution,记为成Pdata
即在这整个image space里面只有非常少的部分sample出来的image看起来像是人脸,在多数的space中sample出来image它都不像是人脸。
举例来说,在下图的例子里面,可能只有蓝色的这个区域sample的image,它看起来像是人脸啊。在其他地方simple看起来的图片看起来就不像是人脸。
所以假设我们要生成的是人脸的话,它有一个固定的distribution,这个distribution在蓝色的这个区域,它的几率是高的;在蓝色区域以外,它的几率是低的。
在这里插入图片描述
机器做的事情是什么呢?
我们要机器去找出这一个distribution,而这个distribution实际上我们是不知道的
我们可以搜集很多的x(image)去了解x可能在某些地方分布比较高,但是要我们把它的function找出来是做不到的。
所以现在generated model(GAN)做的事情是——找出这个x的distribution。

1. Maximum Likelihood Estimation

那在有GAN之前,我们怎么做generative这件事呢?
我们使用最大似然估计(Maximum Likelihood Estimation)来完成
如下图示:
最大似然估计

  • 给定一个数据分布 Pdata(x)(这是我们采样得到,因为我们并不能列出这个distribution的式子)
  • 我们有一个由参数θ参数化的分布Pc(xi;θ)。
  • 我们希望找到θ使得 Pc(xi;θ)接近Pdata(x)。
  • 注意:Pc(xi;θ) 是一个Gaussian Mixture Model(高斯混合模型), (θ) 是Gaussian的均值(average)和方差(mean)
    在这里插入图片描述
    步骤如下:
  1. 从 Pdata(x) 中采样 (x1, x2, …, xm) 。
  2. 计算PG(xi;θ)
  3. 计算生成样本的似然度(Likelihood)
    L = ∏ i = 1 m P G ( x i ; θ ) L = \prod_{i=1}^{m} P_G(x_i; \theta) L=i=1mPG(xi;θ)
  4. 找到最大化似然的 (θ*)去maximize L。
    在这里插入图片描述
    其中θ*转化Maximum Likelihood Estimation为minimize KL 散度的推导如下:
    其中需要注意的是解释中的
    3与4,这是转化的关键
    在这里插入图片描述
    在推导的时候,我有个疑问就是为什么直接转化为KL,减一个东西不影响原来的结果吗?
    其实我们θ只影响被减去的那一项,另外一个是常数项。
    因此我们把max转化为min减去那一项结果(加个负号由max变为min)其实是不影响的。
    如下图所示:
    在这里插入图片描述
    但是问题在于PG,其也许不是高斯分布模型(使用高斯分布模型,给定一个x可以计算其被sample出来的几率)
    而是比高斯分布更加复杂的分布,例如,它是一个neural network,那你就没有办法计算PG(xi;θ)。

2. 复习训练GAN的过程

那要怎么办呢?
于是就有了一个新的想法:
因为Generator就是一个network,而我们把一个network看作是一个probability distribution。
回顾一下Gnerator的运作过程:
1.每次sample出一个z,它丢到这个generator里面,你就会得到一个x
(把Generator看作一个function,那个结果就是G(z))
2.sample不同的z得到的x呢就不一样。(z是从一个gaussian distribution里面sample出来的)
在这里插入图片描述
把这些从gaussian distribution里面sample出来的z通过Generator得到另外一大堆sample,把这些sample统统集合起来,你得到的就会是另外一个distribution。
在这里插入图片描述
那接下来目标是什么?
接下来目标是希望generator根据这个generator所定义出来的这个distribution PG它跟我们的目标跟我们的Pdata的越接近越好。
在这里插入图片描述
写一个optimization的formulation,这个formulation看起来是这个样子:
G ∗ = arg ⁡ min ⁡ G Div ⁡ ( P G , P data  ) G^{*}=\arg \min _{G} {\operatorname{Div}\left(P_{G}, P_{\text {data }}\right)} G=argGminDiv(PG,Pdata )
就是求PG与Pdata的散度

补充:arg的含义
在这里插入图片描述
公式有了,但是现在的问题就是
PG跟Pdata它们的formulation我们是不知道的,我们无法计算Divergence,所以怎么办呢?
这个就是GAN神奇的地方
在进入比较多的数学推导之前我们复习一下GAN到底是怎么做到minimize divergence这件事情
虽然不知道PG跟Pdata的distribution长什么样子,但是我们可以从这两个distribution里去sample一些data出来形成一个代表性的distribution。

  1. 把DataBase拿出来,假设我们做二次元人物头像生成的话,就把你的DataBase拿出来,然后从里面sample很多image,这个就是从Pdata这个distribution里面做sample。
  2. 从PG里面做sample其实就是random sample一个vector,然后把这个vector丢到Generator里面产生一image。因为PG由的generator所定义的,那我们在使用这个generator的时候,我们是从某一个部分distribution里面去sample 一大堆的vector,每一个vector就会产生一张image

所以我们可以从Pdata里面做sample,我们也可以从PG里面做sample。
在这里插入图片描述
接下来的问题是我们可以从PG与Pdata做sample,根据这个sample,要怎么知道这两个distribution的divergence呢?
GAN神奇的地方就是透过discriminator,我们可以来量这两个distribution间的divergence
蓝色星星是从Pdata中sample出来的,让其分数越大越好。
红色星星是从PG中sample出来的,让其分数越小越好。
然后我们训练Discriminator
在这里插入图片描述
discriminator训练的结果就会告诉我们Pdata跟PG它们之间的divergence有多大
我们会写一个objective function。
它跟两项有关,一个是跟generator有关,一个是跟discriminator有关。
在train这个discriminator的时候呢,会固定住generator。所以只跟我们的Discriminator有关。

V ( G , D ) = E x ∼ P data  [ log ⁡ D ( x ) ] + E x ∼ P G [ log ⁡ ( 1 − D ( x ) ) ] V(G, D)=E_{x \sim P_{\text {data }}}[\log D(x)]+E_{x \sim P_{G}}[\log (1-D(x))] V(G,D)=ExPdata [logD(x)]+ExPG[log(1D(x))]

我们希望从Pdata采样的D(x)值越大越好,从PG采用出来的D(x)越小越好。
所以要 maximize V(G,D)。

3. Objective function与JS散度相关性推导

然后我们上一周提到了Objective function与JS散度是有关联的,其实我们看下图中的,我们也可以直观的感受到:
在这里插入图片描述
它们的具体表达如下:

E x ∼ P data  [ log ⁡ D ( x ) ] + E x ∼ P G [ log ⁡ ( 1 − D ( x ) ) ] = − 2 log ⁡ 2 + 2 J S D ( P data II  P G ) E_{x \sim P_{\text {data }}}[\log D(x)]+E_{x \sim P_{G}}[\log (1-D(x))] =-2 \log 2+2 J S D\left(P_{\text {data}} \text { II } P_{G}\right) ExPdata [logD(x)]+ExPG[log(1D(x))]=2log2+2JSD(Pdata II PG)

数学公式推导如下:
在这里插入图片描述
在这里插入图片描述
所以我们就可以把DIV的问题转化为 max V(G,D)
在这里插入图片描述
如下图中G3就代表了最好的Div,因为G3表示Pdata与PG的divergence是最小的
在这里插入图片描述
那接下来呢,我们就是要想办法解这个min & max 的问题
GAN在train的时候:

  1. 固定住generator去update discriminator
  2. 固定住discriminator接下来去update generator
    在这里插入图片描述

要解这个optimization问题,要怎么做呢?
在maxV(G,D)中需要找一个D去maximize V(G,D)看起来有点复杂,所以我们把它用L(G)来代替,因为跟D没有关系的,我们只需要最后得到的值越大越好。
在这里插入图片描述处理过程如下:
在这里插入图片描述
在这里插入图片描述
==这边有一个问题,现在做的事未必等同于真的在minimize JS Divergence
说假设一个Generator就是G0,那V(G0,D)假设如下图左边的样子。
找到一个D0*,这个D0*的值就是G0跟Data之间的JS divergence。
但是当update G0变成G1的时候,这个时候呢,V(G1,D0*)的function可能就会变了。
本来V(G1,D0*)如下图左边图像所示
(V(G1,D0*)就是G0跟Data的JS Diversion。)
但是updateG0变成G1,这个时候就算不是在evaluate JS Dvergence。这是因为你的D0*仍然是固定的,但是V(G1,D0*)就不是在evaluate JS Divergence
因为估算JS Divergence的是要求最大的值,所以今天当你的G变了,function就变了,当function变的时候,同样的D*就不是在evaluate JS Dvergence。

但是为什么我们在进行参数θG优化的时候,是在减少JS Divergence呢?
一个前提的假设就这两个function可能是非常的像的
假设只update一点点的G,例如,从G0变到G1,,G的参数只动了一点点,那这两个function长相可能是比较像的。
所以他一样用,一样用D0*,仍然是在量JS Divergence这样的情形
(如下图的两个曲线图像所示,这边本来值很小,突然变很高的情形,可能是不会发生的。
因为G0与G1是很像的,所以这两个function是比较接近,所以只同样用固定的D0*就可以evaluate JS Divergence。)

所以在train这个GAN的时候,tips就是Generator不能够一次update太多。但是在train Discriminator的时候,理论上你应该把它train到底。
原因如下:

  1. 因为对于Generator的话,你应该只要跑比较少的iteration,以免上述的假设不成立。
  2. 对于Discriminator在train时候你其实会需要比较多的iteration,把它train到底,因为我们需要找到MAX的值,才算evaluate JS Divergence。
    在这里插入图片描述

4. GAN的实际做法

以上都是假设的,那么实际上你在做GAN的时候,其实是怎么做的呢?
之前说过要计算objective function就要对里面的x取期望(Ex),但是在实际上你没有办法获取其期望,所以我们都是用sample n个data来代替期望。
实际上我们在做的时候,我们就是在maximize如图的式子,而不是真的去maximize它的期望
即把sample出来的这n个data的通统算出来,然后再把它统统平均起来,就当做是expectation
在这里插入图片描述
所以在train Discriminator的时候,就是在train一个binary classifier(二元分类器),说明如下:

  1. 实际上Discriminator是一个binary classified
  2. 这个binary classified是一个这个logistics regression
  3. 它的output有接一个sigmoid(即output的值是介于0到1之间的)
  4. 然后从Pdata里面里sample n个data出来,这n个data就是Positive examples或者class 1 examples。
  5. 然后呢你从PG里面再sample另外 n个data出来,这n个data就当做是negative examples或者class 2 examples
  6. 接下来就train binary classified(即Discriminator),会minimize cross entropy。然后发现如果你在minimize cross entropy,把式子写出来,它会等同于上面maximize objective function。
    在这里插入图片描述

总结:
我们复习一下以上的过程,算法分为两步
第一步是maxV(G,D),即train Discriminator,以下是我个人的觉得重要的总结

  1. 我们train Discriminator的目的是为了量evaluate James divergence。当V(G,D)的值最大的时候,Discriminator才是在evaluates diverges,所以V的值要被maximize,为了让V的值最大,所以一定要对Discriminator train 很多次(虽然很难达到,但是可以train个接近的值)。
    在这里插入图片描述
    第二步就是min maxV(G,D),即train Generator。
  2. 我们train Generator是为了要minimize JS Divergence即,减少JS Divergence的值,minimize下图的式子的时,第一项呢是可以不用考虑它的,所以你把第一项拿掉,只去minimize第二项之前说过Generator你不能够train太多,因为一旦train太多的话,你的Discriminator就没有办法evaluate James divergence。
    在这里插入图片描述
    目前为止,我们讲说今天在train generator的时候,实际不是下图上半部分,而是如下图下半部分:
    在这里插入图片描述
    原因如下而在一篇paper里面,一开始就不是在minimize这个式子。
    log(1-D(x)),它长的是这个样子:
    在这里插入图片描述
    而一开始在做training的时候,D(x)的值通常是很小的,因为Discriminator会知道说你的generator产生出来的image它是fake的,所以它会给它很小的值。所以一开始D(x)的值会落在上图中靠近坐标轴左边的地方,那它的微分是很小的。所以在training的时候会造成你在training的一些问题

所以他说呢,作者把它把它换成-log(D(x)),-log(D(x))它长的是这个样子:
在这里插入图片描述
这两个式子的趋势是一样的,但是他们在同一个位置的斜率就变得不一样。
-log(D(x))在一开始设置D(x)还很小的时候,你算出来的微分会比较大,所以觉得说这样子training是比较容易的。

最后再来直观感受一下Discriminator 和 Generator互动的过程:
重复步骤交替训练 Discriminator 和 Generator。每次迭代,Generator 尝试生成更真实的数据,而 Discriminator 则尝试更好地区分真实和假数据。这个过程可以看作是两个模型之间的“对抗”或“游戏”,其中 Generator 试图生成越来越好的数据,而 Discriminator 则试图更好地区分。随着训练的进行,Generator 生成的数据分布将越来越接近真实数据分布。最终,Discriminator 将无法区分真实数据和生成数据,或者达到一个平衡点,此时 Generator 生成的数据在统计上与真实数据无法区分。
在这里插入图片描述

总结

本周因为要准备考试和课程论文,进度缓慢,之后需要加快进度。
本周的学习了生成对抗网络(GAN)的理论基础和训练过程。首先了解了GAN如何通过高维空间中的点来模拟图像生成,并探讨了如何通过生成模型来寻找数据的分布。其中学习了最大似然估计(MLE)在生成任务中的应用,并理解了GAN与传统方法的不同之处。我复习了训练GAN的过程,包括如何通过判别器(Discriminator)来衡量两个分布之间的差异。对GAN的目标函数与JS散度之间的关系进行了推导。学习了如何通过样本来近似期望值,并理解了在训练判别器和生成器时的不同策略。在训练判别器时需要多次迭代以接近最大值,而在训练生成器时则应避免过大的更新步长,以保持判别器能够有效地评估JS散度。最后,我直观地感受到了判别器和生成器之间的互动过程,这是一个“对抗”或“游戏”的过程,其中生成器试图生成越来越真实的数据,而判别器则试图更好地区分真实和假数据。随着训练的进行,生成器生成的数据分布将越来越接近真实数据分布,最终达到一个平衡点,此时生成器生成的数据在统计上与真实数据无法区分。通过本周的学习,我对GAN的理论知识和实际应用有了更深入的理解,为进一步的研究和实践打下了坚实的基础。
下一周打算先把GAN的拓展放一放,学习一些新内容,因为前面的部分遗忘的七七八八了,打算以以往的内容复习为主。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/6241.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

计算机网络八股文个人总结

1.TCP/IP模型和OSI模型的区别 在计算机网络中,TCP/IP 模型和 OSI 模型是两个重要的网络协议模型。它们帮助我们理解计算机通信的工作原理。以下是它们的主要区别,以通俗易懂的方式进行解释: 1. 模型层数 OSI 模型:有 7 层&#…

Web Workers 学习笔记

最近在开发中遇到了一个需求,一大堆的图片都需要调用两个接口。这对单线程的 JavaScript 运行环境构成了挑战,容易影响用户体验。所以决定学习 Web Workers 并记录一下。 Web Workers 的作用就是提供一个多线程环境,允许将一些繁重任务&…

YOLO11改进|注意力机制篇|引入HAT超分辨率重建模块

目录 一、HAttention注意力机制1.1HAttention注意力介绍1.2HAT核心代码二、添加HAT注意力机制2.1STEP12.2STEP22.3STEP32.4STEP4三、yaml文件与运行3.1yaml文件3.2运行成功截图一、HAttention注意力机制 1.1HAttention注意力介绍 HAT模型 通过结合卷积特征提取与多尺度注意力机…

关于wordpress instagram feed 插件 (现更名为Smash Balloon Social Photo Feed)

插件地址: Smash Balloon Social Photo Feed – Easy Social Feeds Plugin – WordPress 插件 | WordPress.org China 简体中文 安装后,配置教程: Setting up the Instagram Feed Pro WordPress Plugin - Smash Balloon 从这里面开始看就…

ElasticSearch认识

ElasticSearch是什么? Elasticsearch 是一个基于 Apache Lucene 构建的开源分布式搜索引擎和分析引擎。它专为云计算环境设计,提供了一个分布式的、高可用的实时分析和搜索平台。Elasticsearch 可以处理大量数据,并且具备横向扩展能力&#…

在 Google Chrome 上查找并安装 SearchGPT 扩展

ChatGPT 搜索 (SearchGPT),一个嵌入在流行的 ChatGPT 聊天机器人中的全新搜索引擎,可以改变人们搜索网页的方式。如果你想让它更容易找到并使用它,可以通过安装它的 Chrome 扩展程序。 ChatGPT 搜索是一个快速、精准且无广告的搜索引擎&…

两道算法题

一、算法一 Amazon would like to enforce a password policy that when a user changes their password, the new password cannot be similar to the current one. To determine whether two passwords are similar, they take the new password, choose a set of indices a…

嵌入式硬件电子电路设计(三)电源电路之负电源

引言:在对信号线性度放大要求非常高的应用需要使用双电源运放,比如高精度测量仪器、仪表等;那么就需要给双电源运放提供正负电源。 目录 负电源电路原理 负电源的作用 如何产生负电源 负电源能作功吗? 地的理解 负电压产生电路 BUCK电…

A019基于SpringBoot的校园闲置物品交易系统

🙊作者简介:在校研究生,拥有计算机专业的研究生开发团队,分享技术代码帮助学生学习,独立完成自己的网站项目。 代码可以查看文章末尾⬇️联系方式获取,记得注明来意哦~🌹 赠送计算机毕业设计600…

字节青训-小S的倒排索引

问题描述 小S正在帮助她的朋友们建立一个搜索引擎。为了让用户能够更快地找到他们感兴趣的帖子,小S决定使用倒排索引。倒排索引的工作原理是:每个单词都会关联一个帖子ID的列表,这些帖子包含该单词,且ID按从小到大的顺序排列。 例…

2024 CSS - 基础保姆级教程系列一

CSS盒子模型 <style>.box {width: 200px;height: 100px;padding: 20px;} </style> <div class"box">盒子模型 </div><style>.box {width: 200px;height: 100px;padding: 20px;box-sizing: border-box;} </style> <div class&…

道品科技水肥一体化如何确定灌溉需水量呢?

在农业生产进程之中&#xff0c;持续攀升的生产成本&#xff0c;使农民苦不堪言。其一&#xff0c;水肥用量递增&#xff0c;致使成本上扬&#xff1b;其二&#xff0c;种植成效并不显著&#xff0c;所增经济收益颇为有限。另外&#xff0c;不科学的滴灌施肥亦破坏了农业环境架…

北航软件工程考研难度分析!

C哥专业提供——计软考研院校选择分析专业课备考指南规划 总体情况概述 北航软件工程学硕2024届呈现"稳中有降"态势。2024届复试线335分&#xff0c;较2023届上升25分&#xff0c;但较2022届下降10分。实际录取24人&#xff08;含实验室方向&#xff09;&#xff0c…

网页,app,微信小程序互相跳转

1.网页打开小程序 配置&#xff1a;登录小程序账号&#xff0c;找到账号设置&#xff0c;在基本设置中找到隐私与安全 在明文scheme中点击配置&#xff0c;填写要跳转的小程序页面地址即可 此处只展示一种实现方法&#xff0c;其他参照获取 URL Scheme | 微信开放文档 <a …

SQL,力扣题目1767,寻找没有被执行的任务对【递归】

一、力扣链接 LeetCode_1767 二、题目描述 表&#xff1a;Tasks ------------------------- | Column Name | Type | ------------------------- | task_id | int | | subtasks_count | int | ------------------------- task_id 具有唯一值的列。 ta…

【工具】在线一维码生成器

在国外网站上看到一款条形码生成器&#xff0c;它是开源的&#xff0c;很好用。但是访问慢&#xff0c;也不支持下载一维码&#xff0c; 于是我把他翻译了过来&#xff0c;加上下载条码功能&#xff0c;再加了配色&#xff0c;让界面看上来更丰富 一个可直接使用的工具&#x…

PHM技术应用:发电机线棒高温预警

目录 1 案例背景 1.1 事件描述 1.2 设备概况 1.3 事件过程 2 系统动力学模型 典型工况 故障树 潜在业务提升 3 异常预警规则模型 4 故障排查逻辑 5 小结 1 案例背景 1.1 事件描述 某发电厂的某台发电机组&#xff0c;在满功率工况下&#xff0c;因发电机下层线棒温…

Spark on YARN:Spark集群模式之Yarn模式的原理、搭建与实践

Spark 的介绍与搭建&#xff1a;从理论到实践-CSDN博客 Spark 的Standalone集群环境安装与测试-CSDN博客 PySpark 本地开发环境搭建与实践-CSDN博客 Spark 程序开发与提交&#xff1a;本地与集群模式全解析-CSDN博客 目录 一、Spark on YARN 的优势 &#xff08;一&#…

是时候用开源降低AI落地门槛了

过去三十多年&#xff0c;从Linux到KVM&#xff0c;从OpenStack到Kubernetes&#xff0c;IT领域众多关键技术都来自开源。开源技术不仅大幅降低了IT成本&#xff0c;也降低了企业技术创新的门槛。 那么&#xff0c;在生成式AI时代&#xff0c;开源能够为AI带来什么&#xff1f;…

机器学习—矩阵乘法的规则

有一个23的矩阵A&#xff0c;有两行三列&#xff0c;把这个矩阵的列想象成三个向量a1,a2,a3&#xff0c;用一个转置&#xff0c;把它相乘&#xff0c;首先&#xff0c;什么是转置&#xff0c;把一个矩阵进行行变列&#xff0c;列变行的操作&#xff0c;所以这些行现在是一个转置…