pytorch U²-Net教程

U²-Net (U2-Net) 是一个用于图像分割的神经网络模型,特别擅长于边界复杂的物体分割任务,如前景背景分割和抠图。U²-Net 的独特之处在于其 U 形结构和嵌套 U 形块,能够有效捕捉不同尺度的特征,同时保持较小的模型大小。它非常适合在资源受限的环境下使用。

官方文档链接

U²-Net 本身并没有一个独立的 Python 库,但可以通过 官方 GitHub 仓库 获取源码和模型细节。


一、U²-Net 架构概述

U²-Net 是基于 U-Net 结构的改进模型,由多个嵌套的 U 形编码器-解码器模块组成。其创新点在于 U2 模块,它在不同尺度上提取特征,增强了对边界信息的捕捉能力。

U²-Net 结构包含:

  1. 编码器(Encoder):使用多尺度卷积核提取图像的特征,逐渐压缩特征图尺寸。
  2. 解码器(Decoder):通过逐步上采样,恢复原始分辨率,同时结合编码器的跳跃连接。
  3. U2 模块:嵌套的 U 形块,能够同时处理不同分辨率的特征,从而保留高分辨率的局部细节和低分辨率的全局语义信息。

二、基础功能

在 U²-Net 中,通常的工作流程是加载预训练模型并对输入图像进行分割。U²-Net 最常见的任务是图像前景提取,比如抠图。

1. 加载 U²-Net 模型

从官方 GitHub 下载预训练模型权重,并通过 PyTorch 加载。

import torch
import torchvision.transforms as transforms
from PIL import Image
import numpy as np# 加载预训练的 U²-Net 模型
model = torch.load('u2net.pth')
model.eval()  # 设置为评估模式# 准备图像输入
def load_image(image_path):transform = transforms.Compose([transforms.Resize((320, 320)),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])image = Image.open(image_path).convert('RGB')image = transform(image).unsqueeze(0)return image# 加载图片并转换为张量
input_image = load_image("input_image.jpg")# 前向传播,生成分割结果
with torch.no_grad():result = model(input_image)

2. 处理模型输出

U²-Net 的输出通常为前景掩码 (mask),可以通过阈值处理生成二值化图像。

def process_output(output):# 提取前景掩码mask = output[0][0].squeeze().cpu().numpy()# 归一化到0-1范围mask = (mask - np.min(mask)) / (np.max(mask) - np.min(mask))# 二值化处理mask = (mask > 0.5).astype(np.uint8)return mask# 处理输出的前景掩码
foreground_mask = process_output(result)

三、进阶功能

1. 前景提取并保存透明 PNG

U²-Net 可以用于精细化的图像前景提取。通过将背景像素设置为透明,生成透明的 PNG 图片。

from PIL import Imagedef save_foreground(image_path, mask, save_path):image = Image.open(image_path).convert('RGBA')width, height = image.sizemask = Image.fromarray(mask * 255).resize((width, height), Image.BILINEAR)# 转换为 RGBA 格式,将背景设置为透明image_data = np.array(image)mask_data = np.array(mask)# 将背景区域的 alpha 通道设置为 0(完全透明)image_data[:, :, 3] = mask_data# 保存带有透明背景的 PNG 图片output_image = Image.fromarray(image_data)output_image.save(save_path)# 使用掩码提取前景并保存
save_foreground("input_image.jpg", foreground_mask, "output_image.png")

2. 使用其他输入尺寸

虽然 U²-Net 默认是使用 320x320 的输入尺寸,但它对不同的输入尺寸有一定的适应性。我们可以根据需要调整输入图像的大小。

# 自定义输入尺寸
def load_image_custom_size(image_path, size=(320, 320)):transform = transforms.Compose([transforms.Resize(size),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])image = Image.open(image_path).convert('RGB')image = transform(image).unsqueeze(0)return image# 调整输入图像尺寸
custom_size_image = load_image_custom_size("input_image.jpg", size=(512, 512))

四、高级教程

U²-Net 的高级用法可以结合其他深度学习框架或任务,例如对分割结果进行进一步的图像处理或增强。

1. 与 OpenCV 结合处理分割结果

可以利用 OpenCV 对分割后的图像进行一些后处理,例如边缘检测、轮廓提取等。

import cv2def process_with_opencv(mask):# 使用 OpenCV 检测轮廓contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)# 绘制轮廓contour_image = np.zeros_like(mask)cv2.drawContours(contour_image, contours, -1, (255), 2)return contour_image# 使用 OpenCV 处理分割结果
contour_image = process_with_opencv(foreground_mask)
cv2.imwrite("contour_image.png", contour_image)

