语义分割——U-Net

U-Net是继FCN之后又一个经典的语义分割网络模型,并且也是很多后续语义分割模型的“祖宗”。这个网络模型是2015年提出来的,它具有一个非常对称的结构,很像字母“U”,所以被称作U-Net。U-Net被广泛应用于医学影像领域,如分割器官、肿瘤、血管等。由于医学图像通常具有复杂的结构和较高的分辨率要求,U-Net 的结构特点使其能够有效地处理这些图像,并提供准确的分割结果,为医学诊断和治疗提供重要的支持,在遥感卫星分析和工业质量检测中也有广泛的应用。U-Net相比于FCN,它对于细节的把控能力更强。U-Net总体结构如下:

和FCN一样,它也分成了编码和解码模块两部分,编码模块也被称作下采样,解码模块也被称作上采样。

1.输入572x572,经过3x3卷积和ReLU,大小变为570x570。(卷积计算公示:output = (input-kernel_size+2*padding)/stride + 1),再做一次卷积和ReLU,大小变为568x568,通道数变成64;

2.经过最大池化层,长宽变为原来的一半,变成284x284,经过两次卷积和ReLU,大小变成280x280,通道数变成128;

3.再次经过最大池化层,长宽变为原来的一半,变成140下140,再经过两次卷积和ReLU,大小变成136x136,通道数加倍变成256;

4.再次经过最大池化层,长宽继续减半,变成68x68,再经过两次卷积和ReLU,大小变成64x64,通道数变成512;

5.再次经过最大池化层,长宽继续减半,变成32x32,再经过两次卷积和ReLU,大小变成28x28,通道数变成1024;

6.此处开始上采样。经过2x2上采样,长宽变成原来的两倍,56x56,同时通道数减半,变成512;

7.从下采样过程中的512x64x64结果中,裁剪512x56x56的部分,和上采样后得到的512x56x56结果,进行通道维度上的拼接,注意不是求和是拼接concat,得到1024x56x56的结果,再做一次卷积和ReLU,通道数减半,变成512x54x54,再做一次卷积和ReLU,这次通道数保持不变,大小变成了52x52,所以这一层最终输出512x52x52;

8.上采样,通道数进一步减小变成256x256,长宽再变大一倍,得到256x104x104,和下采样过程中通道是256x136x136的部分拼接,因为长宽不同,所以需要裁剪,将下采样过程中的结果裁剪成256x104x104,按通道维度拼接,得到512x104x104,再经过卷积和ReLU,通道数减半,变成256x102x102,再经过卷积和ReLU,得到256x100x100;

9.上采样,通道数进一步减半变成128x128,长宽再变大一倍,得到128x200x200,和下采样过程中通道是128的部分拼接,裁剪成200x200大小后,按通道维度拼接,得到256x200x200,经过卷积和ReLU,通道数减半,得到128x198x198,再经过一次卷积和ReLU,通道不变,得到128x196x196;

10.上采样,通道数进一步减半变成64x64,长宽再变大一倍,得到64x392x392,和下采样过程中通道是64的部分拼接,裁剪成392x392大小后,按通道维度拼接,得到128x392x392,经过卷积和ReLU,通道数减半,得到64x390x390,再经过一次卷积和ReLU,通道不变,得到64x388x388,最后经过一次1x1卷积,改变通道数为2,得到2x388x388,也就是最终的输出。论文最终输出是分成了两类。

不过这里有一个问题,就是输入图像的大小和输出的大小是不一致的。输入是572X572的图像,输出的结果是388X388的,当然在医学影像中可能是无所谓的,不过在大部分的语义分割场景,最好还是使得输入图像和输出结果的大小是一致的,大部分的语义分割数据集,如VOC2012的图像和标签大小也是一直的,所以我们最好使得输入和输出大小一致。

U-Net的简单实现如下:

