【笔记】扩散模型(八):DALL-E 2 (unCLIP) 论文解读与代码实现

论文链接:Hierarchical Text-Conditional Image Generation with CLIP Latents

非官方实现:lucidrains/DALLE2-pytorch

DALL-E 2 是一个比较经典的文生图模型,虽然和 Stable Diffusion 的架构有些区别,但是也利用了 CLIP 的文本-图像对齐能力实现了用文本作为条件进行图像生成。由于 CLIP 是输入文本和图像获得相应的特征,而 DALL-E 2 是将输入的文本转化为特征再转换为图像,相当于把 CLIP 中的图像编码器反转了过来,所以这个方法也被称为 unCLIP。这个模型主要由三个部分组成:

  • CLIP 模型:负责将条件文本转换到文本-图像的统一特征空间中;
  • prior 模型:将文本特征转换为图像特征,用于后续的图像生成;
  • decoer 模型:将从 prior 获得的图像特征转换为具体的生成图像,相当于反转了 CLIP 中的图像 encoder。

模型的架构图如下图所示,虚线的上方是 CLIP 模型,下方是 prior 和 decoder 模型。

DALL-E 2 模型架构

DALL-E 2 的训练与采样

由于 DALL-E 2 由三个不同的部分组成,这三个模型都需要分别进行训练。

训练的第一步是训练 CLIP 模型,这部分和 CLIP 原本的训练过程是一样的,因此 DALL-E 2 可以直接使用已经训练好的 CLIP 模型。

第二步是训练 prior 模型,这个模型的作用是将 CLIP 的文本特征转换为图像特征,用于后续的生成步骤。作者个人感觉这一步不一定是必须的,因为 CLIP 中的文本特征与图像特征是对齐的,而且在 Stable Diffusion 中实际上也是直接用 CLIP 的文本特征和 latent 做交叉注意力。不过这里还是用 prior 模型做了一步转换,直观上来说可能转换一步之后可以弥补原先在 CLIP 中文本和图像特征没有对齐的那一部分。

这里的 prior 模型有两种可能的选择:

  • 自回归模型(autoregressive prior):将图像的特征转换为一系列离散的序列,用自回归的方式生成。(应该比较类似于用 Transformer 做 next token prediction 的任务)
  • 扩散模型(diffusion prior):相当于用文本特征作为条件,并用扩散模型生成图像特征。

由于两种模型的效果差不多并且扩散模型的效率更高,所以最后使用的是扩散模型。不过这里用的不是普通的基于 UNet 的扩散模型,而是使用了一个 decoder-only 的 Transformer 模型,并且预测的内容也是从预测噪声变成了直接预测 embedding。

由于 prior 模型是要将文本特征转换为图像特征,训练目标也是将输出与 CLIP 原本的图像特征对齐,如图所示:

DALL-E 2 prior 模型的训练

最后一步是训练 decoder 模型,这个模型需要以图像为条件,生成最终的目标图像。decoder 模型使用的是一个改进的 GLIDE(也是 diffusion model),训练流程和 GLIDE 是一致的。

在采样时,首先使用 CLIP 将文本进行编码,然后用 prior 将文本特征转换为图像特征,最后用 decoder 生图。

DALL-E 2 代码解读

因为 OpenAI 官方没有放出 DALL-E 2 的完整代码,这里主要参考的是文章最开始给出的非官方实现。这个模型的层次结构也很清晰:

class DALLE2(nn.Module):def __init__(self,*,prior: DiffusionPrior,decoder: Decoder,prior_num_samples = 2):super().__init__()self.prior = priorself.decoder = decoderself.prior_num_samples = prior_num_samplesself.decoder_need_text_cond = self.decoder.condition_on_text_encodings@torch.no_grad()@eval_decoratordef forward(self,text,cond_scale = 1.,prior_cond_scale = 1.,return_pil_images = False):device = module_device(self)# 预处理文本,将文本进行 tokenizationone_text = isinstance(text, str) or (not is_list_str(text) and text.shape[0] == 1)if isinstance(text, str) or is_list_str(text):text = [text] if not isinstance(text, (list, tuple)) else texttext = tokenizer.tokenize(text).to(device)# 这里相当于两步合一:CLIP 提取文本特征+生成图像特征image_embed = self.prior.sample(text, num_samples_per_batch=self.prior_num_samples, cond_scale=prior_cond_scale)text_cond = text if self.decoder_need_text_cond else None# 使用 decoder 生成图像,可以看到不仅可以用图像特征进行 condition,# 也可以使用文本特征进行 conditionimages = self.decoder.sample(image_embed=image_embed, text=text_cond, cond_scale=cond_scale)return images

