支持4K高分辨率,PixArt-Sigma最新文生图落地经验

PixArt-Sigma是由华为诺亚方舟实验室、大连理工大学和香港大学的研究人员共同开发的一个先进的文本到图像(Text-to-Image,T2I)生成模型。

PixArt-Sigma是在PixArt-alpha的基础上进一步改进的模型,旨在生成高质量的4K分辨率图像。

PixArt-Sigma通过整合高级元素和采用由弱到强式训练方法,这种策略有助于模型逐渐学习并优化图像细节,从而提高了生成图像的保真度和与文本提示的对齐程度。

PixArt-Sigma在美学质量上与当前顶级的文本到图像产品如DALL·E 3和Midjourney V6不相上下,并且在遵循文本提示方面表现出色。

PixArt-Sigma的生成能力支持高分辨率海报和壁纸的创作,有效支持电影和游戏等行业高质量视觉内容的制作。

github项目地址:https://github.com/PixArt-alpha/PixArt-sigma。

一、环境安装

1、python环境

建议安装python版本在3.10以上。

2、pip库安装

pip install torch==2.3.0+cu118 torchvision==0.18.0+cu118 torchaudio==2.3.0 --extra-index-url https://download.pytorch.org/whl/cu118

pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple

3、SDXL-VAE模型下载

git lfs install

git clone https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers

3、PixArt-Sigma模型下载

python tools/download.py

、功能测试

1、命令行运行测试

(1)python代码调用测试
 

