improved-diffusion-main代码理解

目录

  • 一、 TimestepEmbedSequential
  • 二、PyTorch之Checkpoint机制
  • 三、AttentionBlock
  • 四、use_scale_shift_norm

和nanoDiffusion-main相比,improved-diffusion-main代码是相似的,但有几个不是很好理解的地方记录一下。

一、 TimestepEmbedSequential

代码中class ResBlock继承自TimestepBlock,需要执行时间步嵌入操作,其他不需要。

class TimestepBlock(nn.Module):"""Any module where forward() takes timestep embeddings as a second argument."""@abstractmethoddef forward(self, x, emb):"""Apply the module to `x` given `emb` timestep embeddings."""class TimestepEmbedSequential(nn.Sequential, TimestepBlock):"""A sequential module that passes timestep embeddings to the children thatsupport it as an extra input."""def forward(self, x, emb):for layer in self:if isinstance(layer, TimestepBlock):x = layer(x, emb)else:x = layer(x)return x
class ResBlock(TimestepBlock):

二、PyTorch之Checkpoint机制

def checkpoint(func, inputs, params, flag):"""Evaluate a function without caching intermediate activations, allowing forreduced memory at the expense of extra compute in the backward pass.:param func: the function to evaluate.:param inputs: the argument sequence to pass to `func`.:param params: a sequence of parameters `func` depends on but does notexplicitly take as arguments.:param flag: if False, disable gradient checkpointing."""if flag:args = tuple(inputs) + tuple(params)return CheckpointFunction.apply(func, len(inputs), *args)else:return func(*inputs)

checkpoint 是在 torch.no_grad() 模式下计算的目标操作的前向函数,这并不会修改原本的叶子结点的状态,有梯度的还会保持。只是关联这些叶子结点的临时生成的中间变量会被设置为不需要梯度,因此梯度链式关系会被断开。

三、AttentionBlock


class AttentionBlock(nn.Module):def __init__(self, channels, num_heads=1, use_checkpoint=False):super().__init__()self.channels = channelsself.num_heads = num_headsself.use_checkpoint = use_checkpointself.norm = normalization(channels)self.qkv = conv_nd(1, channels, channels * 3, 1)self.attention = QKVAttention()self.proj_out = zero_module(conv_nd(1, channels, channels, 1))def forward(self, x):return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)def _forward(self, x):b, c, *spatial = x.shapex = x.reshape(b, c, -1)qkv = self.qkv(self.norm(x))qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])h = self.attention(qkv)h = h.reshape(b, -1, h.shape[-1])h = self.proj_out(h)return (x + h).reshape(b, c, *spatial)class QKVAttention(nn.Module):"""A module which performs QKV attention."""def forward(self, qkv):"""Apply QKV attention.:param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.:return: an [N x C x T] tensor after attention."""ch = qkv.shape[1] // 3q, k, v = th.split(qkv, ch, dim=1)scale = 1 / math.sqrt(math.sqrt(ch))weight = th.einsum("bct,bcs->bts", q * scale, k * scale)  # More stable with f16 than dividing afterwardsweight = th.softmax(weight.float(), dim=-1).type(weight.dtype)return th.einsum("bts,bcs->bct", weight, v)@staticmethoddef count_flops(model, _x, y):b, c, *spatial = y[0].shapenum_spatial = int(np.prod(spatial))# We perform two matmuls with the same number of ops.# The first computes the weight matrix, the second computes# the combination of the value vectors.matmul_ops = 2 * b * (num_spatial ** 2) * cmodel.total_ops += th.DoubleTensor([matmul_ops])

下面这个函数是准备适合的qkv矩阵

    def _forward(self, x):b, c, *spatial = x.shapex = x.reshape(b, c, -1)-》输入转换为(b,c,N)qkv = self.qkv(self.norm(x))-》通过卷积转换为(b,3*c,N)qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])h = self.attention(qkv)h = h.reshape(b, -1, h.shape[-1])h = self.proj_out(h)return (x + h).reshape(b, c, *spatial)

class QKVAttention的forward就是下面的公式:
在这里插入图片描述

四、use_scale_shift_norm

    def _forward(self, x, emb):h = self.in_layers(x)emb_out = self.emb_layers(emb).type(h.dtype)while len(emb_out.shape) < len(h.shape):emb_out = emb_out[..., None]if self.use_scale_shift_norm:out_norm, out_rest = self.out_layers[0], self.out_layers[1:]scale, shift = th.chunk(emb_out, 2, dim=1)h = out_norm(h) * (1 + scale) + shifth = out_rest(h)else:h = h + emb_outh = self.out_layers(h)return self.skip_connection(x) + h

