【扩散模型(三)】IP-Adapter 源码详解1-输入篇

系列文章目录

  • 【扩散模型(一)】中介绍了 Stable Diffusion 可以被理解为重建分支(reconstruction branch)和条件分支(condition branch)
  • 【扩散模型(二)】IP-Adapter 从条件分支的视角,快速理解相关的可控生成研究
  • 【可控图像生成系列论文(一)】 简要介绍了 MimicBrush 的整体流程和方法;
  • 【可控图像生成系列论文(二)】 就MimicBrush 的具体模型结构训练数据纹理迁移进行了更详细的介绍。
  • 【可控图像生成系列论文(三)】介绍了一篇相对早期(2018年)的可控字体艺术化工作。
  • 【可控图像生成系列论文(四)】介绍了 IP-Adapter 具体是如何训练的?
  • 【可控图像生成系列论文(五)】ControlNet 和 IP-Adapter 之间的区别有哪些?
  • 本文《【扩散模型(三)】IP-Adapter 源码详解1-输入篇》作为两个系列的交汇点,将通过对经典的 IP-Adapter 源码详细阅读,进一步加深对其原理的解释。

文章目录

  • 系列文章目录
  • 整体结构图+代码中的变量名
  • 一、IP-Adapter 做了什么?
  • 二、对应的代码实现
    • 1.模型输入
    • 2.Linear 和 LN(LayerNorm)
  • 总结


整体结构图+代码中的变量名

IP-Adapter 源码:https://github.com/tencent-ailab/IP-Adapter

本文就基于 SD1.5 的 IP-Adapter 训练代码 tutorial_train.py 为例,进行代码和结构图的解释。


在这里插入图片描述

一、IP-Adapter 做了什么?

如上图所示,插入了图中的最上面一条分支(图像输入条件分支):

  1. 蓝色的(无需训练的) Image Encoder
  2. 红色的(需训练的)Linear + LN(LayerNorm)
  3. 红色的(需训练的)、针对图像(Image Prompt)的 Cross Attention。

在论文中也提到,具体分别是:

  1. Image Encoder 是 pretrained CLIP image encoder
  2. 线性层和层归一化 Linear + LN(LayerNorm1):
    • 为了有效地分解全局图像嵌入,作者使用一个小的可训练投影网络(projection network)将图像嵌入投影到长度为N的特征序列中(在本研究中使用N=4),图像特征的维数与预训练的扩散模型中文本特征的维数相同。使用的投影网络由线性层和层归一化组成。
  3. Decoupled Cross-Attention 中,做法是在原来的 UNet 的 Cross-Attention 中加了一层 Cross-Attention。
    • 如原文提到 “we add a new cross-attention layer for each cross-attention layer in the original UNet model to insert image features.”

二、对应的代码实现

在这里插入图片描述

1.模型输入

先简单看下模型的训练时的输入,即 /path/IP-Adapter/tutorial_train.py 中 main() 函数内的 dataloader 部分,下面代码通过调用 MyDataset 类来实现了 train_dataloader 的构建。

    # dataloadertrain_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path)train_dataloader = torch.utils.data.DataLoader(train_dataset,shuffle=True,collate_fn=collate_fn,batch_size=args.train_batch_size,num_workers=args.dataloader_num_workers,)

对于实际训练使用的数据则为从 train_dataloader 中取的:

  1. batch[“images”]
    • 用来得到形状后,生成随机噪声。
    • 具体如下代码所示,通过 vae.encoder 得到 latents后
    • 通过 torch.randn_like(latents) 按照 latents 张量的形状生成一个随机的噪声张量 noise
  2. batch[“clip_images”]
    • 通过 image_encoder 得到 image_embeds 图像特征
  3. batch[“drop_image_embeds”]
    • 文中有提到会随机通过随机丢弃条件信息(如文本或图像嵌入),使得模型会学会在有条件和无条件的情况下进行预测(生成图像)
  4. batch[“text_input_ids”] 是文本输入,通过一个 text_encoder 后得到文本特征 encoder_hidden_states
  for step, batch in enumerate(train_dataloader):load_data_time = time.perf_counter() - beginwith torch.no_grad():latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample()latents = latents * vae.config.scaling_factor# Sample noise that we'll add to the latentsnoise = torch.randn_like(latents)bsz = latents.shape[0]# Sample a random timestep for each imagetimesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)timesteps = timesteps.long()# Add noise to the latents according to the noise magnitude at each timestep# (this is the forward diffusion process)noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)with torch.no_grad():image_embeds = image_encoder(batch["clip_images"].to(accelerator.device, dtype=weight_dtype)).image_embedsimage_embeds_ = []for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):if drop_image_embed == 1:image_embeds_.append(torch.zeros_like(image_embed))else:image_embeds_.append(image_embed)image_embeds = torch.stack(image_embeds_)with torch.no_grad():encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0] # pooled_prompt_embeds?