import os
import re
import sys
import argparse
from datetime import datetime
from pathlib import Pathimport torch
from torch import nn
from torchvision.utils import save_image
from tqdm import tqdm
from diffusers.models import AutoencoderKL
from transformers import T5EncoderModel, T5Tokenizerfrom diffusion.model.utils import prepare_prompt_ar
from diffusion import IDDPM, DPMS, SASolverSampler
from diffusion.model.nets import PixArtMS_XL_2, PixArt_XL_2
from diffusion.data.datasets import get_chunks
import diffusion.data.datasets.utils as ds_utils
from tools.download import find_modelclass ImageGenerator:def __init__(self, args):self.args = argsself.device = "cuda" if torch.cuda.is_available() else "cpu"self.seed = args.seedself._set_env()self._load_model_components()def _set_env(self):torch.manual_seed(self.seed)torch.set_grad_enabled(False)for _ in range(30):torch.randn(1, 4, self.args.image_size, self.args.image_size)def _load_model_components(self):self.latent_size = self.args.image_size // 8self.max_sequence_length = {"alpha": 120, "sigma": 300}[self.args.version]self.pe_interpolation = self.args.image_size / 512self.micro_condition = self.args.version == 'alpha' and self.args.image_size == 1024self.sample_steps_dict = {'iddpm': 100, 'dpm-solver': 20, 'sa-solver': 25}self.sample_steps = self.args.step if self.args.step != -1 else self.sample_steps_dict[self.args.sampling_algo]self.weight_dtype = torch.float16self._load_main_model()self._load_vae()self._load_text_components()def _load_main_model(self):if self.args.image_size in [512, 1024, 2048] or self.args.version == 'sigma':self.model = PixArtMS_XL_2(input_size=self.latent_size,pe_interpolation=self.pe_interpolation,micro_condition=self.micro_condition,model_max_length=self.max_sequence_length,).to(self.device)else:self.model = PixArt_XL_2(input_size=self.latent_size,pe_interpolation=self.pe_interpolation,model_max_length=self.max_sequence_length,).to(self.device)print("Generating sample from ckpt: %s" % self.args.model_path)state_dict = find_model(self.args.model_path)state_dict['state_dict'].pop('pos_embed', None)missing, unexpected = self.model.load_state_dict(state_dict['state_dict'], strict=False)print('Missing keys: ', missing)print('Unexpected keys', unexpected)self.model.eval()self.model.to(self.weight_dtype)self.base_ratios = getattr(ds_utils, f'ASPECT_RATIO_{self.args.image_size}', ds_utils.ASPECT_RATIO_1024)def _load_vae(self):vae_path = "output/pretrained_models/sd-vae-ft-ema" if self.args.sdvae else f"{self.args.pipeline_load_from}/vae"self.vae = AutoencoderKL.from_pretrained(vae_path).to(self.device).to(self.weight_dtype)def _load_text_components(self):self.tokenizer = T5Tokenizer.from_pretrained(self.args.pipeline_load_from, subfolder="tokenizer")self.text_encoder = T5EncoderModel.from_pretrained(self.args.pipeline_load_from, subfolder="text_encoder").to(self.device)null_caption_token = self.tokenizer("", max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt").to(self.device)self.null_caption_embs = self.text_encoder(null_caption_token.input_ids, attention_mask=null_caption_token.attention_mask)[0]def generate_images(self, items: list):save_root = self._prepare_save_directory()self._visualize(items, save_root)def _prepare_save_directory(self):work_dir = 'output'try:epoch_name = re.search(r'.*epoch_(\d+).*', self.args.model_path).group(1)step_name = re.search(r'.*step_(\d+).*', self.args.model_path).group(1)except:epoch_name = 'unknown'step_name = 'unknown'img_save_dir = os.path.join(work_dir, 'vis')os.umask(0o000)  # file permission: 666; dir permission: 777os.makedirs(img_save_dir, exist_ok=True)save_root = os.path.join(img_save_dir, f"{datetime.now().date()}_{self.args.dataset}_epoch{epoch_name}_step{step_name}_scale{self.args.cfg_scale}_step{self.sample_steps}_size{self.args.image_size}_bs{self.args.bs}_samp{self.args.sampling_algo}_seed{self.seed}")print("save_root: ", save_root)os.makedirs(save_root, exist_ok=True)return save_root@torch.inference_mode()def _visualize(self, items, save_root):for chunk in tqdm(list(get_chunks(items, self.args.bs)), unit='batch'):prompts, hw, ar = self._prepare_prompts_and_configs(chunk)caption_embs, emb_masks, null_y = self._get_text_embeddings(prompts)with torch.no_grad():samples = self._run_sampling(hw, ar, caption_embs, emb_masks, null_y)self._save_images(samples, save_root)def _prepare_prompts_and_configs(self, chunk):prompts = []if self.args.bs == 1:timestamp = datetime.now().strftime("%Y%m%d%H%M%S")save_path = os.path.join(save_root, f"{timestamp}.jpg")if os.path.exists(save_path):returnprompt_clean, _, hw, ar, custom_hw = prepare_prompt_ar(chunk[0], self.base_ratios, device=self.device, show=False)latent_size_h, latent_size_w = int(hw[0, 0] // 8), int(hw[0, 1] // 8)prompts.append(prompt_clean.strip())else:hw = torch.tensor([[self.args.image_size, self.args.image_size]], dtype=torch.float, device=self.device).repeat(self.args.bs, 1)ar = torch.tensor([[1.]], device=self.device).repeat(self.args.bs, 1)for prompt in chunk:prompts.append(prepare_prompt_ar(prompt, self.base_ratios, device=self.device, show=False)[0].strip())latent_size_h, latent_size_w = self.latent_size, self.latent_sizereturn prompts, hw, ardef _get_text_embeddings(self, prompts):caption_token = self.tokenizer(prompts, max_length=self.max_sequence_length, padding="max_length", truncation=True, return_tensors="pt").to(self.device)caption_embs = self.text_encoder(caption_token.input_ids, attention_mask=caption_token.attention_mask)[0]emb_masks = caption_token.attention_maskcaption_embs = caption_embs[:, None]null_y = self.null_caption_embs.repeat(len(prompts), 1, 1)[:, None]print(f'finish embedding')return caption_embs, emb_masks, null_ydef _run_sampling(self, hw, ar, caption_embs, emb_masks, null_y):model_kwargs = dict(data_info={'img_hw': hw, 'aspect_ratio': ar}, mask=emb_masks)if self.args.sampling_algo == 'iddpm':z = torch.randn(len(prompts), 4, latent_size_h, latent_size_w, device=self.device).repeat(2, 1, 1, 1)model_kwargs['y'] = torch.cat([caption_embs, null_y])model_kwargs['cfg_scale'] = self.args.cfg_scalediffusion = IDDPM(str(self.sample_steps))samples = diffusion.p_sample_loop(self.model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True,device=self.device)samples, _ = samples.chunk(2, dim=0)elif self.args.sampling_algo == 'dpm-solver':z = torch.randn(len(prompts), 4, latent_size_h, latent_size_w, device=self.device)dpm_solver = DPMS(self.model.forward_with_dpmsolver,condition=caption_embs,uncondition=null_y,cfg_scale=self.args.cfg_scale,model_kwargs=model_kwargs)samples = dpm_solver.sample(z,steps=self.sample_steps,order=2,skip_type="time_uniform",method="multistep",)elif self.args.sampling_algo == 'sa-solver':sa_solver = SASolverSampler(self.model.forward_with_dpmsolver, device=self.device)samples = sa_solver.sample(S=25,batch_size=len(prompts),shape=(4, latent_size_h, latent_size_w),eta=1,conditioning=caption_embs,unconditional_conditioning=null_y,unconditional_guidance_scale=self.args.cfg_scale,model_kwargs=model_kwargs,)[0]samples = samples.to(self.weight_dtype)samples = self.vae.decode(samples / self.vae.config.scaling_factor).sampletorch.cuda.empty_cache()return samplesdef _save_images(self, samples, save_root):os.umask(0o000)  # file permission: 666; dir permission: 777for i, sample in enumerate(samples):timestamp = datetime.now().strftime("%Y%m%d%H%M%S")save_path = os.path.join(save_root, f"{timestamp}.jpg")print("Saving path: ", save_path)save_image(sample, save_path, nrow=1, normalize=True, value_range=(-1, 1))def get_args():parser = argparse.ArgumentParser()parser.add_argument('--image_size', default=1024, type=int)parser.add_argument('--version', default='sigma', type=str)parser.add_argument("--pipeline_load_from", default='PixArt-sigma-model/pixart_sigma_sdxlvae_T5_diffusers',type=str, help="Download for loading text_encoder, ""tokenizer and vae from https://huggingface.co/PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers")parser.add_argument('--txt_file', default='asset/test.txt', type=str)parser.add_argument('--model_path', default='PixArt-sigma-model/PixArt-Sigma-XL-2-1024-MS.pth', type=str)parser.add_argument('--sdvae', action='store_true', help='sd vae')parser.add_argument('--bs', default=1, type=int)parser.add_argument('--cfg_scale', default=4.5, type=float)parser.add_argument('--sampling_algo', default='dpm-solver', type=str, choices=['iddpm', 'dpm-solver', 'sa-solver'])parser.add_argument('--seed', default=0, type=int)parser.add_argument('--dataset', default='custom', type=str)parser.add_argument('--step', default=-1, type=int)parser.add_argument('--save_name', default='test_sample', type=str)return parser.parse_args()if __name__ == '__main__':args = get_args()generator = ImageGenerator(args)with open(args.txt_file, 'r') as f:items = [item.strip() for item in f.readlines()]generator.generate_images(items)

