Open-Sora代码详细解读(2):时空3D VAE

Diffusion Models视频生成

前言:目前开源的DiT视频生成模型不是很多,Open-Sora是开发者生态最好的一个,涵盖了DiT、时空DiT、3D VAE、Rectified Flow、因果卷积等Diffusion视频生成的经典知识点。本篇博客从Open-Sora的代码出发,深入解读背后的原理。

目录

3D VAE原理

代码剖析

2D VAE

时间VAE

因果3D卷积


3D VAE原理

之前绝大多数都是2D VAE,特别是SDXL的VAE相当好用,很多人都拿来直接用了。但是在DiT-based的模型中,时间序列上如果再不做压缩的话,就已经很难训得动了。因此非常有必要在时间序列上进行压缩,3D VAE应运而生。

Open-Sora的方案是在2D VAE的基础上,再添加一个时间VAE,相比于EasyAnimate 和 CogVideoX的方案的Full Attention 存在劣势,但是可以充分利用到2D VAE的权重,成本更低。

代码剖析

2D VAE

来自华为pixart sdxl vae:

    vae_2d = dict(type="VideoAutoencoderKL",from_pretrained="PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",subfolder="vae",micro_batch_size=micro_batch_size,local_files_only=local_files_only,)

时间VAE

    vae_temporal = dict(type="VAE_Temporal_SD",from_pretrained=None,)