这里的 diffusion prior 大部分都和一般的 diffusion model 一样,不过主要需要关注两个方法。第一个是采样方法,和上述的流程一样,不过有一个上边没有介绍的细节,就是实际上采样了两个图像的 embedding,但是只使用了与文本最匹配的一个:

class DiffusionPrior(nn.Module):...@torch.no_grad()@eval_decoratordef sample(self,text,num_samples_per_batch = 2,cond_scale = 1.,timesteps = None):# 初始化时间步timesteps = default(timesteps, self.sample_timesteps)# 原文的做法是采样两个 image embedding 然后选 CLIP 匹配分数较高的一个text = repeat(text, 'b ... -> (b r) ...', r=num_samples_per_batch)batch_size = text.shape[0]image_embed_dim = self.image_embed_dim# 使用 CLIP 进行 embeddingtext_embed, text_encodings = self.clip.embed_text(text)text_cond = dict(text_embed=text_embed)if self.condition_on_text_encodings:text_cond = {**text_cond, 'text_encodings': text_encodings}# 生成图像 embeddingimage_embeds = self.p_sample_loop((batch_size, image_embed_dim), text_cond=text_cond, cond_scale=cond_scale, timesteps=timesteps)# 匹配一个比较好的图像 embedding 返回text_embeds = text_cond['text_embed']text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r=num_samples_per_batch)image_embeds = rearrange(image_embeds, '(b r) d -> b r d', r=num_samples_per_batch)text_image_sims = einsum('b r d, b r d -> b r', l2norm(text_embeds), l2norm(image_embeds))top_sim_indices = text_image_sims.topk(k=1).indicestop_sim_indices = repeat(top_sim_indices, 'b 1 -> b 1 d', d=image_embed_dim)top_image_embeds = image_embeds.gather(1, top_sim_indices)return rearrange(top_image_embeds, 'b 1 d -> b d')

第二个需要关注的是训练时的损失,这里预测的对象和普通的 diffusion model 有所不同:

class DiffusionPrior(nn.Module):...def p_losses(self, image_embed, times, text_cond, noise=None):noise = default(noise, lambda: torch.randn_like(image_embed))image_embed_noisy = self.noise_scheduler.q_sample(x_start=image_embed, t=times, noise=noise)self_cond = Noneif self.net.self_cond and random.random() < 0.5:with torch.no_grad():self_cond = self.net(image_embed_noisy, times, **text_cond).detach()# 正常的 diffusion model 这里预测的是噪声,但这里直接预测了 embeddingpred = self.net(image_embed_noisy,times,self_cond = self_cond,text_cond_drop_prob = self.text_cond_drop_prob,image_cond_drop_prob = self.image_cond_drop_prob,**text_cond)if self.predict_x_start and self.training_clamp_l2norm:pred = self.l2norm_clamp_embed(pred)if self.predict_v:target = self.noise_scheduler.calculate_v(image_embed, times, noise)elif self.predict_x_start:target = image_embedelse:target = noise# 计算损失也是直接用 embedding 进行计算loss = self.noise_scheduler.loss_fn(pred, target)return loss

decoder 的采样过程也没有什么特别的地方,就是普通的 diffusion model 采样过程,这里就不展开介绍了。

总结

DALL-E 2 刚出的时候也算非常火,不过这个模型也有 diffusion model 的一些通病,比如会出现不同主体的属性混淆、文本的生成效果比较差等情况。总体来说,个人感觉这个模型不如 Stable Diffusion 优雅,从后续的很多工作也可以看出,基于 Stable Diffusion 继续进行拓展的方法才是主流,基于 DALL-E 2 的方法还是比较少的。

