生成模型常见的条件融合方式
目前生成模型主要有4中常见的条件融合方式以实现可控生成:条件归一化层,Decoupled Cross-Attention,self-attention层进行融合,特征值逐元素求和。本文首先介绍下各种方法现,然后进行总结,最后提出一下展望。
-
条件归一化层
过去和现在的一些工作,会通过条件归一化层将一些条件(比如类别和文本)融合到生成模型以实现条件可控生成。假设生成模型某层的特征是 x ∈ R b × h w × c x \in R^{b \times hw \times c} x∈Rb×hw×c, P P P是输入条件, f ( P ) f(P) f(P)是条件特征提取网络,常见的范式如下:
f ( P ) = γ , β y = γ x − μ σ + ϵ + β f(P) = \gamma ,\beta \\ y = \gamma \frac{x - \mu}{\sigma + \epsilon} + \beta f(P)=γ,βy=γσ+ϵx−μ+β, 其中 μ , σ \mu, \sigma μ,σ分别是 x x x的均值和方差, y y y将要作为生成模型下一层的输入, ϵ \epsilon ϵ是较小常量防止数值问题(为了方便,后文省略)。(1) Adaptive Instance Normalization (AdaIN)
f ( P ) = γ , β ( γ ∈ R b × 1 × c , β ∈ R b × 1 × c ) y = γ x − μ σ + β ( μ ∈ R b × 1 × c , σ ∈ R b × 1 × c ) f(P) = \gamma ,\beta \ (\gamma \in R^{b \times 1 \times c}, \beta \in R^{b \times 1 \times c}) \\ y = \gamma \frac{x - \mu}{\sigma} + \beta \ (\mu \in R^{b \times 1 \times c}, \sigma \in R^{b \times 1 \times c}) f(P)=γ,β (γ∈Rb×1×c,β∈Rb×1×c)y=γσx−μ+β (μ∈Rb×1×c,σ∈Rb×1×c)
(2) 为解决条件归一化层(比如AdaIN和条件BatchNorm)缺失空间信息的问题, SPatially-Adaptive DEnormalization (SPADE, 2019)提出具有空间维度的 μ , σ \mu, \sigma μ,σ:
f ( P ) = γ , β ( α ∈ R 1 × h w × c , β ∈ R 1 × h w × c ) y = γ x − μ σ + β ( μ ∈ R 1 × 1 × c , γ ∈ R 1 × 1 × c ) f(P) = \gamma ,\beta \ (\alpha \in R^{1 \times hw \times c}, \beta \in R^{1 \times hw \times c}) \\ y = \gamma \frac{x - \mu}{\sigma} + \beta \ (\mu \in R^{1 \times 1 \times c}, \gamma \in R^{1 \times 1 \times c}) f(P)=γ,β (α∈R1×hw×c,β∈R1×hw×c)y=γσx−μ+β (μ∈R1×1×c,γ∈R1×1×c)
(3) DiT为了实现恒等函数(the identity function)以加速大规模训练,提出adaLN-Zero block。具体地,其在adaLN(广泛用于GAN和Diffusion Unet)回归 γ , β \gamma, \beta γ,β的基础上,还会回归一个 α \alpha α参数(这个 α \alpha α会被0初始化以实现恒等函数):
f ( P ) = γ , β , α ( γ ∈ R b × 1 × c , β ∈ R b × 1 × c , α ∈ R b × 1 × c ) y = α × M u l t i H e a d S e l f A t t e n t i o n ( γ x − μ σ + β ) ( μ ∈ R b × h w × 1 , σ ∈ R b × h w × 1 ) f(P) = \gamma ,\beta, \alpha \ (\gamma \in R^{b \times 1 \times c}, \beta \in R^{b \times 1 \times c}, \alpha \in R^{b \times 1 \times c}) \\ y = \alpha \times MultiHeadSelfAttention(\gamma \frac{x - \mu}{\sigma} + \beta) \ (\mu \in R^{b \times hw \times1}, \sigma \in R^{b \times hw \times 1}) f(P)=γ,β,α (γ∈Rb×1×c,β∈Rb×1×c,α∈Rb×1×c)y=α×MultiHeadSelfAttention(γσx−μ+β) (μ∈Rb×hw×1,σ∈Rb×hw×1)
-
Decoupled Cross-Attention
Decoupled Cross-Attention最初由IPAdapter(2023)提出,其在去噪网络crossattention layer旁边再引入一个cross attention分支,然后对二个cross attetnion的输出进行求和送给去噪网络下一层。具体地,假设生成模型某层的特征是 x ∈ R b × h w × c x \in R^{b \times hw \times c} x∈Rb×hw×c, P P P是cross attention的条件特征, C C C是新的cross attention层的条件特征,我们可以得到:
y = Softmax ( W q x ( W k P ) ⊤ d ) W v P + λ Softmax ( W q x ( W k ′ C ) ⊤ d ) W v ′ C , y = \textrm{Softmax}(\frac{W_q x (W_k P)^\top}{\sqrt{d}}) W_v P + \lambda \ \textrm{Softmax}(\frac{W_q x (W_k' C)^\top}{\sqrt{d}}) W_v' C, y=Softmax(dWqx(WkP)⊤)WvP+λ Softmax(dWqx(Wk′C)⊤)Wv′C,
其中 d d d 是归一化因子, W q W_q Wq, W k W_k Wk, W v W_v Wv是去噪网络crossattention的投影矩阵, W k ′ W_k' Wk′, and W v ′ W_v' Wv′是新crossattention的投影矩阵, y y y将要作为生成模型下一层的输入, λ \lambda λ是权衡权重。
对于没有cross attention层的模型比如FLUX(2024),目前也有方法比如PuLID-FLUX-v0.9.0(2024)在self attention层旁引入cross attention分支。 -
self-attention层进行融合
如果提取条件特征的网络是去噪网络的副本,条件特征则通常是在去噪网络的self-attetion层进行融合。
(1) 假设生成模型层的特征是 x ∈ R b × h w × c x \in R^{b \times hw \times c} x∈Rb×hw×c, P P P是对应的条件特征:
x ′ = c o n c a t ( [ x , P ] , a x i s = 1 ) y = S e l f A t t e n t i o n ( x ′ , x ′ , x ′ ) [ : , : h w , : ] , x' = concat([x,P], axis=1) \\ y = SelfAttention(x', x', x')[:, :hw, :], x′=concat([x,P],axis=1)y=SelfAttention(x′,x′,x′)[:,:hw,:],
其中concat是特征拼接函数,模型经过SelfAttention操作后会丢弃多余维度特征以保证y的维度和x的维度一致。该方法最初由AnimateAnyone(2023)提出。
(2) 第一种方法有冗余特征,因此可以改进去噪网络的self-attention层的以实现更好的条件特征融合。具体地,假设生成模型层的特征是 x ∈ R b × h w × c x \in R^{b \times hw \times c} x∈Rb×hw×c, P P P是对应的条件特征:
x ′ = c o n c a t ( [ x , P ] , a x i s = 1 ) y = S e l f A t t e n t i o n ( x , x ′ , x ′ ) = Softmax ( W q x ( W k x ′ ) ⊤ d ) W v x ′ , x' = concat([x,P], axis=1) \\ y = SelfAttention(x, x', x') = \textrm{Softmax}(\frac{W_q x (W_k x')^\top}{\sqrt{d}}) W_v x', x′=concat([x,P],axis=1)y=SelfAttention(x,x′,x′)=Softmax(dWqx(Wkx′)⊤)Wvx′,
其中concat是特征拼接函数, S e l f A t t e n t i o n SelfAttention SelfAttention的key和value将由 x ′ x' x′得到而不是 x x x。
(3) 为了不影响去噪模型的原本的能力,可以引入新的crossattention分支进行特征融合。具体地,假设生成模型层的特征是 x ∈ R b × h w × c x \in R^{b \times hw \times c} x∈Rb×hw×c, P P P是对应的条件特征:
y = Softmax ( W q x ( W k x ) ⊤ d ) W v x + λ Softmax ( W q x ( W k ′ P ) ⊤ d ) W v ′ P , y = \textrm{Softmax}(\frac{W_q x (W_k x)^\top}{\sqrt{d}}) W_v x + \lambda \ \textrm{Softmax}(\frac{W_q x (W_k' P)^\top}{\sqrt{d}}) W_v' P, y=Softmax(dWqx(Wkx)⊤)Wvx+λ Softmax(dWqx(Wk′P)⊤)Wv′P,
其中 d d d 是归一化因子, W q W_q Wq, W k W_k Wk, W v W_v Wv是去噪网络self-attention投影矩阵, W k ′ W_k' Wk′, and W v ′ W_v' Wv′是新crossattention的投影矩阵, y y y将要作为生成模型下一层的输入, λ \lambda λ是权衡权重。 -
特征直接逐元素求和
(1) 同样假设生成模型某层的特征是 x ∈ R b × h w × c x \in R^{b \times hw \times c} x∈Rb×hw×c, P ∈ R b × h w × c P \in R^{b \times hw \times c} P∈Rb×hw×c是提取的对应的条件特征。ControlNet(2023)等方法直接将 x x x和 P P P进行逐元素求和:
y = x + λ P , y = x + \lambda P, y=x+λP,其中 y y y将要作为生成模型下一层的输入, λ \lambda λ是权衡权重。
(2) ControlNeXt(2024)为了解决训练开始阶段 x x x和 P P P特征分布不配备造成的收敛缓慢的问题,提出cross normalization,即使用 x x x和 P P P的均值和方差去平衡去噪网络和条件特征提取网络对最终输出的影响。具体地,假设生成模型某层的特征是 x ∈ R b × h w × c x \in R^{b \times hw \times c} x∈Rb×hw×c, P ∈ R b × h w × c P \in R^{b \times hw \times c} P∈Rb×hw×c)是提取的对应的条件特征:
y = x + λ ( P − μ P σ P σ x + μ x ) , y = x + \lambda \ (\frac{P - \mu_P}{\sigma_P} \sigma_x + \mu_x), y=x+λ (σPP−μPσx+μx),其中 μ x ∈ R b × 1 × 1 \mu_x \in R^{b \times 1 \times 1} μx∈Rb×1×1和 σ x ∈ R b × 1 × 1 \sigma_x \in R^{b \times 1 \times 1} σx∈Rb×1×1是 x x x均值和标准差, μ P ∈ R b × 1 × 1 \mu_P \in R^{b \times 1 \times 1} μP∈Rb×1×1和 σ P ∈ R b × 1 × 1 \sigma_P \in R^{b \times 1 \times 1} σP∈Rb×1×1是 P P P的均值和标准差, y y y将要作为生成模型下一层的输入, λ \lambda λ是权衡权重。
总结:
上文提到了条件归一化层,Decoupled Cross-Attention,self-attention层进行融合,特征值逐元素求和四种特征融合方式。显然地:
(1) 条件归一化层目前主要是用于一些简单的条件,比如文本和timestep。
(2) 特征值逐元素求和目前主要是用于保留条件的空间信息。
(3) Decoupled Cross-Attention和self-attention层进行融合目前主要是用于保留条件的语义和细节。在Cross-Attention层进行融合的话需要考虑到文本可控性和条件信息注入程度之间的权衡问题,即2个分支哪个分支占主导地位。
值得注意的是,尽管本文主要关注这四种特征融合方式,但是依然有其它的方式比如使用条件特征替换预训练模型的文本特征。
展望:
目前的条件融合方式主要是基于Diffusion Unet结构探索得到。对于FLUX这种将文本和图像耦合在一起的新兴模型,如何高效地将除文本外的额外信息融合到模型值得探索。