(2)web端测试

未完......

更多详细的内容欢迎关注:杰哥新技术
 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1488327.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Mongodb文档和数组的通配符索引

学习mongodb,体会mongodb的每一个使用细节,欢迎阅读威赞的文章。这是威赞发布的第97篇mongodb技术文章,欢迎浏览本专栏威赞发布的其他文章。如果您认为我的文章对您有帮助或者解决您的问题,欢迎在文章下面点个赞,或者关…

老板电器发布首个烹饪AI模型,揭秘其如何引领厨电行业变革

数字发展日新月异,智慧产品迭代更新。当前,我们或许正身处一场连科学巨人也无法预见的深度变革之中。现代科技使得普通人无需深入学习数学或编程知识,也能借助手机或电脑,体验“苏格拉底式”的在线指导,或者与“乔布斯…

【LeetCode、牛客】链表分割、链表的回文结构、160.相交链表

Hi~!这里是奋斗的明志,很荣幸您能阅读我的文章,诚请评论指点,欢迎欢迎 ~~ 🌱🌱个人主页:奋斗的明志 🌱🌱所属专栏:数据结构 📚本系列文章为个人学…

面试经典 114. 二叉树展开为链表

最近工作越来越难找,裁员越来越懂了,焦虑的睡不着,怎么办呢,只能刷面试题,卷死你们 今天这个题目没刷过,我思考了半天才只能用暴力,后来苦思冥想才想出来简单的方法,废话不多说&…

【音视频】RTSP、RTMP与流式传输

文章目录 前言RTSP与RTMPRTSP(Real-Time Streaming Protocol)RTMP(Real-Time Messaging Protocol)主要差异 什么是流式传输?流式传输的特点流式传输与传统下载的区别 使用VLC播放RTSP监控 总结 前言 在现代网络环境中…

uni-app声生命周期

应用的生命周期函数在App.vue页面 onLaunch:当uni-app初始化完成时触发(全局触发一次) onShow:当uni-app启动,或从后台进入前台时显示 onHide:当uni-app从前台进入后台 onError:当uni-app报错时触发,异常信息为err 页面的生命周期 onLoad…

html+css+js前端作业 王者荣耀官网5个页面带js

htmlcssjs前端作业 王者荣耀官网5个页面带js 下载地址 https://download.csdn.net/download/qq_42431718/89574989 目录1 目录2 目录3 项目视频 王者荣耀5个页面(带js) 页面1 页面2 页面3 页面4 页面5

四步实现网站HTTPS访问