2. 自定义损失函数与训练

如果需要训练自己的 U²-Net 模型,可以基于 Binary Cross Entropy 损失函数进行训练。以下是一个自定义损失函数的示例。

import torch.nn as nnclass U2NetLoss(nn.Module):def __init__(self):super(U2NetLoss, self).__init__()self.bce_loss = nn.BCELoss()def forward(self, d0, d1, d2, d3, d4, d5, d6, labels):# 对不同尺度的预测进行加权损失计算loss0 = self.bce_loss(d0, labels)loss1 = self.bce_loss(d1, labels)loss2 = self.bce_loss(d2, labels)loss3 = self.bce_loss(d3, labels)loss4 = self.bce_loss(d4, labels)loss5 = self.bce_loss(d5, labels)loss6 = self.bce_loss(d6, labels)return loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6

3. 模型优化与推理加速

U²-Net 的推理速度在某些情况下可能是瓶颈,尤其在移动端。可以通过模型量化、剪枝或者使用推理加速库(如 TensorRT)来提高效率。


五、总结

U²-Net 是一个轻量级、功能强大的模型,专注于高质量的前景分割任务。它具有以下特点:

  1. 多尺度特征捕捉:通过 U2 模块,U²-Net 能够捕捉到不同尺度的细节,适用于精细的边缘分割任务。
  2. 易于使用:通过 PyTorch 实现,能够轻松加载预训练模型并进行推理。
  3. 适应性强:U²-Net 适用于不同分辨率的输入图像,具有良好的推广性。

如果你有更多问题或需要代码测试,请随时告诉我!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1542617.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

顶点缓存对象(VBO)与顶点数组对象(VAO)

我们的顶点数组在CPU端的内存里是以数组的形式存在,想要GPU去绘制三角形,那么需要将这些数据传输给GPU。那这些数据在显存端是怎么存储的呢?VBO上场了,它代表GPU上的一段存储空间对象,表现为一个unsigned int类型的变量,GPU端内存对象的一个ID编号、地址、大小。一个VBO对…

Python爬虫之urllib模块详解

Python爬虫入门 此专栏为Python爬虫入门到进阶学习。 话不多说,直接开始吧。 urllib模块 Python中自带的一个基于爬虫的模块,其实这个模块都几乎没什么人用了,我就随便写写了。 - 作用:可以使用代码模拟浏览器发起请求。&…

基于python的文本聚类分析与可视化实现,使用kmeans聚类,手肘法分析

1、数据预处理 由于在数据分析之前数据集通常都存在数据重复、脏数据等问题,所以为了提高 数据分析结果的质量,在应用之前就必须对数据集进行数据预处理。数据预处理的方法通常有清洗、集成、转换、规约这四个方面,接下来详细介绍这对爬取…

leetcode第七题:字符反转

给你一个 32 位的有符号整数 x ,返回将 x 中的数字部分反转后的结果。 如果反转后整数超过 32 位的有符号整数的范围 [−231, 231 − 1] ,就返回 0。 假设环境不允许存储 64 位整数(有符号或无符号)。 示例 1: 输入…

分布式安装LNMP

目录 搭建LNMP架构 安装mysql 1.上传mysql软件包,关闭防火墙和核心防护 2.安装环境依赖包,桌面安装可能有自带的数据库除 3.配置软件模块 4.编译及安装 5.创建mysql用户 6.修改mysql 配置文件 7.更改mysql安装目录和配置文件的属主属组 8.设置…

认识结构体

目录 一.结构体类型的声明 1.结构的声明 2.定义结构体变量 3.结构体变量初始化 4.结构体的特殊声明 二.结构体对齐(重点难点) 1.结构体对齐规则 2.结构体对齐练习 (一)简单结构体对齐 (二)嵌套结构体对齐 3.为什么存在内存对齐 4.修改默认对齐数 三.结构体传参 1…

Object类代码结构

