【笔记】扩散模型(九):Imagen 理论与实现

论文链接:Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding

非官方实现:lucidrains/imagen-pytorch

Imagen 是 Google Research 的文生图工作,这个工作并没有沿用 Stable Diffusion 的架构,而是级联了一系列普通的 DDPM 模型。其主要的贡献有以下几个方面:

  1. 使用比较大的文本模型进行文本嵌入,可以获得比使用 CLIP 更好的文本理解能力;
  2. 在采样阶段引入了一种动态阈值的方法,可以利用更高的 guidance scale 来生成更真实、细节更丰富的图像(这里的阈值是控制 x \mathbf{x} x 的范围);
  3. 改良了 UNet,提出 Efficient UNet,使模型更简单、收敛更快、内存消耗更少。

该模型的架构如下图所示,可以看到使用了一个条件生成的 diffusion 模型以及两个超分辨率模型,每个模型都以文本模型的 embedding 作为条件,先生成一个 64 分辨率的图像,然后逐步超分辨率到 1024 大小。

Imagen 模型结构

Imagen

预训练文本模型

现在的文生图模型主流使用的文本嵌入方法是使用 CLIP 文本编码器,在直观上感觉是比较合理的,因为 CLIP 的文本特征和图像特征共享同一个空间,用来控制图像的生成过程是比较合理的。不过 CLIP 的缺点是对文本的表达能力比较有限,处理复杂文本比较困难。

这里选择的不是使用 CLIP,而是使用规模比较大、且在大规模文本语料上训练的文本模型,具体来说使用的模型有 BERT、T5 和 CLIP。经过实验(具体结果可以看原论文 Figure 4 的 a 和 b,以及 Figure A.5),主要有以下发现:

  • 缩放文本编码器对提升生成质量的作用很明显;
  • 相比增大 UNet 的尺寸,增大文本编码器的尺寸更重要;
  • 相比于 CLIP,人类更偏好 T5-XXL 的结果。

高 Guidance Scale 的改善

提高 classifier-free guidance 的 guidance scale 可以提升文本-图像的匹配程度,但是会破坏图像的质量。这个现象是因为高 guidance scale 会导致训练阶段和测试阶段出现 mismatch。具体来说,在训练时,所有的 x \mathbf{x} x 都分布在 [ − 1 , 1 ] [-1,1] [1,1] 的范围里,然而当使用比较大的 guidance scale 时,得到的 x \mathbf{x} x 会超出这个范围。这样会导致 x \mathbf{x} x 落在已经学习过的范围以外,为了解决这个问题,作者研究了静态阈值(static thresholding)和动态阈值(dynamic thresholding)两种方案,具体算法如下图所示:

静态阈值和动态阈值算法

静态阈值

这种方法就是在预测噪声后,先计算出 x 0 \mathbf{x}_0 x0,然后将其取值范围直接裁剪到 [ − 1 , 1 ] [-1,1] [1,1] 之间,然后再进行去噪。这种方法已经很多方法都使用了,例如 openai/guided-diffusion 中的这段代码就是为了进行这种处理:

def process_xstart(x):if denoised_fn is not None:x = denoised_fn(x)if clip_denoised:return x.clamp(-1, 1) # 裁剪到 [-1,1]return xif self.model_mean_type == ModelMeanType.EPSILON:pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) # 得到 x_0)
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t
)

动态阈值

这个方法不是很好理解,我们可以从一个例子出发,我们平时进行 classifier-free guidance 时使用的 guidance scale 通常都是 7.5,那么一个原本分布在 [ − 1 , 1 ] [-1,1] [1,1] 之间的变量乘以这个系数之后就会变到 [ − 7.5 , 7.5 ] [-7.5,7.5] [7.5,7.5] 的范围内。如果某处的几个数分别是 { 0.2 , 0.4 , 0.6 , 0.8 } \{0.2, 0.4, 0.6, 0.8\} {0.2,0.4,0.6,0.8},乘以 7.5 后就变成了 { 1.5 , 3.0 , 4.5 , 6.0 } \{1.5,3.0,4.5,6.0\} {1.5,3.0,4.5,6.0}。如果此时直接将这些数裁剪到 [ − 1 , 1 ] [-1,1] [1,1],那么所有的数都会变成 1,原本这些数之间是有比较大的差别的,裁剪后都变成了相同的数,这样很明显是不合理的,动态阈值就是为了寻找一个比较合理的裁剪范围。

这里的做法是寻找一个 x 0 \mathbf{x}_0 x0 的 p-分位数 s s s,也就是找到大多数的数字落在什么范围内,然后先裁剪到 [ − s , s ] [-s,s] [s,s] 范围内,再全部除以 s s s 以缩放到 [ − 1 , 1 ] [-1,1] [1,1] 的范围内。实验发现这种方法能比较好地改善图像的质量,这部分的代码如下所示(摘自非官方实现):

