昇思MindSpore学习笔记4-03生成式--Diffusion扩散模型

摘要:

        记录昇思MindSpore AI框架使用DDPM模型给图像数据正向逐步添加噪声,反向逐步去除噪声的工作原理和实际使用方法、步骤。

一、概念

1. 扩散模型Diffusion Models

DDPM(denoising diffusion probabilistic model)

(无)条件图像/音频/视频生成领域

        Open-ai

                GLIDE

                DALL-E

        海德堡大学

                潜在扩散

        Google Brain

                图像生成

2. 扩散过程

固定(或预定义)正向扩散过程 q

        将噪声从一些简单分布转换为数据样本

        逐渐添加高斯噪声到图像中,得到纯噪声

学习反向去噪的扩散过程 p0 

        训练神经网络从纯噪声开始逐渐图像去噪,得到实际图像

3. 扩散模型实现原理

(1)正向过程

        图片上加噪声

        神经网络优化可控损失函数

真实数据分布q(x0)

        由于 x0q(x0) ,采样获得图像x0

定义向扩散过程q(xt|xt-1) 

        动态方差 0<β1<β2<...<βT<1 时间步长t

        每个时间步长t添加高斯噪声

        马尔科夫过程:

正态分布(高斯分布)定义参数

        平均值μ

        方差σ2 0

        每个时间步长t从条件高斯分布产生新的噪声图像q({\mu}_t)=\sqrt{1- {\beta}_t}{x}_{t-1}

        采样\epsilon \sim N(0,I)

        设置q(x_t)=\sqrt{1-\beta _t}x_{t-1}+\sqrt{\beta _t}\epsilon

                \beta _t每个时间步长t不恒定

                        通过动态方差

                        每个时间步长的 \beta _t是线性的、二次的、余弦的等

                        设置时间表,得到x_0,...,x_t,...x_T

                        t足够大时x_T就是纯高斯噪声

(2)反向过程

        条件概率分布 p(x_{t-1}|x_t)

        采样随机高斯噪声x_T

        逐渐去噪

        得到真实分布x_0 样本

神经网络近似学习条件概率分布 pθ(xt-1|xt)

        神经网络参数θ

高斯分布参数:

        由\mu _\theta参数化的平均值

        由\mu _\theta参数化的方差

反向过程公式p_\theta (x_{t-1}|x_t)=N(x_{t-1};\mu (x_t,t),\sum _\theta (x_t,t))

        平均值和方差取决于噪声水平t

        神经网络通过学习来找到这些均值和方差

        方差固定

        神经网络只学习条件概率分布的平均值μθ

导出目标函数来学习反向过程的平均值

qp_\theta组合为变分自动编码器(VAE)

        最小化真值数据样本x_0的似然负对数

        变分下界ELBO是每个时间步长的损失之和

                 L=L_0+L_1+...+L_T

                每项损失L_t是2个高斯分布之间的KL发散除了L_0

                相对于均值的L2-loss!

构建Diffusion正向过程的直接结果
x_0条件下任意噪声水平采样x_t

        a_t := 1-\beta _t 

        \bar{a}t:=\prod _{s=1}^{t}\textrm{a}_s ,        q(x_t|x_0)=N(x_t;\sqrt{\bar{a}_t}x_0,(1-\bar{a}_t)I)

采样高斯噪声适当缩放添加到x_0 直接获得x_t

\bar{a}_t是已知\beta _t方差计划的函数可以预先计算

训练期间随机采样t优化损失函数L的随机项L_T

优点

重新参数化平均值

神经网络学习构成损失的KL项中噪声的附加噪声

神经网络成了噪声预测器,不是均值预测器

平均值计算:\mu _\theta (x_t,t)=\frac{1}{\sqrt{a_t}}(x_t-\frac{\beta _t}{\sqrt{1-\bar{a}_t}}\epsilon _\theta (x_t,t))

目标函数Lt \left \| \epsilon -\epsilon _\theta (x_t,t) \right \|^2 =\left \| \epsilon -\epsilon _\theta (\sqrt{\bar{a}_t}x_0+\sqrt{(1-\bar{a}_t)}\epsilon ,t) \right \|^2

                        随机步长t由(ϵ∼N(0,I)) 给定

                        x_0初始图像

                        ϵ时间步长t纯噪声采样

                       \epsilon _\theta (x_t,t)神经网络

基于真实噪声和预测高斯噪声之间的简单均方误差(MSE)优化神经网络

训练算法如下:

4. Net神经网络预测噪声

神经网络需要在特定时间步长接收带噪声的图像,并返回预测的噪声。

预测噪声是与输入图像具有相同大小/分辨率的张量。

网络接受并输出相同形状的张量。

自动编码器

        编码器编码图像为"bottleneck"--较小的隐藏表示

        解码器解码"bottleneck"回实际图像

残差连接改善梯度流

正向和反向过程在有限时间步长T(T=1000)

t=0开始,在数据分布中采样真实图像x_0

使用ImageNet猫图像添加噪声

正向过程

        每个时间步长t都采样一些高斯分布噪声

        添加到上一个次图像中

        足够大的T + 较好地添加噪声过程

        t = T时得到各向同性高斯分布

