在之前的篇章中,以huggingface为例,分析了模型在训练阶段,是如何加噪声,以及用unet预测噪声的。接下来以开源代码diffusers为例,分析扩散模型在去噪推理时的原理。
在阅读这篇文章前,读者需要先了解DDPM、LDM和stable diffusion model的原理。否则阅读会吃力。
以StableDiffusionXLInpaintPipeline为例,其__call__
函数有如下,可以得到两个信息:
- unet会以时间步t的latent matrix为输入,预测噪声,噪声的形状与输入相同。
- self.scheduler.step会根据时间步t的latent matrix和预测的噪声,以某个权重相减,得到时间步t-1的latent matrix。后者可以作为下个循环的输入。
with self.progress_bar(total=num_inference_steps) as progress_bar:for i, t in enumerate(timesteps):...noise_pred = self.unet(latent_model_input,t,encoder_hidden_states=prompt_embeds,timestep_cond=timestep_cond,cross_attention_kwargs=self.cross_attention_kwargs,added_cond_kwargs=added_cond_kwargs,return_dict=False,)[0]...latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]...
深入方法self.scheduler.step
的调用,此处实现为DPMSolverMultistepScheduler,这里会进入dpm_solver_first_order_update。
def step(self,model_output: torch.Tensor,timestep: int,sample: torch.Tensor,generator=None,variance_noise: Optional[torch.Tensor] = None,return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:...if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final:prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise)elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second:prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise)else:prev_sample = self.multistep_dpm_solver_third_order_update(self.model_outputs, sample=sample)if self.lower_order_nums < self.config.solver_order:self.lower_order_nums += 1...return SchedulerOutput(prev_sample=prev_sample)
继续深入实现,可以看出:sovler有"dpmsolver++"等四种选型。简单来看,"dpmsolver++"和"dpmsolver"都是以某种权重比例让该时间步t的latent matrix减去预测噪声,而"sde-dpmsolver++"和"sde-dpmsolver"还会再加一个随机噪声(只是比例可能较小)。
def dpm_solver_first_order_update(self,model_output: torch.Tensor,*args,sample: torch.Tensor = None,noise: Optional[torch.Tensor] = None,**kwargs,) -> torch.Tensor:"""One step for the first-order DPMSolver (equivalent to DDIM).Args:model_output (`torch.Tensor`):The direct output from the learned diffusion model.sample (`torch.Tensor`):A current instance of a sample created by the diffusion process.Returns:`torch.Tensor`:The sample tensor at the previous timestep."""timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)prev_timestep = args[1] if len(args) > 1 else kwargs.pop("prev_timestep", None)# 省略部分代码...sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s)lambda_t = torch.log(alpha_t) - torch.log(sigma_t)lambda_s = torch.log(alpha_s) - torch.log(sigma_s)h = lambda_t - lambda_sif self.config.algorithm_type == "dpmsolver++":x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_outputelif self.config.algorithm_type == "dpmsolver":x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_outputelif self.config.algorithm_type == "sde-dpmsolver++":assert noise is not Nonex_t = ((sigma_t / sigma_s * torch.exp(-h)) * sample+ (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output+ sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise)elif self.config.algorithm_type == "sde-dpmsolver":assert noise is not Nonex_t = ((alpha_t / alpha_s) * sample- 2.0 * (sigma_t * (torch.exp(h) - 1.0)) * model_output+ sigma_t * torch.sqrt(torch.exp(2 * h) - 1.0) * noise)return x_t
参考阅读:
- DDIM
- 从DDIM到DPM-solver++