随着网络安全的重要性日益凸显,HTTPS(超文本传输安全协议)已成为现代网站的标准配置。HTTPS协议作为HTTP协议的安全版本,通过SSL协议加密数据传输,不仅能保护用户数据的安全,还能提升搜索引擎排名&#xff…

07-workqueue

想系统学习k8s源码,云原生的可以加:mkjnnm 今天我们来详细研究下 workqueue 相关代码。client-go 的 util/workqueue 包里主要有三个队列,分别是普通队列,延时队列,限速队列,后一个队列以前一个队列的实现为…

Java基础巩固——JDK 8、9新增接口的特性(接口中定义非抽象方法、静态方法和私有方法)

#Java学了这么久,项目也做了?基础知识还不巩固?快来关注我的这篇系列博客——Java基础复习巩固吧# 目录 引言 一、JDK8新特性:允许在接口中定义非抽象方法和静态方法。 注意事项 二、JDK9新特性:允许在接口中定义p…

“科技创新‘圳’在变革”2025深圳电子展

电子产业作为现代社会的核心驱动力之一,正以前所未有的速度发展。在这样的背景下,深圳作为中国的经济特区和创新高地,又一次迎来了备受瞩目的盛会——2025深圳电子展览会。本次展览会定于2025年4月9日至11日,在深圳会展中心&#…

Photos框架 - 自定义媒体资源选择器(数据部分)

引言 在iOS开发中,系统已经为我们提供了多种便捷的媒体资源选择方式,如UIImagePickerController和PHPickerViewController。这些方式不仅使用方便、界面友好,而且我们完全不需要担心性能和稳定性问题,因为它们是由系统提供的&…

Java Selenium WebDriver:代理设置与图像捕获

在网络爬虫和自动化测试领域,Selenium WebDriver 是一个非常流行的工具,它允许开发者模拟用户在浏览器中的操作。然而,出于安全或隐私的考虑,有时我们需要通过代理服务器来发送请求。本文将介绍如何在Java环境中使用Selenium WebD…

MSQP Mysql数据库权限提升工具,UDF自动检测+快速反向SHELL

项目地址:https://github.com/MartinxMax/MSQP MSQP 这是一个关于Mysql的权限提升工具 安装依赖 $ python3 -m pip install mysql-connector-python 使用方法 $ python3 msqp.py -h 权限提升:建立反向Shell 在建立反向连接前,该工具会自动检测是否具有提权条件&#xff0…

01。配置DevEcoStudio的中文界面方法

打开项目 点击File >> 点击Setting (或者按快捷键 Ctrl alt S) 选择 Plugins (扩展)>> 输入 chinese >>点击 Enable 点击 apply >OK 弹出窗口点击 Restart finish(完成)hiahia…

文件共享功能无法使用提示错误代码0x80004005【笔记】

环境情况: 其他电脑可以正常访问共享端,但有一台电脑访问提示错误代码0x80004005。 处理检查: 搜索里输入“启用或关闭Windows功能”按回车键,在“启用或关闭Windows功能”里将“SMB 1.0/CIFS文件共享支持”勾选后(故…

hipBLAS示例程序

GPT-4o (OpenAI) 当然!以下是一个简单示例,展示了如何使用hipBLAS库进行矩阵-向量乘法 (GEMV) 的操作。该示例包括初始化 hipBLAS 环境,设置矩阵和向量数据并调用hipBLAS API来执行操作。 首先,确保你已经安装了 ROCm&#xff08…

【Web】LitCTF 2024 题解(全)

目录 浏览器也能套娃? 一个....池子? 高亮主题(划掉)背景查看器 百万美元的诱惑 SAS - Serializing Authentication exx 浏览器也能套娃? 随便试一试,一眼ssrf file:///flag直接读本地文件 一个....池子? {…

政安晨【零基础玩转各类开源AI项目】基于Ubuntu系统部署MimicMotion :利用可信度感知姿势指导生成高质量人体运动视频

目录 项目介绍 项目相关工作 图像/视频生成的扩散模型 姿势引导的人体动作转移 生成长视频 方法实践 与最先进方法的比较 消融研究 部署验证 1. 下载项目: 2. 建立环境 3. 下载参数模型 A. 下载 DWPose 预训练模型:dwpose B. 从 Huggingfa…

redis的使用场景

目录 1. 热点数据缓存 1.1 什么是缓存? 1.2 缓存的原理 1.3 什么样的数据适合放入缓存中 1.4 哪个组件可以作为缓存 1.5 java使用redis如何实现缓存功能 1.5.1 需要的依赖 1.5.2 配置文件 1.5.3 代码 1.5.4 发现 1.6 使用缓存注解完成缓存功能 2. 分布式锁…