二、环境准备

安装并导入所需的库MindSpore、download、dataset、matplotlib以及tqdm

%%capture captured_output
# 实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号
!pip uninstall mindspore -y 
!pip install -i https://pypi.mirrors.ustc.edu.cn/simple mindspore==2.2.14
# 查看当前 mindspore 版本
!pip show mindspore

输出:

import math
from functools import partial
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import numpy as np
from multiprocessing import cpu_count
from download import downloadimport mindspore as ms
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore import Tensor, Parameter
from mindspore import dtype as mstype
from mindspore.dataset.vision import Resize, Inter, CenterCrop, ToTensor, RandomHorizontalFlip, ToPIL
from mindspore.common.initializer import initializer
from mindspore.amp import DynamicLossScalerms.set_seed(0)

三、构建Diffusion模型

1.定义帮助函数和类

def rearrange(head, inputs):b, hc, x, y = inputs.shapec = hc // headreturn inputs.reshape((b, head, c, x * y))def rsqrt(x):res = ops.sqrt(x)return ops.inv(res)def randn_like(x, dtype=None):if dtype is None:dtype = x.dtyperes = ops.standard_normal(x.shape).astype(dtype)return resdef randn(shape, dtype=None):if dtype is None:dtype = ms.float32res = ops.standard_normal(shape).astype(dtype)return resdef randint(low, high, size, dtype=ms.int32):res = ops.uniform(size, Tensor(low, dtype), Tensor(high, dtype), dtype=dtype)return resdef exists(x):return x is not Nonedef default(val, d):if exists(val):return valreturn d() if callable(d) else ddef _check_dtype(d1, d2):if ms.float32 in (d1, d2):return ms.float32if d1 == d2:return d1raise ValueError('dtype is not supported.')class Residual(nn.Cell):def __init__(self, fn):super().__init__()self.fn = fndef construct(self, x, *args, **kwargs):return self.fn(x, *args, **kwargs) + x

2.定义上采样和下采样操作的别名

def Upsample(dim):return nn.Conv2dTranspose(dim, dim, 4, 2, pad_mode="pad", padding=1)def Downsample(dim):return nn.Conv2d(dim, dim, 4, 2, pad_mode="pad", padding=1)

3.位置向量

神经网络时间参数使用正弦位置嵌入来编码特定时间步长t 

SinusoidalPositionEmbeddings模块

输入采用(batch_size, 1)形状的张量

        批处理噪声图像、噪声水平

转换为(batch_size, dim)形状的张量

        dim是位置嵌入尺寸

添加到每个剩余块中

class SinusoidalPositionEmbeddings(nn.Cell):def __init__(self, dim):super().__init__()self.dim = dimhalf_dim = self.dim // 2emb = math.log(10000) / (half_dim - 1)emb = np.exp(np.arange(half_dim) * - emb)self.emb = Tensor(emb, ms.float32)def construct(self, x):emb = x[:, None] * self.emb[None, :]emb = ops.concat((ops.sin(emb), ops.cos(emb)), axis=-1)return emb

4.ResNet/ConvNeXT块

选择ConvNeXT块构建U-Net模型

class Block(nn.Cell):def __init__(self, dim, dim_out, groups=1):super().__init__()self.proj = nn.Conv2d(dim, dim_out, 3, pad_mode="pad", padding=1)self.proj = c(dim, dim_out, 3, padding=1, pad_mode='pad')self.norm = nn.GroupNorm(groups, dim_out)self.act = nn.SiLU()
​def construct(self, x, scale_shift=None):x = self.proj(x)x = self.norm(x)
​if exists(scale_shift):scale, shift = scale_shiftx = x * (scale + 1) + shift
​x = self.act(x)return x
​
class ConvNextBlock(nn.Cell):def __init__(self, dim, dim_out, *, time_emb_dim=None, mult=2, norm=True):super().__init__()self.mlp = (nn.SequentialCell(nn.GELU(), nn.Dense(time_emb_dim, dim))if exists(time_emb_dim)else None)
​self.ds_conv = nn.Conv2d(dim, dim, 7, padding=3, group=dim, pad_mode="pad")self.net = nn.SequentialCell(nn.GroupNorm(1, dim) if norm else nn.Identity(),nn.Conv2d(dim, dim_out * mult, 3, padding=1, pad_mode="pad"),nn.GELU(),nn.GroupNorm(1, dim_out * mult),nn.Conv2d(dim_out * mult, dim_out, 3, padding=1, pad_mode="pad"),)
​self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
​def construct(self, x, time_emb=None):h = self.ds_conv(x)if exists(self.mlp) and exists(time_emb):assert exists(time_emb), "time embedding must be passed in"condition = self.mlp(time_emb)condition = condition.expand_dims(-1).expand_dims(-1)h = h + condition
​h = self.net(h)return h + self.res_conv(x)

5.Attention模块

multi-head self-attention

        常规注意力中缩放

LinearAttention

        时间和内存要求在序列长度上线性缩放