参考资料:

  1. DALL·E 2 解读 | 结合预训练CLIP和扩散模型实现文本-图像生成

本文原文以 CC BY-NC-SA 4.0 许可协议发布于 笔记|扩散模型(八):DALL-E 2 (unCLIP) 理论与实现,转载请注明出处。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/143423.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

2024CCPC网络赛

vp链接&#xff1a;Dashboard - The 2024 CCPC Online Contest - Codeforces B. 军训 II 序列 a 从小到大排列或者从大到小排列时&#xff0c;不整齐度是最小的。方案数是所有相同数字的个数的排列数的乘积。如果首尾的数字不同的话&#xff0c;还要再乘个 2。 #include <…

【在Linux世界中追寻伟大的One Piece】进程间关系与守护进程

目录 1 -> 进程组 1.1 -> 什么是进程组 1.2 -> 组长进程 2 -> 会话 2.1 -> 什么是会话 2.2 -> 如何创建会话 2.3 -> 会话ID(SID) 3 -> 控制终端 4 -> 作业控制 4.1 -> 什么是作业(job)和作业控制(Job Control) 4.2 -> 作业号 4.3…

【他山之石】优化 JavaScript 的乐趣与价值(下)

前言 继本文的 上篇 发表之后&#xff0c;没想到反响还挺好&#xff0c;看来大家在 JS 优化的问题上越来越注重“与国际接轨”了。一起来看本文的下篇&#xff0c;也是干货满满。 文章目录 6. Avoid large objectsWhat the eff should I do about this? 7. Use eval8. Use str…

Linux用户账号管理

目录 一、useradd 创建新用户 二、usermod 修改用户账号 三、userdel 删除用户账号 四、passwd 设置或更改用户密码 五、who 或 w 查看当前登录用户 六、切换用户 6.1. su命令切换用户 6.2. sudo授权命令 6.2.1. sudo的特性 6.2.2. sudo的相关文件 6.3. exit退出 6…

自制数据库迁移工具-C版-04-HappySunshineV1.4-(支持Gbase8a、PG)

目录 一、环境信息 二、简述 三、架构图 四、升级点 五、支持功能 六、安装包下载地址 七、配置参数介绍 八、安装步骤 1、配置环境变量 2、生效环境变量 3、检验动态链接是否正常 4、修改配置文件MigrationConfig.txt &#xff08;1&#xff09;Gbase8a -> Gba…

Axios基本语法和前后端交互

Axios是一个js框架&#xff0c;用于发送ajax请求。 一、导入 // node中&#xff0c;使用npm安装 npm install axios // HTML中&#xff0c;使用cdn安装 <script src"https://unpkg.com/axios/dist/axios.min.js"></script> 二、基本使用 // 使用axios…

x264中的cabac编码实现