Object Object是所有类的父类。 方法结构如下 一些不知道的方法 private static native void registerNatives(); * JNI机制 * 这里定义了一个 native 方法 registerNatives(),它没有方法体。 * native 关键字表示这个方法的实现是由本地代码 * (通常…

【Pytorch】一文快速教你高效使用torch.no_grad()

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 博主简介 博主致力于嵌入式、Python、人工智能、C/C领域和各种前沿技术的优质博客分享,用最优质的内容带来最舒适的…

BERT的代码实现

目录 1.BERT的理论 2.代码实现 2.1构建输入数据格式 2.2定义BERT编码器的类 2.3BERT的两个任务 2.3.1任务一:Masked Language Modeling MLM掩蔽语言模型任务 2.3.2 任务二:next sentence prediction 3.整合代码 4.知识点个人理解 1.BERT的理论 B…

Linux 静态库与动态库的制作与使用

在Linux中,库library是一组函数和资源的集合,他们可以被不同的程序共享和使用,库的主要目的是代码重用,减少内存占用,并简化程序的维护。 Linux操作系统支持的函数库分为:静态库和动态库。 静态库&#xf…

【线程池】Tomcat线程池

版本:tomcat-embed-core-10.1.8.jar 前言 最近面试被问到 Tomcat 线程池,因为之前只看过 JDK 线程池,没啥头绪。在微服务横行的今天,确实还是有必要研究研究 Tomcat 的线程池 Tomcat 线程池和 JDK 线程池最大的不同就是它先把最…

二分+优先队列例题总结(icpc vp+牛客小白月赛)

题目 思路分析 要求输出最小的非负整数k,同时我们还要判断是否存在x让整个序列满足上述条件。 当k等于某个值时,我们可以得到x的一个取值区间,若所有元素得到的x的区间都有交集(重合)的话,那么说明存在x满足条件。因为b[i]的取值为1e9&…

Maven-一、分模块开发

Maven进阶 文章目录 Maven进阶前言创建新模块向新模块装入内容使用新模块把模块部署到本地仓库补充总结 前言 分模块开发可以把一个完整项目中的不同功能分为不同模块管理,然后模块间可以相互调用,该篇以一个SSM项目为目标展示如何使用maven分模块管理。…

没错,我给androidx修了一个bug!

不容易啊,必须先截图留恋😁 这个bug是发生在xml中给AppcompatTextView设置textFontWeight,但是却无法生效。修复bug的代码也很简单,总共就几行代码,但是在找引起这个bug的原因和后面给androidx提pr却花了很久。 //App…

云手机的海外原生IP有什么用?

在全球数字化进程不断加快的背景下,企业对网络的依赖程度日益加深。云手机作为一项创新的工具,正逐步成为企业优化网络结构和全球业务拓展的必备。尤其是云手机所具备的海外原生IP功能,为企业进入国际市场提供了独特的竞争优势。 什么是海外原…

DNF Decouple and Feedback Network for Seeing in the Dark

DNF: Decouple and Feedback Network for Seeing in the Dark 在深度学习领域,尤其是在低光照图像增强的应用中,RAW数据的独特属性展现出了巨大的潜力。然而,现有架构在单阶段和多阶段方法中都存在性能瓶颈。单阶段方法由于域歧义&#xff0c…

如何使用 3 种简单的方法将手写内容转换为文本

手写比文本更具艺术性,这就是许多人追求手写字体的原因。有时,我们必须将手写内容转换为文本,以便于存储和阅读。本文将指导您如何轻松转换它。 此外,通常以扫描的手写内容编辑文本很困难,但使用奇客免费OCR&#xff…

视觉距离与轴距离的转换方法

1.找一个明显的参照物,用上方固定的相机拍一下。保存好图片 2.轴用定长距离如1mm移动一下。 3.再用上相机再取一张图。 4.最后用halcon 将两图叠加 显示 效果如下 从图上可以明显的看出有两个图,红色标识的地方。 这时可以用halcon的工具画一个长方形…

Cesium 绘制可编辑点

Cesium Point点 实现可编辑的pointEntity 实体 文章目录 Cesium Point点前言一、使用步骤二、使用方法二、具体实现1. 开始绘制2.绘制事件监听三、 完整代码前言 支持 鼠标按下 拖动修改点,释放修改完成。 一、使用步骤 1、点击 按钮 开始 绘制,单击地图 绘制完成 2、编辑…

误差评估,均方误差、均方根误差、标准差、方差

均方根误差 RMSE/RMS 定义 RMSE是观察值与真实值偏差的平方,对于一组观测值 y i y_i yi​ 和对应的真值 t i t_i ti​ R M S E 1 n ∑ i 1 n ( y i − t i ) ,其中n是观测次数 RMSE\sqrt{\frac1n \sum_{i1}^n (y_i-t_i)} \text{,其中n是…