class Attention(nn.Cell):def __init__(self, dim, heads=4, dim_head=32):super().__init__()self.scale = dim_head ** -0.5self.heads = headshidden_dim = dim_head * heads
​self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)self.to_out = nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True)self.map = ops.Map()self.partial = ops.Partial()
​def construct(self, x):b, _, h, w = x.shapeqkv = self.to_qkv(x).chunk(3, 1)q, k, v = self.map(self.partial(rearrange, self.heads), qkv)
​q = q * self.scale
​# 'b h d i, b h d j -> b h i j'sim = ops.bmm(q.swapaxes(2, 3), k)attn = ops.softmax(sim, axis=-1)# 'b h i j, b h d j -> b h i d'out = ops.bmm(attn, v.swapaxes(2, 3))out = out.swapaxes(-1, -2).reshape((b, -1, h, w))
​return self.to_out(out)
​
​
class LayerNorm(nn.Cell):def __init__(self, dim):super().__init__()self.g = Parameter(initializer('ones', (1, dim, 1, 1)), name='g')
​def construct(self, x):eps = 1e-5var = x.var(1, keepdims=True)mean = x.mean(1, keep_dims=True)return (x - mean) * rsqrt((var + eps)) * self.g
​
​
class LinearAttention(nn.Cell):def __init__(self, dim, heads=4, dim_head=32):super().__init__()self.scale = dim_head ** -0.5self.heads = headshidden_dim = dim_head * headsself.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, pad_mode='valid', has_bias=False)
​self.to_out = nn.SequentialCell(nn.Conv2d(hidden_dim, dim, 1, pad_mode='valid', has_bias=True),LayerNorm(dim))
​self.map = ops.Map()self.partial = ops.Partial()
​def construct(self, x):b, _, h, w = x.shapeqkv = self.to_qkv(x).chunk(3, 1)q, k, v = self.map(self.partial(rearrange, self.heads), qkv)
​q = ops.softmax(q, -2)k = ops.softmax(k, -1)
​q = q * self.scalev = v / (h * w)
​# 'b h d n, b h e n -> b h d e'context = ops.bmm(k, v.swapaxes(2, 3))# 'b h d e, b h d n -> b h e n'out = ops.bmm(context.swapaxes(2, 3), q)
​out = out.reshape((b, -1, h, w))return self.to_out(out)

6.组归一化

U-Net卷积/注意层与群归一化

定义PreNorm类

        在注意层之前应用groupnorm

class PreNorm(nn.Cell):def __init__(self, dim, fn):super().__init__()self.fn = fnself.norm = nn.GroupNorm(1, dim)
​def construct(self, x):x = self.norm(x)return self.fn(x)

7.条件U-Net

网络\epsilon _\theta (x_t,t)

        输入

                噪声图像,(batch_size, num_channels, height, width)形状

                噪音水平,(batch_size, 1)形状

        输出

                噪声,(batch_size, num_channels, height, width)形状的张量

8.网络构建过程

噪声图像批上应用卷积层

计算噪声水平位置

应用一系列下采样级

        每个下采样阶段

                2个ResNet/ConvNeXT块

                Groupnorm

                Attention

                残差连接

                一个下采样操作

应用ResNet或ConvNeXT块

交织attention

应用一系列上采样级

        每个上采样级

                2个ResNet/ConvNeXT块

                Groupnorm

                Attention

                残差连接

                一个上采样操作

应用ResNet/ConvNeXT块

应用卷积层

class Unet(nn.Cell):def __init__(self,dim,init_dim=None,out_dim=None,dim_mults=(1, 2, 4, 8),channels=3,with_time_emb=True,convnext_mult=2,):super().__init__()
​self.channels = channels
​init_dim = default(init_dim, dim // 3 * 2)self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3, pad_mode="pad", has_bias=True)
​dims = [init_dim, *map(lambda m: dim * m, dim_mults)]in_out = list(zip(dims[:-1], dims[1:]))
​block_klass = partial(ConvNextBlock, mult=convnext_mult)
​if with_time_emb:time_dim = dim * 4self.time_mlp = nn.SequentialCell(SinusoidalPositionEmbeddings(dim),nn.Dense(dim, time_dim),nn.GELU(),nn.Dense(time_dim, time_dim),)else:time_dim = Noneself.time_mlp = None
​self.downs = nn.CellList([])self.ups = nn.CellList([])num_resolutions = len(in_out)
​for ind, (dim_in, dim_out) in enumerate(in_out):is_last = ind >= (num_resolutions - 1)
​self.downs.append(nn.CellList([block_klass(dim_in, dim_out, time_emb_dim=time_dim),block_klass(dim_out, dim_out, time_emb_dim=time_dim),Residual(PreNorm(dim_out, LinearAttention(dim_out))),Downsample(dim_out) if not is_last else nn.Identity(),]))
​mid_dim = dims[-1]self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
​for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):is_last = ind >= (num_resolutions - 1)
​self.ups.append(nn.CellList([block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),block_klass(dim_in, dim_in, time_emb_dim=time_dim),Residual(PreNorm(dim_in, LinearAttention(dim_in))),Upsample(dim_in) if not is_last else nn.Identity(),]))
​out_dim = default(out_dim, channels)self.final_conv = nn.SequentialCell(block_klass(dim, dim), nn.Conv2d(dim, out_dim, 1))
​def construct(self, x, time):x = self.init_conv(x)
​t = self.time_mlp(time) if exists(self.time_mlp) else None
​h = []
​for block1, block2, attn, downsample in self.downs:x = block1(x, t)x = block2(x, t)x = attn(x)h.append(x)
​x = downsample(x)
​x = self.mid_block1(x, t)x = self.mid_attn(x)x = self.mid_block2(x, t)
​len_h = len(h) - 1for block1, block2, attn, upsample in self.ups:x = ops.concat((x, h[len_h]), 1)len_h -= 1x = block1(x, t)x = block2(x, t)x = attn(x)
​x = upsample(x)return self.final_conv(x)