import torch.nn as nn
import torchvision.models as models
import torchclass UNetEncoder(nn.Module):def __init__(self, pretrained=True) -> None:super(UNetEncoder, self).__init__()vgg = models.vgg16(pretrained=pretrained)self.encoder = nn.Sequential(*list(vgg.children())[:-2])self.layer1 = nn.Sequential(vgg.features[:4])self.layer2 = nn.Sequential(vgg.features[4:9])self.layer3 = nn.Sequential(vgg.features[9:16])self.layer4 = nn.Sequential(vgg.features[16:23])self.layer5 = nn.Sequential(vgg.features[23:-1])def forward(self, x):features = []x = self.layer1(x)print("x.shape: ",x.shape)features.append(x)x = self.layer2(x)print("x.shape: ",x.shape)features.append(x)x = self.layer3(x)print("x.shape: ",x.shape)features.append(x)x = self.layer4(x)print("x.shape: ",x.shape)features.append(x)x = self.layer5(x)print("x.shape: ",x.shape)features.append(x)return featuresclass UNetDecoder(nn.Module):def __init__(self, num_classes) -> None:super(UNetDecoder, self).__init__()# 定义解码器层self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)self.conv1 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1)self.conv2 = nn.Conv2d(in_channels=768, out_channels=256, kernel_size=3, padding=1)self.conv3 = nn.Conv2d(in_channels=384, out_channels=128, kernel_size=3, padding=1)self.conv4 = nn.Conv2d(in_channels=192, out_channels=64, kernel_size=3, padding=1)self.classifier = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)def forward(self, features):x5 = features[-1] # 最后一层编码的结果print("x5.shape: ",x5.shape)x4 = self.upsample(x5)x4 = torch.cat([x4, features[-2]], dim=1) # 根据通道维度进行合并,第0维是batch_sizex4 = self.conv1(x4)print("x4.shape: ",x4.shape)x3 = self.upsample(x4)x3 = torch.cat([x3, features[-3]], dim=1)x3 = self.conv2(x3)print("x3.shape: ",x3.shape)x2 = self.upsample(x3)x2 = torch.cat([x2, features[-4]], dim=1)x2 = self.conv3(x2)print("x2.shape: ",x2.shape)x1 = self.upsample(x2)x1 = torch.cat([x1, features[-5]], dim=1)x1 = self.conv4(x1)print("x1.shape: ",x1.shape)out = self.classifier(x1)return outclass UNet(nn.Module):def __init__(self, num_classes, pretrained=True):super(UNet, self).__init__()self.encoder = UNetEncoder(pretrained=pretrained)self.decoder = UNetDecoder(num_classes)def forward(self, x):features = self.encoder(x)out = self.decoder(features)return outif __name__ == '__main__':model = UNet(num_classes=21, pretrained=True)x = torch.randn(1, 3, 480, 320)out = model(x)print(out.shape)
# 输出:
x.shape:  torch.Size([1, 64, 480, 320])
x.shape:  torch.Size([1, 128, 240, 160])
x.shape:  torch.Size([1, 256, 120, 80])
x.shape:  torch.Size([1, 512, 60, 40])
x.shape:  torch.Size([1, 512, 30, 20]) 
x5.shape:  torch.Size([1, 512, 30, 20])
x4.shape:  torch.Size([1, 512, 60, 40])
x3.shape:  torch.Size([1, 256, 120, 80])
x2.shape:  torch.Size([1, 128, 240, 160])
x1.shape:  torch.Size([1, 64, 480, 320])
torch.Size([1, 21, 480, 320])

可以看到,输入是四维的数据,batch_size=1,三通道的彩色图像,宽480,高320的一幅图像,输出batch_size=1,通道数21,用来分类,宽480,高320的分割结果图像。

整个U-Net代码分为编码器Encoder和解码器Decoder。Encoder就是下采样模块,采用预训练的vgg模型的特征提取模块。Decoder就是上采样模块,把特征图不断恢复为原始图像大小,并在整个过程中,按通道维度拼接特征提取过程中的特征图。

根据输出可以看到图像的变换过程,如下图所示:

下面,我们来看看这个U-Net模型在GID和VOC2012两个数据集上的分割效果:

从GID的分割结果可以看到U-Net的分割结果更为精细,细节部分比FCN提高了不少。

在VOC2012数据集上分割效果还比较一般,只能分割出大致形态,细节还不够完善。