if pred_objective == 'noise':x_start = noise_scheduler.predict_start_from_noise(x, t=t, noise=pred)
elif pred_objective == 'x_start':x_start = pred
elif pred_objective == 'v':x_start = noise_scheduler.predict_start_from_v(x, t=t, v=pred)if dynamic_threshold: # 动态阈值# 找到 p-分位数s = torch.quantile(rearrange(x_start, 'b ... -> b (...)').abs(),self.dynamic_thresholding_percentile,dim = -1)s.clamp_(min=1.)s = right_pad_dims_to(x_start, s)# 进行归一化x_start = x_start.clamp(-s, s) / s
else: # 静态阈值,直接截断x_start.clamp_(-1., 1.)
mean_and_variance = noise_scheduler.q_posterior(x_start=x_start, x_t=x, t=t, t_next=t_next)

级联扩散模型

为了生成高分辨率图像,模型级联了三个扩散模型,一个用来生成低分辨率图像,两个用来将低分辨率图像逐步超分到高分辨率。在训练阶段,作者发现使用带有噪声条件增强的超分模型可以生成更高质量的模型。具体来说,每次生成噪声时,还从 [ 0 , 1 ] [0,1] [0,1] 范围内随机采样一个 aug level,然后基于这个 level 进行增强。在预测噪声时,不仅输入带噪声的图像、低分辨率图像、时间步,还输入一个 aug level。在推理阶段,使用一系列 aug level 进行增强,然后分别进行推理,从中选取一个最佳样本,这样可以提升采样效果。具体的算法如下所示:

超分模型的训练和采样过程

总结

除了上述的一些贡献,Imagen 还做了一些工程上的改进,例如使用了不同的 text condition 注入方式,以及对基础的 UNet 模型进行了改进,提出了 Efficient UNet 模型等。相比同期的其他方法,Imagen 应该是为数不多可以直接生成 1024 分辨率图像的 diffusion 模型,虽然和主流的 Stable Diffusion 架构不同,但其中的一些改进思路还是值得学习一下的。

本文原文以 CC BY-NC-SA 4.0 许可协议发布于 笔记|扩散模型(九):Imagen 理论与实现,转载请注明出处。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/7628.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Windows下载安装Ollama本地运行大模型,新手详细

目录 1. 下载安装Ollama2. 环境配置- 关闭开机自启动(可选):- 配置环境变量(必须):- 配置端口(可选):- 允许浏览器跨域请求(可选): 3.…

代码随想录算法训练营Day55 | 图论理论基础、深度优先搜索理论基础、卡玛网 98.所有可达路径、797. 所有可能的路径、广度优先搜索理论基础

目录 图论理论基础 深度优先搜索理论基础 卡玛网 98.所有可达路径 广度优先搜索理论基础 图论理论基础 图论理论基础 | 代码随想录 图的基本概念 图的种类 大体分为有向图和无向图。 图中的边有方向的是有向图: 图中的边没有方向的是无向图: 图…

牛客练习赛131(dp,dfs,bfs,线段树维护等差数列)

文章目录 牛客练习赛131(dp,dfs,bfs,线段树维护等差数列)A. 小H学语文B. 小H学数学(dp、偏移值)C. 小H学生物(DFS、树上两点间路径的距离)D. 小H学历史(BFS)E. 小H学物理…

干货分享篇:Air780EP的硬件设计原理全解析(上)

一、绪论 Air780EP是一款基于移芯EC718P平台设计的LTE Cat 1无线通信模组。支持FDD-LTE/TDD-LTE的4G远距离无线传输技术。另外,模组提供了USB/UART/I2C等通用接口满足IoT行业的各种应用诉求。 二、综述 2.1 型号信息 表格 1:模块型号列表 2.2 主要性能…

Python将Word文档转为PDF

将word转pdf,只能使用办公工具,但是这些工具大都是收费。因此想用python 将word转pdf,发现很好用特此记录下。方法一:使用docx2pdf模块将docx文件转为pdf 要实现这样的功能,需要用到的就是 docx2pdf 这个python第三方库。对于doc…

无惧任天堂的法律威胁:Switch模拟器Ryujinx v1.2.72版发布

此前任天堂向多个提供 Nintendo Switch 模拟器项目发送律师函甚至直接起诉,要求这些项目立即停止更新、删除以及向任天堂提供经济赔偿。其中 Ryujinx 项目已经在 2024 年 10 月 1 日因任天堂的法律威胁而放弃项目,不过很快就有分叉版本出现,这…

JavaWeb——Web入门(6/9)-HTTP协议:协议解析(客户端的 HTTP 协议解析、服务端的 HTTP 协议解析、Web服务器的作用)