四、正向扩散

1.定义T时间步的时间表

def linear_beta_schedule(timesteps):beta_start = 0.0001beta_end = 0.02return np.linspace(beta_start, beta_end, timesteps).astype(np.float32)

首先使用T = 200时间步长的线性计划

定义\beta _t的各种变量

        方差 \bar{a}_t的累积乘积

        每个变量都是一维张量,存储tT的值

        extract函数,批提取t索引

# 扩散200步
timesteps = 200
​
# 定义 beta schedule
betas = linear_beta_schedule(timesteps=timesteps)
​
# 定义 alphas
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.pad(alphas_cumprod[:-1], (1, 0), constant_values=1)
​
sqrt_recip_alphas = Tensor(np.sqrt(1. / alphas))
sqrt_alphas_cumprod = Tensor(np.sqrt(alphas_cumprod))
sqrt_one_minus_alphas_cumprod = Tensor(np.sqrt(1. - alphas_cumprod))
​
# 计算 q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
​
p2_loss_weight = (1 + alphas_cumprod / (1 - alphas_cumprod)) ** -0.
p2_loss_weight = Tensor(p2_loss_weight)
​
def extract(a, t, x_shape):b = t.shape[0]out = Tensor(a).gather(t, -1)return out.reshape(b, *((1,) * (len(x_shape) - 1)))

2.扩散过程的每个时间步猫图像添加噪音

# 下载猫猫图像
url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/image_cat.zip'
path = download(url, './', kind="zip", replace=True)