这里有一个要注意的问题,就是U-Net并不是任意大小的输入都可以运行的,对于一些长宽不符合要求的会报错。因为在下采样过程中要进行卷积运算,由于长宽不一定是偶数,可能造成图像长宽变化,上采样后无法和下采样过程中的特征图进行拼接,所以训练过程中最好对图像进行裁剪操作,裁剪成适合U-Net运行的大小,或者改进U-Net模型,使其可以兼容各种不同大小的数据。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/4186.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

AI之硬件对比:据传英伟达Nvidia2025年将推出RTX 5090-32GB/RTX 5080-24GB、华为2025年推出910C/910D

AI之硬件对比:据传英伟达Nvidia2025年将推出RTX 5090-32GB/RTX 5080-24GB、华为2025年推出910C/910D 目录 Nvidia的显卡 Nvidia的5090/5080/4090/4080:据传传英伟达Nvidia RTX 5090后续推出32GB版且RTX 5080后续或推出24GB版 RTX 5090相较于RTX 4090&…

Android无限层扩展多级recyclerview列表+实时搜索弹窗

业务逻辑: 点击选择,弹出弹窗,列表数据由后台提供,不限层级,可叠加无限层子级; 点击item展开收起,点击尾部icon单选选中,点击[确定]为最终选中,收起弹窗; 搜索…

SpringBoot+ClickHouse集成

前面已经完成ClickHouse的搭建&#xff0c;创建账号&#xff0c;创建数据库&#xff0c;保存数据库等&#xff0c;接下来就是在SpringBoot项目中集成ClickHouse。 一&#xff0c;引入依赖 <!-- SpringBoot集成ClickHouse --> <dependency><groupId>com.baom…

【基于轻量型架构的WEB开发】课程 12.5 数据回写 Java EE企业级应用开发教程 Spring+SpringMVC+MyBatis

12.5 数据回写 12.5.1 普通字符串的回写 接下来通过HttpServletResponse输出数据的案例&#xff0c;演示普通字符串的回写&#xff0c;案例具体实现步骤如下。 1 创建一个数据回写类DataController&#xff0c;在DataController类中定义 showDataByResponse()方法&#xff…

Java实现JWT登录认证

文章目录 什么是JWT?为什么需要令牌?如何实现?添加依赖&#xff1a;JwtUtils.java&#xff08;生成、解析Token的工具类&#xff09;jwt配置&#xff1a;登录业务逻辑&#xff1a;其他关联代码&#xff1a;测试&#xff1a; 什么是JWT? JWT&#xff08;Json Web Token&…

光伏无人机踏勘,照亮光伏未来!

光伏电站选址地分散在各地&#xff0c;想要精准获取该地的地形特点与屋顶面积等信息&#xff0c;传统的人工踏勘耗时耗力且精度无法保证&#xff0c;难以满足现代光伏项目的规模快发发展需求。光伏无人机踏勘&#xff0c;照亮光伏未来&#xff01; 在光伏无人机智能踏勘设计系统…

Vue全栈开发旅游网项目(7)-搜索界面开发及其接口联调

1.搜索界面开发 1.1 模糊查询 文件地址&#xff1a;pycharm- class SightListView(ListView):paginate_by 5def get_queryset(self):#is_validTrue&#xff1a;表中is_valid列&#xff0c;有值则被查询出来query Q(is_validTrue)#1.获得热门景点is_hot self.request.GET.…

『 Linux 』网络传输层 - TCP(二)

文章目录 TCP六个标志位TCP的连接三次握手 四次挥手为什么是三次握手和四次挥手 重传机制 TCP六个标志位 在TCP协议报文的报头中存在一个用于标志TCP报文类型的标志位(不考虑保留标志位),这些标志位以比特位选项的方式存在,即对应标志位为0则表示为假,对应标志位为1则为真; SYN…

Django学习-项目部署

WSGI定义&#xff1a; uWSGI定义&#xff1a; 安装uWSGI&#xff1a; 配置uWSGI&#xff1a; uWSGI常见问题汇总&#xff1a; 安装nginx&#xff1a; 配置&#xff1a; 启动/停止dnginx 修改uWSGI配置&#xff1a; 常见问题解决方法&#xff1a; nginx静态文件配置&#xff…