目录 概述 客户端的 HTTP 协议解析 服务端的 HTTP 协议解析 Web服务器的作用 概述 了解完 HTTP 协议的请求数据格式以及响应数据格式之后,接下来我们来讲了解 HTTP 协议的解析。 HTTP 协议的解析分为客户端和服务端两个部分,客户端浏览器中内置了解…

操作系统-实验报告单(2)

目录 1 实验目标 2 实验工具 3 实验内容、实验步骤及实验结果 一、自定义操作系统并启动 1. 最简单操作系统的编写并生成镜像文件 2.虚拟机启动操作系统 【思考题:1、仔细阅读helloos.nas,结合操作系统启动过程尝试分析它的作用;2、若…

城镇住房保障:SpringBoot系统优化技巧

3系统分析 3.1可行性分析 通过对本城镇保障性住房管理系统实行的目的初步调查和分析,提出可行性方案并对其一一进行论证。我们在这里主要从技术可行性、经济可行性、操作可行性等方面进行分析。 3.1.1技术可行性 本城镇保障性住房管理系统采用SSM框架,JA…

FlyMcu串口下载STLink Utility

1、FlyMcu FlyMcu串口下载,同STC-ISP(51单片机下载)。 使用步骤: 1、STM32的USART1通过串口转usb连接到电脑 2、通过keil生成Hex、bin文件 生成bin、hex文件可参考 keil生成bin文件(简单)-CSDN博客 创建…

aws(学习笔记第十课) 对AWS的EBS如何备份(snapshot)以及使用snapshot恢复数据,AWS实例存储

aws(学习笔记第十课) 对AWS的EBS如何备份(snapshot)以及使用snapshot,AWS实例存储 学习内容: 对AWS的EBS如何备份AWS实例存储EBS和实例存储的不足 1. 对AWS的EBS如何备份(snapshot)以及使用snapshot恢复数…

论文2—《基于柔顺控制的智能神经导航手术机器人系统设计》文献阅读分析报告

论文报告:基于卷积神经网络的手术机器人控制系统设计 摘要 本研究针对机器人辅助微创手术中定向障碍和缺乏导航信息的问题,设计了一种智能控制导航手术机器人系统。该系统采用可靠和安全的定位技术、7自由度机械臂以及避免关节角度限制的逆运动学控制策…

《数据结构与算法》二叉树基础OJ练习

二叉树的基础知识详见:《数据结构与算法》二叉树-CSDN博客 1 单值二叉树 思路 我们把树分成当前树(用根和左孩子还有右孩子进行比较,如果左孩子或者右孩子为空那就不比了,如果左右孩子或者其中一个存在就比较,相等就是…

栈和队列(C 语言)

目录 一、栈1. 栈的概念2. 栈的结构3. 栈的实现思路4. 栈的实现代码 二、队列1. 队列的概念2. 队列的结构3. 队列的实现思路4. 队列的实现代码5. 循环队列 一、栈 1. 栈的概念 栈是一种特殊的线性表,只允许在固定的一端进行插入和删除操作,该端被称为栈…

自动化测试工具Ranorex Studio(二十五)-库的拆分

默认地,每一个Ranorex Studio项目包含一个对象库文件,这个文件自动用在每一个新创建的录制中。你可以在一个单独的库文件中管理一个测试套件项目中所有的UI元素,但是在一个自动化测试项目中多个对象库的存在还是有一些原因的: .测…

Centos下安装Maven(无坑版)

Linux 安装 Maven Maven 压缩包下载与解压 华为云下载源,自行选择版本 下面的示例使用的是 3.8.1 版本 wget https://repo.huaweicloud.com/apache/maven/maven-3/3.8.1/binaries/apache-maven-3.8.1-bin.tar.gz解压 tar -zxvf apache-maven-3.8.1-bin.tar.gz移…

99、Python并发编程:多线程的问题、临界资源以及同步机制

引言 多线程技术的引入,可以帮助我们实现并发编程,一方面可以充分利用CPU计算资源,另一方面,可以在用户体验上带来极大的改善。但是,多线程技术也存在一些问题。本文就来简单聊一下多线程引入导致的问题,以…

jmeter常用配置元件介绍总结之取样器

系列文章目录 1.windows、linux安装jmeter及设置中文显示 2.jmeter常用配置元件介绍总结之安装插件 3.jmeter常用配置元件介绍总结之取样器 jmeter常用配置元件介绍总结之取样器 2.取样器2.1.HTTP请求2.2.Debug Sampler2.3.JSR223 Sampler2.4.JDBC Connection Configuration和J…

Python练习11

Python日常练习 题目: 编写一个石头剪刀布游戏,该程序要求完成如下功能: (1) 显示游戏规则,提醒用户输入一个1-3的整数或者直接回车。 用户输入回车时游戏结束。 用户输入不合法(包括输入的…