输出:

Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/image_cat.zip (170 kB)file_sizes: 100%|████████████████████████████| 174k/174k [00:00<00:00, 1.45MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./

from PIL import Image
​
image = Image.open('./image_cat/jpg/000000039769.jpg')
base_width = 160
image = image.resize((base_width, int(float(image.size[1]) * float(base_width / float(image.size[0])))))
image.show()

输出:

添加噪声到mindspore张量

定义图像转换

        从PIL图像转换到mindspore张量

        除以255标准化图像,确保在[-1,1]范围内(假设图像数据由{0,1,...,255}中的整数组成)

from mindspore.dataset import ImageFolderDataset
​
image_size = 128
transforms = [Resize(image_size, Inter.BILINEAR),CenterCrop(image_size),ToTensor(),lambda t: (t * 2) - 1
]
​
​
path = './image_cat'
dataset = ImageFolderDataset(dataset_dir=path, num_parallel_workers=cpu_count(),extensions=['.jpg', '.jpeg', '.png', '.tiff'],num_shards=1, shard_id=0, shuffle=False, decode=True)
dataset = dataset.project('image')
transforms.insert(1, RandomHorizontalFlip())
dataset_1 = dataset.map(transforms, 'image')
dataset_2 = dataset_1.batch(1, drop_remainder=True)
x_start = next(dataset_2.create_tuple_iterator())[0]
print(x_start.shape)

输出:

(1, 3, 128, 128)

3.定义反向变换

输入一个包[−1,1]的张量

输出PIL图像

import numpy as np
​
reverse_transform = [lambda t: (t + 1) / 2,lambda t: ops.permute(t, (1, 2, 0)), # CHW to HWClambda t: t * 255.,lambda t: t.asnumpy().astype(np.uint8),ToPIL()
]
​
def compose(transform, x):for d in transform:x = d(x)return x

验证:

reverse_image = compose(reverse_transform, x_start[0])
reverse_image.show()

输出:

4.定义向扩散过程

def q_sample(x_start, t, noise=None):if noise is None:noise = randn_like(x_start)return (extract(sqrt_alphas_cumprod, t, x_start.shape) * x_start +extract(sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)

测试:

def get_noisy_image(x_start, t):# 添加噪音x_noisy = q_sample(x_start, t=t)
​# 转换为 PIL 图像noisy_image = compose(reverse_transform, x_noisy[0])
​return noisy_image
[18]:# 设置 time step
t = Tensor([40])
noisy_image = get_noisy_image(x_start, t)
print(noisy_image)
noisy_image.show()

输出:

<PIL.Image.Image image mode=RGB size=128x128 at 0x7F54569F3950>

显示不同的时间步骤:

import matplotlib.pyplot as plt
​
def plot(imgs, with_orig=False, row_title=None, **imshow_kwargs):if not isinstance(imgs[0], list):imgs = [imgs]
​num_rows = len(imgs)num_cols = len(imgs[0]) + with_orig_, axs = plt.subplots(figsize=(200, 200), nrows=num_rows, ncols=num_cols, squeeze=False)for row_idx, row in enumerate(imgs):row = [image] + row if with_orig else rowfor col_idx, img in enumerate(row):ax = axs[row_idx, col_idx]ax.imshow(np.asarray(img), **imshow_kwargs)ax.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[])
​if with_orig:axs[0, 0].set(title='Original image')axs[0, 0].title.set_size(8)if row_title is not None:for row_idx in range(num_rows):axs[row_idx, 0].set(ylabel=row_title[row_idx])
​plt.tight_layout()
[20]:plot([get_noisy_image(x_start, Tensor([t])) for t in [0, 50, 100, 150, 199]])

定义损失函数:

def p_losses(unet_model, x_start, t, noise=None):if noise is None:noise = randn_like(x_start)x_noisy = q_sample(x_start=x_start, t=t, noise=noise)predicted_noise = unet_model(x_noisy, t)
​loss = nn.SmoothL1Loss()(noise, predicted_noise)# todoloss = loss.reshape(loss.shape[0], -1)loss = loss * extract(p2_loss_weight, t, loss.shape)return loss.mean()

五、数据准备与处理

1.下载数据集

Fashion-MNIST图像

        线性缩放为 [−1,1]

        相同图像大小28x28

        随机水平翻转

使用download下载

解压到指定路径./

# 下载MNIST数据集
url = 'https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset.zip'
path = download(url, './', kind="zip", replace=True)

输出:

Downloading data from https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/notebook/datasets/dataset.zip (29.4 MB)file_sizes: 100%|██████████████████████████| 30.9M/30.9M [00:00<00:00, 43.4MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./
from mindspore.dataset import FashionMnistDataset
​
image_size = 28
channels = 1
batch_size = 16
​
fashion_mnist_dataset_dir = "./dataset"
dataset = FashionMnistDataset(dataset_dir=fashion_mnist_dataset_dir, usage="train", num_parallel_workers=cpu_count(), shuffle=True, num_shards=1, shard_id=0)

2.定义transform操作

图像预处理

        随机水平翻转

        重新调整

        值在 [−1,1]范围内

transforms = [RandomHorizontalFlip(),ToTensor(),lambda t: (t * 2) - 1
]
dataset = dataset.project('image')
dataset = dataset.shuffle(64)
dataset = dataset.map(transforms, 'image')
dataset = dataset.batch(16, drop_remainder=True)

x = next(dataset.create_dict_iterator())
print(x.keys())

输出:

dict_keys(['image'])

3.采样

在训练期间从模型中采样。

采样算法2:

反转扩散过程

        从T开始,采样高斯分布纯噪声

        神经网络使用条件概率逐渐去噪,时间步t=0结束

        重新参数化

                噪声预测器插入平均值

        导出降噪程度较低的图像xt-1

        得到一个近似真实数据分布的图像

def p_sample(model, x, t, t_index):betas_t = extract(betas, t, x.shape)sqrt_one_minus_alphas_cumprod_t = extract(sqrt_one_minus_alphas_cumprod, t, x.shape)sqrt_recip_alphas_t = extract(sqrt_recip_alphas, t, x.shape)model_mean = sqrt_recip_alphas_t * (x - betas_t * model(x, t) / sqrt_one_minus_alphas_cumprod_t)
​if t_index == 0:return model_meanposterior_variance_t = extract(posterior_variance, t, x.shape)noise = randn_like(x)return model_mean + ops.sqrt(posterior_variance_t) * noise
​
def p_sample_loop(model, shape):b = shape[0]# 从纯噪声开始img = randn(shape, dtype=None)imgs = []
​for i in tqdm(reversed(range(0, timesteps)), desc='sampling loop time step', total=timesteps):img = p_sample(model, img, ms.numpy.full((b,), i, dtype=mstype.int32), i)imgs.append(img.asnumpy())return imgs
​
def sample(model, image_size, batch_size=16, channels=3):return p_sample_loop(model, shape=(batch_size, channels, image_size, image_size))

六、训练过程

# 定义动态学习率
lr = nn.cosine_decay_lr(min_lr=1e-7, max_lr=1e-4, total_step=10*3750, step_per_epoch=3750, decay_epoch=10)
​
# 定义 Unet模型
unet_model = Unet(dim=image_size,channels=channels,dim_mults=(1, 2, 4,)
)
​
name_list = []
for (name, par) in list(unet_model.parameters_and_names()):name_list.append(name)
i = 0
for item in list(unet_model.trainable_params()):item.name = name_list[i]i += 1
​
# 定义优化器
optimizer = nn.Adam(unet_model.trainable_params(), learning_rate=lr)
loss_scaler = DynamicLossScaler(65536, 2, 1000)
​
# 定义正向过程
def forward_fn(data, t, noise=None):loss = p_losses(unet_model, data, t, noise)return loss
​
# 计算梯度
grad_fn = ms.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=False)
​
# 梯度更新
def train_step(data, t, noise):loss, grads = grad_fn(data, t, noise)optimizer(grads)return loss
import time
​
# 由于时间原因,epochs设置为1,可根据需求进行调整
epochs = 1
​
for epoch in range(epochs):begin_time = time.time()for step, batch in enumerate(dataset.create_tuple_iterator()):unet_model.set_train()batch_size = batch[0].shape[0]t = randint(0, timesteps, (batch_size,), dtype=ms.int32)noise = randn_like(batch[0])loss = train_step(batch[0], t, noise)
​if step % 500 == 0:print(" epoch: ", epoch, " step: ", step, " Loss: ", loss)end_time = time.time()times = end_time - begin_timeprint("training time:", times, "s")# 展示随机采样效果unet_model.set_train(False)samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)plt.imshow(samples[-1][5].reshape(image_size, image_size, channels), cmap="gray")
print("Training Success!")

输出:

 epoch:  0  step:  0  Loss:  0.43375123epoch:  0  step:  500  Loss:  0.113769315epoch:  0  step:  1000  Loss:  0.08649178epoch:  0  step:  1500  Loss:  0.067664884epoch:  0  step:  2000  Loss:  0.07234038epoch:  0  step:  2500  Loss:  0.043936778epoch:  0  step:  3000  Loss:  0.058127824epoch:  0  step:  3500  Loss:  0.049789283
training time: 922.3438229560852 sepoch:  1  step:  0  Loss:  0.05088563epoch:  1  step:  500  Loss:  0.051174678epoch:  1  step:  1000  Loss:  0.04455947epoch:  1  step:  1500  Loss:  0.055165425epoch:  1  step:  2000  Loss:  0.043942295epoch:  1  step:  2500  Loss:  0.03274461epoch:  1  step:  3000  Loss:  0.048117325epoch:  1  step:  3500  Loss:  0.063063145
training time: 937.5596783161163 sepoch:  2  step:  0  Loss:  0.052893892epoch:  2  step:  500  Loss:  0.05721748epoch:  2  step:  1000  Loss:  0.057248186epoch:  2  step:  1500  Loss:  0.048806388epoch:  2  step:  2000  Loss:  0.05007638epoch:  2  step:  2500  Loss:  0.04337231epoch:  2  step:  3000  Loss:  0.043207955epoch:  2  step:  3500  Loss:  0.034530163
training time: 947.6374666690826 sepoch:  3  step:  0  Loss:  0.04867614epoch:  3  step:  500  Loss:  0.051636297epoch:  3  step:  1000  Loss:  0.03338969epoch:  3  step:  1500  Loss:  0.0420174epoch:  3  step:  2000  Loss:  0.052145053epoch:  3  step:  2500  Loss:  0.03905913epoch:  3  step:  3000  Loss:  0.07621498epoch:  3  step:  3500  Loss:  0.06484105
training time: 957.7780408859253 sepoch:  4  step:  0  Loss:  0.046281893epoch:  4  step:  500  Loss:  0.03783619epoch:  4  step:  1000  Loss:  0.0587488epoch:  4  step:  1500  Loss:  0.06974746epoch:  4  step:  2000  Loss:  0.04299112epoch:  4  step:  2500  Loss:  0.027945498epoch:  4  step:  3000  Loss:  0.045338146epoch:  4  step:  3500  Loss:  0.06362417
training time: 955.6116819381714 sepoch:  5  step:  0  Loss:  0.04781142epoch:  5  step:  500  Loss:  0.032488734epoch:  5  step:  1000  Loss:  0.061507083epoch:  5  step:  1500  Loss:  0.039130375epoch:  5  step:  2000  Loss:  0.034972396epoch:  5  step:  2500  Loss:  0.039485026epoch:  5  step:  3000  Loss:  0.06690869epoch:  5  step:  3500  Loss:  0.05355365
training time: 951.7758958339691 sepoch:  6  step:  0  Loss:  0.04807706epoch:  6  step:  500  Loss:  0.021469856epoch:  6  step:  1000  Loss:  0.035354104epoch:  6  step:  1500  Loss:  0.044303045epoch:  6  step:  2000  Loss:  0.040063944epoch:  6  step:  2500  Loss:  0.02970439epoch:  6  step:  3000  Loss:  0.041152682epoch:  6  step:  3500  Loss:  0.02062454
training time: 955.2220208644867 sepoch:  7  step:  0  Loss:  0.029668871epoch:  7  step:  500  Loss:  0.028485576epoch:  7  step:  1000  Loss:  0.029675964epoch:  7  step:  1500  Loss:  0.052743085epoch:  7  step:  2000  Loss:  0.03664278epoch:  7  step:  2500  Loss:  0.04454907epoch:  7  step:  3000  Loss:  0.043067697epoch:  7  step:  3500  Loss:  0.0619511
training time: 952.6654670238495 sepoch:  8  step:  0  Loss:  0.055328347epoch:  8  step:  500  Loss:  0.035807922epoch:  8  step:  1000  Loss:  0.026412832epoch:  8  step:  1500  Loss:  0.051044375epoch:  8  step:  2000  Loss:  0.05474911epoch:  8  step:  2500  Loss:  0.044595096epoch:  8  step:  3000  Loss:  0.034082986epoch:  8  step:  3500  Loss:  0.02653109
training time: 961.9374921321869 sepoch:  9  step:  0  Loss:  0.039675284epoch:  9  step:  500  Loss:  0.046295933epoch:  9  step:  1000  Loss:  0.031403508epoch:  9  step:  1500  Loss:  0.028816734epoch:  9  step:  2000  Loss:  0.06530296epoch:  9  step:  2500  Loss:  0.051451046epoch:  9  step:  3000  Loss:  0.037913296epoch:  9  step:  3500  Loss:  0.030541396
training time: 974.643147945404 s
Training Success!

七、推理过程(从模型中采样)

从模型中采样,只使用上面定义的采样函数:

# 采样64个图片
unet_model.set_train(False)
samples = sample(unet_model, image_size=image_size, batch_size=64, channels=channels)

输出:

sampling loop time step:   0%|          | 0/200 [00:00<?, ?it/s]

# 展示一个随机效果
random_index = 5
plt.imshow(samples[-1][random_index].reshape(image_size, image_size, channels), cmap="gray")

cmap="gray")