2.Linear 和 LN(LayerNorm)

以 SD1.5 + IP-Adapter 的训练代码为例:

下方代码为 /path/IP-Adapter/tutorial_train.py 中 main() 函数内,调用了定义好的 ImageProjModel 类

#ip-adapterimage_proj_model = ImageProjModel(cross_attention_dim=unet.config.cross_attention_dim,clip_embeddings_dim=image_encoder.config.projection_dim,clip_extra_context_tokens=4,)

下方代码为 /path/IP-Adapter/ip_adapter/ip_adapter.py 被调用的 ImageProjModel 类,在构造函数 __init__ 中可以看到有前文提到的 Linear 和 LayerNorm。

class ImageProjModel(torch.nn.Module):"""Projection Model"""def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):super().__init__()self.generator = Noneself.cross_attention_dim = cross_attention_dimself.clip_extra_context_tokens = clip_extra_context_tokensself.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)self.norm = torch.nn.LayerNorm(cross_attention_dim)def forward(self, image_embeds):embeds = image_embedsclip_extra_context_tokens = self.proj(embeds).reshape(-1, self.clip_extra_context_tokens, self.cross_attention_dim)clip_extra_context_tokens = self.norm(clip_extra_context_tokens)return clip_extra_context_tokens

总结

本文详解了IP-Adapter 训练源码中的输入部分,下篇则详解核心部分,针对图像输入的 Cross-Attention。


  1. Jimmy Lei Ba, Jamie Ryan Kiros, and Geoffrey E Hinton. Layer normalization. arXiv preprint arXiv:1607.06450, 2016 ↩︎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1473454.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Web美食分享平台的系统-计算机毕业设计源码45429

基于Web美食分享平台的系统设计与实现 摘 要 本研究基于Spring Boot框架,设计并实现了一个Web美食分享平台,旨在为用户提供一个交流分享美食体验的社区平台。该平台涵盖了用户注册登录、美食制作方法分享发布、点赞评论互动等功能模块,致力于…

如何在Windows 11上复制文件和文件夹路径?这里提供几种方法

在Windows 11上复制文件或文件夹的路径就像在右键单击菜单中选择一个选项或按键盘快捷键一样简单。我们将向你展示如何在电脑上以各种方式进行操作。 从右键单击菜单 复制文件或文件夹路径的最简单方法是在该项目的右键单击菜单中选择一个选项。你也可以使用此方法复制多个项…

电表读数检测数据集VOC+YOLO格式18156张12类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):18156 标注数量(xml文件个数):18156 标注数量(txt文件个数):18156 标…

使用JAR命令打包JAR文件使用Maven打包使用Gradle打包打包Spring Boot应用