在深度学习中,特别是在处理如扩散模型(Diffusion Models)或任何需要精细控制输出特征的神经网络时,use_scale_shift_norm引入一种灵活的变换,这种变换通过缩放(scale)和平移(shift)来调整网络层的输出。
use_scale_shift_norm是一个布尔值(True或False),用于决定是否应用这种缩放和平移的归一化方法。如果use_scale_shift_norm为True,则执行以下步骤:

分割输出层:首先,代码将self.out_layers(一个包含网络层的列表)分割为两部分。out_norm是列表中的第一个层,负责进行某种形式的归一化或变换(尽管这里的名字是out_norm,但它可能不仅仅执行归一化,而是任何形式的变换层)。out_rest是列表中剩余的所有层,这些层将在缩放和平移之后应用。
提取缩放和平移参数:接下来,从emb_out(可能是嵌入层的输出或其他某种特征表示)中提取缩放(scale)和平移(shift)参数。这里假设emb_out的维度被设计为包含这两组参数,通过th.chunk(emb_out, 2, dim=1)沿着第二维(dim=1)将其分割成两部分,分别代表缩放和平移参数。
应用缩放和平移:然后,将h(可能是之前某个层的输出)通过out_norm层进行变换,之后使用从emb_out中提取的缩放和平移参数对结果进行调整。调整的方式是将out_norm(h)的输出乘以(1 + scale)并加上shift。这个步骤实质上是在对out_norm(h)的输出进行线性变换,以引入额外的灵活性和控制。
通过剩余层:最后,将调整后的输出h通过out_rest中剩余的层进行进一步的处理。
这种技术的一个关键优势是它能够以一种灵活且数据驱动的方式调整网络层的输出,而不需要在模型架构中硬编码特定的归一化或变换策略。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1473121.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【C++题解】1413. 切割绳子

问题&#xff1a;1413. 切割绳子 类型&#xff1a;贪心&#xff0c;二分&#xff0c;noip2017普及组初赛 题目描述&#xff1a; 有 n 条绳子&#xff0c;每条绳子的长度已知且均为正整数。绳子可以以任意正整数长度切割&#xff0c;但不可以连接。现在要从这些绳子中切割出 m…

Open3D 在点云中构建八叉树

目录 一、概述 二、代码实现 2.1关键函数 2.2完整代码 三、实现效果 3.1原始点云 3.2构建后点云 一、概述 八叉树&#xff08;Octree&#xff09;是一种树状数据结构&#xff0c;用于递归地将3D空间分割成较小的立方体。八叉树特别适用于3D计算机图形学、点云处理和空间…

TreeMap、HashMap 和 LinkedHashMap 的区别

TreeMap、HashMap 和 LinkedHashMap 的区别 1、HashMap2、LinkedHashMap3、TreeMap4、总结 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收藏不迷路&#x1f496; 在 Java 中&#xff0c;TreeMap、HashMap 和 LinkedHashMap 是三种常用的集合类&#xff0c;它们在…

昇思Mindspore学习25天打卡Day16:热门LLM及其他AI应用|基于MindeNLP+MusicGen生成自己的个性化音乐

昇思Mindspore学习25天打卡Day16&#xff1a;热门LLM及其他AI应用|基于MindeNLPMusicGen生成自己的个性化音乐 1 下载模型2 生成音乐2.1 无提示生成2.2 文本提示生成2.3 音频提示生成 3 生成配置 &训练结束打上标签和时间 MusicGen是来自Meta Al的Jade Copet等人提出的基于…

连锁店收银系统源码

千呼新零售2.0系统是零售行业连锁店一体化收银系统&#xff0c;包括线下收银线上商城连锁店管理ERP管理商品管理供应商管理会员营销等功能为一体&#xff0c;线上线下数据全部打通。 私有化独立部署/全开源源码&#xff0c;系统开发语言&#xff1a; 核心开发语言: PHP、HTML…

hnust 1816: 算法10-9:简单选择排序

hnust 1816: 算法10-9&#xff1a;简单选择排序 题目描述 选择排序的基本思想是&#xff1a;每一趟比较过程中&#xff0c;在n-i1(i1,2,…,n-1)个记录中选取关键字最小的记录作为有序序列中的第i个记录。 在多种选择排序中&#xff0c;最常用且形式最为简单的是简单选择排序。…

ingress-nginx控制器证书不会自动更新问题

好久没更新了&#xff0c;正好今天遇到了一个很有意思的问题&#xff0c;在这里给大家分享下&#xff0c;同时也做下记录。 背景 最近想做个实验&#xff0c;当k8s集群中secret更新后&#xff0c;ingress-nginx控制器会不会自动加载新的证书。我用通义千问搜了下&#xff0c;…

windows 7 安装IPP协议,支持Internet打印

1 windows 7 安装IPP协议,支持Internet打印 #控制面板--打开或关闭Windows功能 3 复制Printers 文件夹 到 c:\inetpub\wwwroot\,复制msw3prt.dll到c:\windows\system32\ 4 打开IIs管理器 #报错:模块列表中不存在此处理程序所需的指定模块。如果您添加脚本映射处理程序映射&…