输出:

<matplotlib.image.AxesImage at 0x7f5175ea1690>

这个模型产生一件衣服!

创建去噪过程的gif:

import matplotlib.animation as animation
​
random_index = 53
​
fig = plt.figure()
ims = []
for i in range(timesteps):im = plt.imshow(samples[i][random_index].reshape(image_size, image_size, channels), cmap="gray", animated=True)ims.append([im])
​
animate = animation.ArtistAnimation(fig, ims, interval=50, blit=True, repeat_delay=100)
animate.save('diffusion.gif')
plt.show()

输出:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1473975.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

昇思25天学习打卡营第8天|模型权重与 MindIR 的保存加载

目录 导入Python 库和模块 创建神经网络模型 保存和加载模型权重 保存和加载MindIR 导入Python 库和模块 上一章节着重阐述了怎样对超参数予以调整&#xff0c;以及如何开展网络模型的训练工作。在网络模型训练的整个进程当中&#xff0c;事实上我们满怀期望能够留存中间阶段…

眼底图像生成新 SOTA:GeCA模拟生物细胞的演变过程

眼底图像生成新 SOTA&#xff1a;GeCA模拟生物细胞的演变过程 提出背景GeCA 框架生成元胞自动机&#xff1a;从单细胞到生物体的过程生物体从单个像素细胞开始细胞扩散&#xff1a;从细胞演变为生物体通过基因遗传改进逆向采样视网膜疾病分类GeCA 逻辑拆解子解法1&#xff1a;神…