本人详解 作者:王文峰,参加过 CSDN 2020年度博客之星,《Java王大师王天师》 公众号:JAVA开发王大师,专注于天道酬勤的 Java 开发问题中国国学、传统文化和代码爱好者的程序人生,期待你的关注和支持!本人外号:神秘小峯 山峯 转载说明:务必注明来源(注明:作者:王文峰…

vue 模糊查询加个禁止属性

vue 模糊查询加个禁止属性 父组件通过属性传,是否禁止输入-------默认可以输入

VirtualBox 安装 Ubuntu Server24.04

环境: ubuntu-2404-server、virtualbox 7.0.18 新建虚拟机 分配 CPU 核心和内存(根据自己电脑实际硬件配置选择) 分配磁盘空间(根据自己硬盘实际情况和需求分配即可) 设置网卡,网卡1 负责上网&#xff0c…

字符串相似度算法完全指南:编辑、令牌与序列三类算法的全面解析与深入分析

在自然语言处理领域,人们经常需要比较字符串,这些字符串可能是单词、句子、段落甚至是整个文档。如何快速判断两个单词或句子是否相似,或者相似度是好还是差。这类似于我们使用手机打错一个词,但手机会建议正确的词来修正它&#…

【VUE基础】VUE3第三节—核心语法之ref标签、props

ref标签 作用&#xff1a;用于注册模板引用。 用在普通DOM标签上&#xff0c;获取的是DOM节点。 用在组件标签上&#xff0c;获取的是组件实例对象。 用在普通DOM标签上&#xff1a; <template><div class"person"><h1 ref"title1">…

使用 PyTorch 创建的多步时间序列预测的 Encoder-Decoder 模型

Encoder-decoder 模型在序列到序列的自然语言处理任务&#xff08;如语言翻译等&#xff09;中提供了最先进的结果。多步时间序列预测也可以被视为一个 seq2seq 任务&#xff0c;可以使用 encoder-decoder 模型来处理。本文提供了一个用于解决 Kaggle 时间序列预测任务的 encod…

笔记13:switch多分支选择语句

引例&#xff1a; 输入1-5中的任意一共数字&#xff0c;对应的打印字符A,B,C,D,E int num 0; printf("Input a number[1,5]:"); scanf("%d"&#xff0c;&num); if( num 1)printf("A\n"); else if(num2)printf("B\n"); else i…

ZYNQ7020的bank引脚分区

一张图看ZYNQ7000的资源分布 从图中看出BANK33 34 35是ZYNQ的PL部分 也就是FPGA部分PS部分在BANK0 500 501&#xff0c;DDR控制器连接在PS部分BANK33的电压可调

ePTFE膜(膨体聚四氟乙烯膜)应用前景广阔 本土企业技术水平不断提升

ePTFE膜&#xff08;膨体聚四氟乙烯膜&#xff09;应用前景广阔 本土企业技术水平不断提升 ePTFE膜全称为膨体聚四氟乙烯膜&#xff0c;指以膨体聚四氟乙烯&#xff08;ePTFE&#xff09;为原材料制成的薄膜。ePTFE膜具有耐化学腐蚀、防水透气性好、耐候性佳、耐磨、抗撕裂等优…

CTF常用sql注入(三)无列名注入

0x06 无列名 适用于无法正确的查出结果&#xff0c;比如把information_schema给过滤了 join 联合 select * from users;select 1,2,3 union select * from users;列名被替换成了1,2,3&#xff0c; 我们再利用子查询和别名查 select 2 from (select 1,2,3 union select * f…

中英双语介绍伦敦金融城(City of London)

中文版 伦敦金融城&#xff0c;通常称为“金融城”或“城”&#xff08;The City&#xff09;&#xff0c;是英国伦敦市中心的一个著名金融区&#xff0c;具有悠久的历史和全球性的影响力。以下是关于伦敦金融城的详细介绍&#xff0c;包括其地理位置、人口、主要公司、历史背…

关于在自行封装的组件库中(使用vue-class-component)使用Vue-i18n无法正常翻译的解决办法

文章目录 介绍背景现象1解决办法 现象2原因分析解决办法 最终方案 介绍 大家或多或少都用过别人封装的组件库&#xff0c;甚至有人或者公司内有自行封装的一些公用组件库&#xff0c;而国际化翻译现在已经是各大项目中必不可少的一个插件了&#xff0c;但组件库中使用 i18n 进…

计算机网络 0319

OSPF协议&#xff1a;开放式最短路径优先 协议 基于代价的路由协议 适合与大型的网络 DR 指定路由器 BDR 备用指定路由器 OSPF的组播地址 224.0.0.5 224.0.0.6 RIP组播地址&#xff1a;224.0.0.9 OSPF数据包 过程&#xff1a;先各个发送hello包认识&#xff0c;成为邻居…

深圳航空顶象验证码逆向,和百度验证码训练思路

声明(lianxi a15018601872) 本文章中所有内容仅供学习交流使用&#xff0c;不用于其他任何目的&#xff0c;抓包内容、敏感网址、数据接口等均已做脱敏处理&#xff0c;严禁用于商业用途和非法用途&#xff0c;否则由此产生的一切后果均与作者无关&#xff01; 前言(lianxi a…

CC2530寄存器编程学习笔记_点灯

下面是我的CC2530的学习笔记之点灯部分。 第一步&#xff1a;分析原理图 找到需要对应操作的硬件 图 1 通过这个图1我们可以找到LED1和LED2连接的引脚&#xff0c;分别是P1_0和P1_1。 第二步 分析原理图 图 2 通过图2 确认P1_0和P1_1引脚连接到LED&#xff0c;并且这些引…

51单片机———LED点阵屏显示图形动画

单片机上的一小块屏幕就是LED点阵屏&#xff0c;与数码管一样&#xff0c;内部由LED灯组成&#xff0c;只是点阵屏使用的LED灯更多&#xff0c;LED灯呈矩形分布而非“8”字形&#xff1b;并且点阵屏和数码管一样&#xff0c;有两种接法共阳极和共阳极&#xff1b; 16*16LED点阵…

springboot集成tika解析word,pdf,xls文件文本内容

介绍 Apache Tika 是一个开源的内容分析工具包&#xff0c;用于从各种文档格式中提取文本和元数据。它支持多种文档类型&#xff0c;包括但不限于文本文件、HTML、PDF、Microsoft Office 文档、图像文件等。Tika 的主要功能包括内容检测、文本提取和元数据提取。 官网 https…