AndroidKille不能用?更新apktool插件-cnblog

AndroidKiller不更新插件容易报错 找到apktool管理器 填入apktool位置&#xff0c;并输入apktool名字 选择默认的apktool版本 x掉&#xff0c;退出重启 可以看到反编译完成了

网络基础:IS-IS协议

IS-IS&#xff08;Intermediate System to Intermediate System&#xff09;是一种链路状态路由协议&#xff0c;最初由 ISO&#xff08;International Organization for Standardization&#xff09;为 CLNS&#xff08;Connectionless Network Service&#xff09;网络设计。…

数据结构——(双)链表

文章目录 1. 定义 2. 双链表和单链表的区别 3. 代码示例 3.1 双链表节点和结构定义 3.2 初始化双链表 3.3 返回双链表的长度 3.4 在指定位置插入元素 3.5 在末尾插入元素 3.6 删除指定位置的元素并返回被删除的元素 3.7 删除末尾元素 3.8 获取指定位置的元素 3.9 修…

磁盘分区工具 -- 傲梅分区助手 v10.4.1 技术员版

软件简介 傲梅分区助手是一款功能强大的磁盘分区工具&#xff0c;它专为Windows系统设计&#xff0c;帮助用户更高效地管理他们的硬盘。该软件支持多种分区操作&#xff0c;包括创建、格式化、调整大小、移动、合并和分割分区。此外&#xff0c;它还提供了复制硬盘和分区的功能…

C++:Level3阶段测试

1、黑客小知识&#xff1a; &#xff08;1&#xff09;常用的黑客头文件有____和____。 &#xff08;2&#xff09;创建文件的函数叫做________。 &#xff08;3&#xff09;我更新了____个黑客头文件。 &#xff08;4&#xff09;万能头文件包含的黑客头文件是________。 …

速刷edurank(1)

python安全开发 python安全开发 python安全开发前言一、平台edu二、使用步骤1.引入库2.功能**完整代码**完整代码 总结 前言 目的&#xff1a;想快速的搜集edu的域名 一、平台edu https://src.sjtu.edu.cn/rank/firm/0/?page2 二、使用步骤 1.引入库 代码如下&#xff08…

气压传感器在自动驾驶汽车还有哪些应用场景

气压传感器在近年来被广泛应用于各种新兴领域&#xff0c;以下是其中几个最新的应用&#xff1a; 1、自动驾驶汽车&#xff1a;自动驾驶汽车需要精确的气压传感器来监测道路上的气压变化&#xff0c;帮助车辆进行准确的定位和导航。气压传感器可以提供高精度、可靠的气压数据&…

实验3-Spark基础-Spark的安装

文章目录 1. 下载安装 Scala1.1 下载 Scala 安装包1.2 基础环境准备1.3 安装 Scala 2. 下载安装 Spark2.1 下载 Spark 安装包2.2 安装 Spark2.3 配置 Spark2.4 创建配置文件 spark-env.sh 3. pyspark 启动4. 建立/user/spark文件夹 1. 下载安装 Scala 1.1 下载 Scala 安装包 下…

免费鼠标连点器有吗?需要付费吗?鼠标连点器电脑版免费推荐6款!

在数字化时代&#xff0c;鼠标连点器成为了许多用户提高工作效率、优化游戏体验的得力助手。然而&#xff0c;面对市场上琳琅满目的鼠标连点器软件&#xff0c;很多用户都会产生疑问&#xff1a;是否有免费的鼠标连点器&#xff1f;它们真的需要付费吗&#xff1f;今天&#xf…

nuxt、vue树形图d3.js

直接上代码 //安装 npm i d3 --save<template><div class"d3"><div :id"id" class"d3-content"></div></div> </template> <script> import * as d3 from "d3";export default {props: {d…

【康复学习--LeetCode每日一题】3033. 修改矩阵

题目&#xff1a; 给你一个下标从 0 开始、大小为 m x n 的整数矩阵 matrix &#xff0c;新建一个下标从 0 开始、名为 answer 的矩阵。使 answer 与 matrix 相等&#xff0c;接着将其中每个值为 -1 的元素替换为所在列的 最大 元素。 返回矩阵 answer 。 示例 1&#xff1a;…

基于.NET开源游戏框架MonoGame实现的开源项目合集

前言 今天分享一些基于.NET开源游戏框架MonoGame实现的开源项目合集。 MonoGame项目介绍 MonoGame是一个简单而强大的.NET框架&#xff0c;使用C#编程语言可以创建桌面PC、视频游戏机和移动设备游戏。它已成功用于创建《怒之铁拳4》、《食肉者》、《超凡蜘蛛侠》、《星露谷物…