typedef struct { /* state */ int i_low; //概率状态的范围low int i_range; //当前概率状态 范围range /* bit stream */ int i_queue; //stored with an offset of -8 for faster asm 队列中可输出的bits 个数&#xff0c;-8 开始&#xff0c;是为了方便asm优化 int i_byt…

数据防泄密系统的构建与功能分析(实用物料)

一、构建1、需求分析&#xff1a;明确企业需要保护的敏感数据类型&#xff08;如商业机密、研发资料等&#xff09;及其潜在的泄露途径&#xff08;如网络传输、文件共享、打印复印等&#xff09;。 2、策略&#xff1a;根据需求分析结果&#xff0c;制定详细的数据防泄密策略…

数字逻辑电路-加法器

目录 半加器和全加器 半加器 ​全加器 集成全加器 利用全加器实现二进制的乘法功能 加法器 半加器和全加器 半加器 不考虑低位进位的加法。 本位为s&#xff0c;进位为c。 全加器 多了一个相邻低位来的进位数。 集成全加器 左上角和右下角那两个是不用的。 利用全加器…

Selenium通过ActionBuilder模拟鼠标操作直接移动到指定坐标的注意事项

在目前&#xff08;2024-09-18&#xff09;得Selenium官方手册中&#xff0c;模拟鼠标操作基本上都是通过ActionChains完成的&#xff0c;唯独有一动作&#xff0c;是通过ActionBuilder完成的。 而前者ActionChains&#xff0c;主要是通过offset&#xff0c;也就是坐标偏移量来…

RK3568笔记五十九:FastSAM部署

若该文为原创文章,转载请注明原文出处。 记录FastSAM训练到部署全过程,转换模型和yolov8一样。 一、介绍 Fast Segment Anything Model (FastSAM) 是一种基于 CNN 的新型实时解决方案,可用于 Segment Anything 任务。该任务旨在根据各种可能的用户交互提示分割图像中的任何…

AT24CXX系列eeprom的相关知识总结

常用的eeprom存储器件有很多容量类型&#xff0c;AT系列的eeprom有at24c01,at24c02…at24c1024等。我们来做一个总结。 1.常见的型号含义 at24c01&#xff1a;表示1kbit&#xff08;128BYTE*8&#xff09; at24c02&#xff1a;表示2kbit&#xff08;256BYTE*8&#xff09; . .…

pybind11 学习笔记

pybind11 学习笔记 0. 一个例子1. 官方文档1.1 Installing the Library1.1.1 Include as A Submodule1.1.2 Include with PyPI1.1.3 Include with Conda-forge 1.2 First Steps1.2.1 Separate Files1.2.2 PYBIND11_MODULE() 宏1.2.3 example.cpython-38-x86_64-linux-gnu.so 的…

二百六十四、Java——Java采集Kafka主题A的JSON数据,解析成一条条数据,然后写入Kafka主题B中

一、目的 由于Hive是单机环境&#xff0c;因此庞大的原始JSON数据在Hive中解析的话就太慢了&#xff0c;必须放在Hive之前解析成一个个字段、一条条CSV数据 二、IDEA创建SpringBoot项目 三、项目中各个文件 3.1 pom.xml <?xml version"1.0" encoding"UTF…

java: 警告: 源发行版 17 需要目标发行版 17(100% 解决)

1. 问题说明 Idea启动Springboot服务报错&#xff1a;java: 警告: 源发行版 17 需要目标发行版 17 2. 解决方案 Project Structure指定jdk版本为我们当前使用的版本&#xff1b; Java Compiler指定jdk为我们当前使用的版本&#xff1b; Invalidate Caches重启Idea。 如果还…

小商品市场配电系统安全用电解决方案

1.概述 随着市场经济的快速发展和人民生活水平的不断提高,全国各地相继建起了大批大型小商品批发市场,此类市场以其商品种类繁多、价格实惠、停车方便等特点吸引了大量的顾客,成为人们日常光顾的重要场所,地方便了广大人民群众的日常生活。 小商品市场集商品销售和短时货物储…

如何利用生成式AI创建图像和可视化效果

每个小型出版商在创建博客文章或新闻文章的过程中&#xff0c;都有一个恐慌时刻&#xff1a; “我用什么做我的特色图片&#xff1f;” 广告公司和媒体公司都有创意总监、摄影师和艺术家随时为他们创作图片。但我们其他人怎么办呢&#xff1f; 我们中的一些人会不顾更好的判…

数据中心扩展之路:创新的数据中心布线解决方案

在不断发展的数据管理领域中&#xff0c;现代技术的迅猛发展既带来了机遇&#xff0c;也带来了挑战&#xff0c;尤其是对不断扩展的数据中心而言。随着这些基础设施的快速发展和转型&#xff0c;对高效可靠的数据中心布线解决方案的需求日益增长。本文将探讨飞速&#xff08;FS…

redis常见类型设置、获取键值的基础命令

redis常见类型设置、获取键值的基础命令 获取键值的数据类型 命令&#xff1a;TYPE keyname 常见数据类型设置、获取键值的基本命令 string类型 置键值&#xff1a;set keyname valuename获取键值&#xff1a;get keyname删除&#xff1a; del keyname list类型 从左边向列表…

关于在Qlabel遮罩方面的踩坑实录

先看目标效果&#xff1a; 想要实现封面图标的遮罩效果&#xff0c;有两个思路&#xff1a; 一、在鼠标移动到这个item上面时&#xff0c;重新绘制pixmap 例如以下代码&#xff1a; #include <QApplication> #include <QWidget> #include <QPixmap> #incl…