Go高级库存照片源码v5.3

GoStock – 免费和付费库存照片脚本这是一个免费和付费共享高质量库存照片的平台,用户可以上传照片与整个社区和访客分享,并可以通过 PayPal 接收捐款。此外,用户还可以点赞、评论、分享和收藏您最喜欢的照片。 下载 特征: 使用Laravel 10构建订阅系统Stripe 连接渐进式网页…

【Python机器学习】模型评估与改进——分层k折交叉验证

在k折分层验证中&#xff0c;将数据集划分为k折时&#xff0c;从数据的前k分之一开始划分&#xff0c;这可能并不总是一个好主意&#xff0c;例如iris数据集中&#xff1a; from sklearn.datasets import load_irisirisload_iris() print(Iris labels:\n:{}.format(iris.targe…

2.Python学习:数据类型和变量

1.标识符命名规则 只能由数字、字母、下划线组成不能以数字开头不能是关键字&#xff08;如class等python内部已经使用的标识符&#xff09;区分大小写 查看关键字&#xff1a; print(keyword.kwlist)2.数据类型 2.1常见数据类型 2.1.1Number数值型&#xff1a; 整数int&a…

13 - Python网络编程入门

网络编程入门 计算机网络基础 计算机网络是独立自主的计算机互联而成的系统的总称&#xff0c;组建计算机网络最主要的目的是实现多台计算机之间的通信和资源共享。今天计算机网络中的设备和计算机网络的用户已经多得不可计数&#xff0c;而计算机网络也可以称得上是一个“复…

自定义isdate函数,判定日期字符串有效性

自定义isdate函数&#xff0c;按日期“属性”&#xff0c;判定字符串日期有效性。 (笔记模板由python脚本于2024年07月05日 15:28:04创建&#xff0c;本篇笔记适合喜欢探究python内建模块的coder翻阅) 【学习的细节是欢悦的历程】 Python 官网&#xff1a;https://www.python.o…

One day for Chinese families

周围生活中的普通家庭的一天流程&#xff1a; 【上班的一天】 【放假的一天】 有家庭的人&#xff0c;上班流程&#xff1a; 01&#xff09;准备早餐&#xff0c;牛奶&#xff0c;面包 02&#xff09;叫娃娃起床&#xff0c;一般要蛮久的&#xff1b;沟通交流 -- 哄娃娃 -- 生气…