迅为RK3588开发板Android多屏显示之多屏同显和多屏异显

迅为RK3588开发板是一款低功耗、高性能的处理器&#xff0c;适用于基于arm的PC和Edge计算设备、个人移动互联网设备等数字多媒体应用&#xff0c;RK3588支持8K视频编解码&#xff0c;内置GPU可以完全兼容OpenGLES 1.1、2.0和3.2。RK3588引入了新一代完全基于硬件的最大4800万像…

QML项目实战:自定义Button

目录 一.添加模块 ​1.QtQuick.Controls 2.1 2.QtGraphicalEffects 1.12 二.自定义Button 1.颜色背景设置 2.设置渐变色背景 3.文本设置 4.点击设置 5.阴影设置 三.效果 1.当enabled为true 2.按钮被点击时 3.当enabled为false 四.代码 一.添加模块 1.QtQuick.Con…

基于C#实现Windows后台窗口操作与图像处理技术分析

在Windows编程中&#xff0c;操作后台窗口是一项复杂而有用的技术。它可以用来自动化用户界面测试、应用程序机器人等场景。本文将深入探讨如何在C#中绑定后台窗口、获取后台窗口界面图片&#xff0c;以及在图片中寻找指定图标并获取坐标。本技术文章结合最先进的资料与实践经验…

数据库基础(1) . 关系型数据库

1.数据库 database 1.1.数据持久化 数据持久化&#xff08;Data Persistence&#xff09;指的是将程序中的数据保存到某种持久化的存储介质&#xff08;如硬盘、SSD、磁带等&#xff09;上的过程&#xff0c;使得即使在程序终止后&#xff0c;数据依然可以被保留下来并在下次…

Python学习的自我理解和想法(27)

学的是b站的课程&#xff08;千锋教育&#xff09;&#xff0c;跟老师写程序&#xff0c;不是自创的代码&#xff01; 今天是学Python的第27天&#xff0c;学的内容是python操作pptx和pdf&#xff0c;但是这节博客只会介绍如何新建pptx和加密pdf。开学了&#xff0c;时间不多&…

鸿蒙移动应用开发-------初始arkts

一. 什么是arkts ArkTS是HarmonyOS优选的主力应用开发语言。 ArkTS围绕应用开发在TypeScript&#xff08;简称TS&#xff09;生态基础上做了进一步扩展&#xff0c;保持了TS的基本风格&#xff0c;同时通过规范定义强化开发期静态检查和分析&#xff0c;提升程序执行稳定性和…

Linux(CentOS)安装 JDK

1、下载 JDK 官网&#xff1a;https://www.oracle.com/ 2、上传 JDK 文件到 CentOS&#xff0c;使用FinalShell远程登录工具&#xff0c;并且使用 root 用户登录 3、解压 JDK 创建目录 /export/server mkdir -p /export/server 解压到目录 /export/server tar -zxvf jdk-17…

qt QStandardItemModel详解

1、概述 QStandardItemModel是Qt框架中提供的一个基于项的模型类&#xff0c;用于存储和管理数据&#xff0c;这些数据可以以表格的形式展示在视图控件&#xff08;如QTableView、QTreeView等&#xff09;中。QStandardItemModel支持丰富的数据操作&#xff0c;包括添加、删除…

SpringBoot框架在在线教育领域的应用

4系统概要设计 4.1概述 本系统采用B/S结构(Browser/Server,浏览器/服务器结构)和基于Web服务两种模式&#xff0c;是一个适用于Internet环境下的模型结构。只要用户能连上Internet,便可以在任何时间、任何地点使用。系统工作原理图如图4-1所示&#xff1a; 图4-1系统工作原理…

【论文分享】基于多源大数据的高密度城市健康资源可达性与公平性评价

评估城市健康设施的可达性和公平性对于有效配置城市健康资源至关重要。本次我们给大家带来一篇SCI论文的全文翻译。该论文从新的视角定义和分类城市中的健康相关设施&#xff0c;考虑居民的主动和被动健康寻求行为&#xff0c;构建一个综合性框架来评估健康设施的邻近性、互补性…