@MODELS.register_module()
class VAE_Temporal(nn.Module):def __init__(self,in_out_channels=4,latent_embed_dim=4,embed_dim=4,filters=128,num_res_blocks=4,channel_multipliers=(1, 2, 2, 4),temporal_downsample=(True, True, False),num_groups=32,  # for nn.GroupNormactivation_fn="swish",):super().__init__()self.time_downsample_factor = 2 ** sum(temporal_downsample)# self.time_padding = self.time_downsample_factor - 1self.patch_size = (self.time_downsample_factor, 1, 1)self.out_channels = in_out_channels# NOTE: following MAGVIT, conv in bias=False in encoder first convself.encoder = Encoder(in_out_channels=in_out_channels,latent_embed_dim=latent_embed_dim * 2,filters=filters,num_res_blocks=num_res_blocks,channel_multipliers=channel_multipliers,temporal_downsample=temporal_downsample,num_groups=num_groups,  # for nn.GroupNormactivation_fn=activation_fn,)self.quant_conv = CausalConv3d(2 * latent_embed_dim, 2 * embed_dim, 1)self.post_quant_conv = CausalConv3d(embed_dim, latent_embed_dim, 1)self.decoder = Decoder(in_out_channels=in_out_channels,latent_embed_dim=latent_embed_dim,filters=filters,num_res_blocks=num_res_blocks,channel_multipliers=channel_multipliers,temporal_downsample=temporal_downsample,num_groups=num_groups,  # for nn.GroupNormactivation_fn=activation_fn,)def get_latent_size(self, input_size):latent_size = []for i in range(3):if input_size[i] is None:lsize = Noneelif i == 0:time_padding = (0if (input_size[i] % self.time_downsample_factor == 0)else self.time_downsample_factor - input_size[i] % self.time_downsample_factor)lsize = (input_size[i] + time_padding) // self.patch_size[i]else:lsize = input_size[i] // self.patch_size[i]latent_size.append(lsize)return latent_sizedef encode(self, x):time_padding = (0if (x.shape[2] % self.time_downsample_factor == 0)else self.time_downsample_factor - x.shape[2] % self.time_downsample_factor)x = pad_at_dim(x, (time_padding, 0), dim=2)encoded_feature = self.encoder(x)moments = self.quant_conv(encoded_feature).to(x.dtype)posterior = DiagonalGaussianDistribution(moments)return posteriordef decode(self, z, num_frames=None):time_padding = (0if (num_frames % self.time_downsample_factor == 0)else self.time_downsample_factor - num_frames % self.time_downsample_factor)z = self.post_quant_conv(z)x = self.decoder(z)x = x[:, :, time_padding:]return xdef forward(self, x, sample_posterior=True):posterior = self.encode(x)if sample_posterior:z = posterior.sample()else:z = posterior.mode()recon_video = self.decode(z, num_frames=x.shape[2])return recon_video, posterior, z

因果3D卷积

class CausalConv3d(nn.Module):def __init__(self,chan_in,chan_out,kernel_size: Union[int, Tuple[int, int, int]],pad_mode="constant",strides=None,  # allow custom stride**kwargs,):super().__init__()kernel_size = cast_tuple(kernel_size, 3)time_kernel_size, height_kernel_size, width_kernel_size = kernel_sizeassert is_odd(height_kernel_size) and is_odd(width_kernel_size)dilation = kwargs.pop("dilation", 1)stride = strides[0] if strides is not None else kwargs.pop("stride", 1)self.pad_mode = pad_modetime_pad = dilation * (time_kernel_size - 1) + (1 - stride)height_pad = height_kernel_size // 2width_pad = width_kernel_size // 2self.time_pad = time_padself.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)stride = strides if strides is not None else (stride, 1, 1)dilation = (dilation, 1, 1)self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)def forward(self, x):x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)x = self.conv(x)return x

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1534903.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

嵌入式软件黑盒测试技术与案例分析培训

黑盒测试,也称为基于需求的测试,是目前嵌入式软件领域普遍开展的一种测试过程。目前,随着人们对软件质量要求的不断提升,行业对软件测试和验证的要求也在不断提高,对测试的充分性和准确性要求越来越苛刻。当前行业内&a…

物联网平台架构图

在数字化时代,物联网(IoT)正逐渐成为连接物理世界与数字世界的桥梁。物联网架构,作为这一桥梁的核心,是一个多层次、分布式的网络系统,它通过将各种物理设备与传感器连接到互联网上,实现设备之间…

GLSL 棋盘shader

今日永杰开金 float size 100.;vec2 checkerboard mod(floor(gl_FragCoord.xy / size), 2.);float c mod(checkerboard.x checkerboard.y, 2.);gl_FragColor vec4(vec3(c), 1);或 vec2 uv floor(S * p.xy * vec2(iResolution.x / iResolution.y, 1) / iResolution.xy); …

华为SMU02B1管理模块WEB登录与账户密码信息

1、将电脑的IP地址与SMU02B1的IP地址配置在同一个网段中。例如,如果监控的IP地址为192.168.0.11,子网掩码为255.255.255.0,默认网关为192.168.0.1,则电脑的IP地址设置成192.168.0.12,子网掩码设置成255.255.255.0&…

Python+Pytest框架,“conftest.py文件编写如何获取token和获取日志“?

1、新增"conftest.py" import pytest import loggingfrom api_keyword.api_key import ApiKey from config import *# 获取token # 1. 正常的请求对应的接口并且提取数据 # 2. pytest.fixture()测试夹具(测试前置、后置操作)pytest.fixture(s…

ESP32开发 -- VSCODE+PlatformIO环境安装

参看官网安装:PlatformIO IDE for VSCode 一、安装PlatformIO IDE 参看:日常生活小技巧 – Visual Studio Code 简单使用 扩展中搜索platformIO IDE 当安装完提示重启之后。 打开一个要创建新工程的文件夹: 点击 Create New Project&…

【高等数学学习记录】函数

【高等数学&学习记录】函数 从事测绘工作多年,深刻感受到基础知识的重要及自身在这方面的短板。 为此,打算重温测绘工作所需基础知识。练好基本功,为测绘工作赋能。 1 知识点 1.1 函数 设数集 D ⊂ R D\subset R D⊂R,称映射…

java开发中间件学习记录(持续更新中~)

1 Redis 2JVM 3 java基础底层 4Mysql 5 spring 6 微服务 7.......(持续更新) One:Redis篇 1:Redis 1.穿透 1.1缓存穿透 1.1.1布隆过滤器 1.2缓存击穿 2:击穿 1.3:缓存雪崩 1.4:双写一致 1.5.持久化(RDB,AOF) 1.6…

电脑桌面数据误删如何恢复?提供一份实用指南

电脑桌面作为我们工作和学习的主要界面,存放着大量重要的文件。一旦这些数据不慎被删除,不仅会影响我们的工作效率,还可能造成无法挽回的损失。幸运的是,通过一些有效的方法,我们有机会恢复这些误删的桌面数据。本文将…

Leetcode面试经典150题-79.搜索单词

题目比较简单,回溯最基础的题,记得除非覆盖,否则一定要恢复现场就行 解法都在代码里,不懂就留言或者私信 class Solution {public boolean exist(char[][] board, String word) {int m board.length; int n board[0].length;i…

AI周报(9.8-9.14)

AI应用-NEKO Health用AI颠覆体检 Neko Health 由 Spotify 创始人丹尼尔埃克和哈亚尔马尔尼尔森共同创立,致力于通过每年的全身扫描和由 AI 驱动的洞察力来改善预防性医疗保健,能够检测诸如心脏病和皮肤癌等疾病。 该公司通过使用人工智能软件支持的全身…

基于Python的量化交易回测框架Backtrader初识记录(二)

版权声明:本文为博主原创文章,如需转载请贴上原博文链接:基于Python的量化交易回测框架Backtrader初识记录(二)-CSDN博客 前言:在上一篇文章 基于Python的量化交易回测框架Backtrader初识记录(一…

转置卷积与反卷积的区分

transposed convolution(转置卷积)和deconvolution(反卷积)是两个完全不同的概念。 deconvolution为“inverse of convolution”、“inverse filter”,翻译为反卷积、解卷积。在信号处理中,反卷积是指从卷积…

Golang协程泄漏定位和排查

Golang协程泄漏定位和排查 1 场景:无缓冲channel写阻塞2 排查和定位思路2.1 Golang pprof2.2 协程数监控2.3 操作系统内存泄漏 参考 1 场景:无缓冲channel写阻塞 package mainimport ("log""net/http"_ "net/http/pprof"…

JavaScript --函数的作用域(全局和局部)

全局作用域 全局作用域&#xff0c;就算不在一个script标签也能调用 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta nam…

【win工具】win安装flameshot并设置截图快捷键

1.下载flameshot软件2.windows端配置flameshot快捷键3.取消win自带截图快捷键 1.下载flameshot软件 https://flameshot.org/#download installer版本为安装包 portable版本为免安装版 2.windows端配置flameshot快捷键 https://cloud.tencent.com/developer/article/2114952 W…

第三方软件测评机构分享:软件性能测试的测试方法和内容

软件性能测试是对软件系统在特定负载和条件下的性能进行评估的过程。它旨在确定软件的响应时间、稳定性、资源消耗及其可扩展性&#xff0c;以确保其在实际环境中能够满足用户的需求。通过性能测试&#xff0c;开发团队能够发现潜在的瓶颈问题&#xff0c;优化应用程序架构&…

Spring3-IoC1-IoC容器、基于xml管理bean

目录 IoC容器 概述 依赖注入 IoC容器在Spring中的实现 基于xml管理bean 获取bean 依赖注入 setter注入 构造器注入 特殊值处理 字面量赋值 null值 xml实体 CDATA节 特殊类型属性注入 对象类型属性注入 方式一&#xff1a;引用外部bean 方式二&#xff1a;内部…

基于OpenSSL的密码管理系统-应用密码学课程报告

第1章 概要设计 1.1 设计目的 本研究旨在设计并实现一个基于OpenSSL的密码管理系统&#xff0c;该系统具备密钥对的生成、密钥上传、密钥的核对、身份认证、文件与邮件的加密和解密、数字签名及数字证书管理等常用功能。研究的意义主要体现在以下几个方面&#xff1a; 提升网…

M3U8是什么,如何解析下载

M3U8是什么&#xff1f;如何解析下载 M3U8是苹果公司推出的视频播放标准&#xff0c;准确来说是一种索引文件&#xff0c;使用M3U8文件实际上是通过它来解析对应的放在服务器上的视频网络地址&#xff0c;从而实现在线播放。M3U8文件使用UTF-8字符编码。M3U8是一种常见的流媒体…