2-5 softmax 回归的简洁实现

我们发现通过深度学习框架的高级API能够使实现线性回归变得更加容易。 同样&#xff0c;通过深度学习框架的高级API也能更方便地实现softmax回归模型。 本节如在上节中一样&#xff0c; 继续使用Fashion-MNIST数据集&#xff0c;并保持批量大小为256。 import torch from torc…

Linux操作系统的引导过程

系统初始化进程与文件、systemd概述、单元类型、切换运行级别、查看系统默认默认运行、永久切换、常见的系统服务&#xff08;centos&#xff09;-CSDN博客 centos 7系统升级内核&#xff08;ELRepo仓库&#xff09;、小版本升级、自编译内核-CSDN博客 ss命令详细使用讲解文…

微积分-导数4(三角函数的导数)

证明 f ( x ) sin ⁡ x f(x) \sin x f(x)sinx的导数为 f ′ ( x ) cos ⁡ x f(x) \cos x f′(x)cosx 已知函数 f ( x ) sin ⁡ x f(x) \sin x f(x)sinx 画出 f ( x ) f(x) f(x)图像以及 f ′ ( x ) f(x) f′(x)的图像 因此&#xff0c;我们可以合理的猜测&#xff1a;…

每日一题~ (判断是否是合法的出栈序列)

大概的题意&#xff1a; 将 1-n 按照顺序进栈&#xff0c;问 输入的序列是否是合法的出栈序列。 遍历序列&#xff0c;如果当前这个值a小于 栈顶的值&#xff0c;说明它还未进栈&#xff08;因为我们是按照顺序进栈的&#xff09;&#xff0c;所以我们将 一些元素进栈&#xff…

最短路:Dijkstra

原始模板&#xff1a; 时间复杂度O() 使用于图很满的情况 struct Node{int y,v;Node(int _y,int _v){y_y;v_v;} };vector<Node> edge[N1]; int n,m,dist[N1]; bool b[N1];int Dijistra(int s,int t){memset(b,false,sizeof(b));memset(dist,127,sizeof(dist));dist[s]…

Linux开发讲课33---线程实现与线程控制步骤简析

线程概述 进程是系统中程序执行和资源分配的基本单位。 每个进程都拥有自己的数据段、代码段和堆栈段&#xff0c;这就造成了进程在进行切换等操作时都需要有比较负责的上下文切换等动作。为了进一步减少处理机的空转时间支持多处理器和减少上下文切换开销&#xff0c;进程在演…

第5章 认证授权:需求分析,Security介绍(OAuth2,JWT),用户认证,微信扫码登录,用户授权

1 模块需求分析 1.1 什么是认证授权 截至目前&#xff0c;项目已经完成了课程发布功能&#xff0c;课程发布后用户通过在线学习页面点播视频进行学习。如何去记录学生的学习过程呢&#xff1f;要想掌握学生的学习情况就需要知道用户的身份信息&#xff0c;记录哪个用户在什么…

工作手机怎么做好业务员工作微信的监控管理

什么是工作手机管理系统&#xff1f; 工作手机管理系统是专为企业管理设计的员工微信管理&#xff0c;它通过监控通讯记录、保障数据安全、自动检测敏感行为、永久保留客户信息等功能&#xff0c;帮助企业提升销售效率、维护客户资源安全&#xff0c;并确保业务流程的合规性。…

自动化设备上位机设计 三

目录 一 设计原型 二 后台源码 一 设计原型 二 后台源码 using SqlSugar;namespace 自动化上位机设计 {public partial class Form1 : Form{SqlHelper sqlHelper new SqlHelper();SqlSugarClient dbContent null;bool IsRun false;int Count 0;public Form1(){Initializ…

奇景光电战略投资Obsidian,共筑热成像技术新未来

5月29日,业界领先的IC设计公司奇景光电宣布,将对热成像传感器解决方案制造商Obsidian进行战略性投资,并以主要投资者的身份,参与到Obsidian的可转换票据融资活动中。虽然奇景光电并未公开具体的投资金额,但这一举动无疑向市场传递了一个明确的信号:奇景光电对Obsidian的技…

深度学习:为什么说英伟达A100或RTX A6000等专业GPU比RTX 4090更适合深度学习呢?

目录 一、关键术语 CUDA cores&#xff08;CUDA内核&#xff09;&#xff1a; memory bandwidth&#xff08;内存带宽&#xff09;&#xff1a; 二、深度学习的显卡硬件要求 三、NVIDIA显卡A100、RTX A6000和RTX 4090对比 1、NVIDIA A100 2、NVIDIA RTX A6000 3、NVIDI…

方法引用 异常 file

一.方法引用 1.方法引用概述 eg: 表示引用run1类里面的sxxxx方法 把这个方法当做抽象方法的方法体 &#xff1a;&#xff1a;是方法引用符 //方法引用Integer[] arr{4,3,1,6,2,7,8,5};Arrays.sort(arr,run1::subtraction);System.out.println(Arrays.toString(arr));}publi…