神经网络(五):U2Net图像分割网络

文章目录

  • 一、网络结构
    • 1.1第一种block结构
    • 1.2第二种block结构
    • 1.3特征图融合模块
    • 1.4损失函数
    • 1.5总体网络架构
    • 1.6代码汇总
    • 1.7普通残差块与RSU对比
  • 二、代码复现


  参考论文:U2-Net: Going deeper with nested U-structure for salient object detection
  这篇文章基于显著目标检测任务提出,显著目标检测是指将图像中最吸引人的目标或者区域分割出来,因此只有前景和背景两个类别,相当于语义分割中的二分类任务。例如,下图中展示了三张图片进行显著目标检测后的结果:
在这里插入图片描述
其中,白色区域代表前景 ,即最吸引人的目标或区域,而黑色区域代表背景。
  在 U 2 N e t U^2Net U2Net被提出前,显著目标检测领域主要面临两个问题:

  • 1.现有的SOD网络大多基于当时已有的网络架构(主干网络)进行深度特征的提取,如 A l e x N e t 、 V G G 、 R e s N e t 、 R e s N e X t AlexNet、VGG、ResNet、ResNeXt AlexNetVGGResNetResNeXt等。这些网络最初是为图像分类设计的,它们提取代表语义含义的特征(如某一具体的猫、狗),而非局部细节和全局对比度信息(SOD的主要任务是将图像划分为前景与背景,而不注重某一具体特征的提取),使得这些模型在进行SOD任务时效率低下。
  • 2.当时的SOD网络架构不断通过向现有的主干网络中添加特征聚合模块以提取多级显著特征,使得模型过于复杂。且这些图像分类模型往往通常通过牺牲特征图的高分辨率来实现更深的架构,即特征图在早期阶段会被缩小到较低的分辨率,如ResNet和DenseNet会使用步长为2的卷积和步长为2的最大池化将特征图大小减小到输入图的四分之一。但是,高分辨率在图像分割中有着重要作用,这也使得这些模型并不适用SOD任务。

  为解决上述问题,提出了 U 2 N e t U^2Net U2Net网络结构:

  • 一种两级嵌套的U形结构,专为SOD 设计,无需使用任何来自图像分类的预训练主干。
  • 提出U型残差块RSU,能在不降低特征图分辨率的情况下提取阶段内多尺度特征。

一、网络结构

   U 2 − N e t U^2-Net U2Net网络基于 U N e t UNet UNet网络设计而来。事实上,该网络的整体结构与 U N e t UNet UNet网络几乎相同,但所使用的上采样、下采样模块变成了小型的 U N e t UNet UNet网络,即, U 2 N e t U^2Net U2Net本质上是 U N e t UNet UNet网络的嵌套。而 U 2 − N e t U^2-Net U2Net网络的核心就是这些作为模块的小型 U N e t UNet UNet网络,并将其起名为 R e S i d u a l U − b l o c k ( R S U ,残差 U 块 ) ReSidual U-block(RSU,残差U块) ReSidualUblock(RSU,残差U)。网络结构如下:
在这里插入图片描述
这些模块其实可以分为两种, E n c o d e r 1 E n c o d e r 4 、 D e c o d e r 1 D e c o d e r 4 Encoder1~Encoder4、Decoder1~Decoder4 Encoder1 Encoder4Decoder1 Decoder4采用的是同一种结构的残差块,只不过深度不同,而Encoder5、Encoder6、Decoder5 采用的是另一种结构的残差块。整体流程可概况为:

  • Encoder阶段:每通过一个模块后都会下采样两倍,使用的是torch.nn.MaxPool2d
  • Decoder阶段:每通过一个模块后都会上采用两倍,使用的是torch.nn.functional.interpolate()
  • 跳跃链接:与 U N e t UNet UNet网络思路相同,将编码器的输出与解码器输出的特征图进行拼接,最后得到分割后的图像。

1.1第一种block结构

  本地和全局上下文信息对于显著对象检测和图像分割任务都非常重要,现代CNN网络设计中VGG、ResNet、DenseNet 等,一般使用1x1或3x3的小型卷积核提取特征。但在SOD任务中,由于它们的感受野太小而无法捕捉全局信息,使得浅层的输出特征图仅包含局部特征。在下图(图 ( a ) − ( c ) (a)-(c) (a)(c))中给出了具有小感受野的典型现有卷积块。为从浅层获得高分辨率特征图的更多全局信息,最直接的想法是扩大感受野,图 ( e ) (e) (e)是一种双向消息传递模块(见论文ieee),它试图通过使用扩张卷积扩大感受野来提取局部和非局部特征,以原始分辨率对输入特征图进行多次扩张卷积(尤其是在早期阶段)需要太多的计算和内存资源。
  为解决上述问题,本文提出了RSU模块(图 ( e ) (e) (e),L表示RSU的深度,图中L=7):

  • 一个输入卷积层,用于将尺寸为 ( H , W , C i n ) (H,W,C_{in}) (H,W,Cin)输入特征图x添加到中间映射 F 1 ( x ) F_1(x) F1(x)中,这是一个用于局部特征提取的普通卷积层,通道数为 C o u t C_{out} Cout
  • 高度为L的类似 U N e t UNet UNet的对称编码器-解码器结构,采用 F 1 ( x ) F_1(x) F1(x)作为输入,学习多尺度上下文信息。用 U ( x ) U(x) U(x)表示该类似 U N e t UNet UNet的结构,则提取的信息可表示为 U ( F 1 ( x ) ) U(F_1(x)) U(F1(x))。较大的 L 会生成更深的RSU、更多的池化操作、更大的感受野范围以及更丰富的局部和全局特征。配置此参数可以从具有任意空间分辨率的输入特征图中提取多尺度特征。从逐渐下采样的特征图中提取多尺度特征,并通过渐进式上采样、连接和卷积,将特征图还原为高分辨率特征图。此过程可减轻因使用大比例的直接上采样而导致的精细细节损失。
  • 通过求和融合局部特征和多尺度特征的残差连接: F 1 ( x ) + U ( F 1 ( x ) ) F_1(x)+U(F_1(x)) F1(x)+U(F1(x))

在这里插入图片描述

   R S U − 7 RSU-7 RSU7真实结构如下图所示:
在这里插入图片描述
  回到 U 2 − N e t U^2-Net U2Net结构,该RSU的使用场景有:

  • Encoder1 和 Decoder1 采用的是 RSU-7 结构。
  • Encoder2 和 Decoder2 采用的是 RSU-6 结构。
  • Encoder3 和 Decoder3 采用的是 RSU-5 结构。
  • Encoder4 和 Decoder4 采用的是 RSU-4 结构。

可见,相邻 block 相差一次下采样和一次上采样,例如 RSU-6 相比于 RSU-7 少了一个下采样卷积和上采样卷积部分,RSU-7 是下采样 32 倍和上采样 32 倍,RSU-6 是下采样 16 倍和上采样 16 倍。代码实现如下:


import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as Fclass REBNCONV(nn.Module):    #实现conv2d+BN+ReLU操作                                                      def __init__(self,in_ch=3,out_ch=3,dirate=1):super(REBNCONV,self).__init__()# dilation用于实现空洞卷积self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)self.bn_s1 = nn.BatchNorm2d(out_ch)self.relu_s1 = nn.ReLU(inplace=True)def forward(self,x):hx = xxout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))return xoutdef _upsample_like(src,tar):src = F.interpolate(src,size=tar.shape[2:],mode='bilinear',align_corners=True)     return src### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):                          #En_1   def __init__(self, in_ch=3, mid_ch=12, out_ch=3):super(RSU7,self).__init__()self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)              #CBR1self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)              #CBR2self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)           self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)             self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)def forward(self,x):hx = xhxin = self.rebnconvin(hx)hx1 = self.rebnconv1(hxin)hx = self.pool1(hx1)hx2 = self.rebnconv2(hx)hx = self.pool2(hx2)hx3 = self.rebnconv3(hx)hx = self.pool3(hx3)hx4 = self.rebnconv4(hx)hx = self.pool4(hx4)hx5 = self.rebnconv5(hx)hx = self.pool5(hx5)hx6 = self.rebnconv6(hx)hx7 = self.rebnconv7(hx6)                                  #实现残差连接hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1))hx6dup = _upsample_like(hx6d,hx5)hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1))hx5dup = _upsample_like(hx5d,hx4)hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))hx4dup = _upsample_like(hx4d,hx3)hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))hx3dup = _upsample_like(hx3d,hx2)hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))hx2dup = _upsample_like(hx2d,hx1)hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))return hx1d + hxin

在这里插入图片描述

1.2第二种block结构

  数据经过 E n 1 − E n 4 En_1-En4 En1En4下采样处理后对应特征图的分辨率就已经相对比较小了,如果再继续下采样就会丢失很多上下文信息。为保留上下文信息,在 E n c o d e r 5 、 E n c o d e r 6 、 D e c o d e r 5 Encoder5、Encoder6、Decoder5 Encoder5Encoder6Decoder5中将原始RSU中的上采样、下采样结构换成了空洞卷积操作,从而得到了 R S U − 4 F RSU-4F RSU4F,其中 F F F表示 R S U RSU RSU是扩张版本。此时 R S U − 4 F RSU-4F RSU4F的所有中间特征图都与其输入特征图具有相同的分辨率。在这里插入图片描述
需要注意,在 E n c o d e r 5 Encoder5 Encoder5中特征图大小已经到了18*18,非常小(也因此不需要再下采样),故采用了空洞卷积操作,目的在不改变特征图大小的情况下增大感受野。故在代码中使用了dalition=2、4、8 E n c o d e r 6 、 D e c o d e r 5 Encoder6、Decoder5 Encoder6Decoder5同理。这一特点在原图中显示为使用了虚线构成的长方体块。

import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as Fclass REBNCONV(nn.Module):                                                          #CBLdef __init__(self,in_ch=3,out_ch=3,dirate=1):super(REBNCONV,self).__init__()self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)self.bn_s1 = nn.BatchNorm2d(out_ch)self.relu_s1 = nn.ReLU(inplace=True)def forward(self,x):hx = xxout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))return xout### RSU-4F ###
class RSU4F(nn.Module):#UNet04FRES(nn.Module):def __init__(self, in_ch=3, mid_ch=12, out_ch=3):super(RSU4F,self).__init__()self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)def forward(self,x):hx = xhxin = self.rebnconvin(hx)hx1 = self.rebnconv1(hxin)hx2 = self.rebnconv2(hx1)hx3 = self.rebnconv3(hx2)hx4 = self.rebnconv4(hx3)hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))return hx1d + hxin

1.3特征图融合模块

  在通过编码、解码器的运算后,最后通过特征图融合模块(红框标出)将 D e 1 、 D e 2 、 D e 3 、 D e 4 、 D e 5 、 E n 6 De_1、De_2、De_3、De_4、De_5、En_6 De1De2De3De4De5En6模块的输出分别通过一个3x3的卷积层(卷积层的卷积核个数均为1),并通过双线性插值将得到的特征图还原回输入图像的大小,之后将得到的6个特征图进行拼接(Concatenation),最后再经过一个1x1的卷积层以及sigmoid激活函数,最终得到融合之后的图像。
在这里插入图片描述

1.4损失函数

   U 2 N e t U^2Net U2Net使用多监督算法构建损失函数。网络输出不仅仅包含最终特征图,还包含前面6个不同尺度的特征图,即,不仅要监督网络输出,还要监督中间融合特征图。 损失函数计算公式:
在这里插入图片描述
其中, M = 1 , 2 , 3 , . . . , 6 M=1,2,3,...,6 M=1,2,3,...,6 l s i d e m l_{side}^{m} lsidem表示特征图 S u p 1 、 S u p 2 、 . . . 、 S u p 6 Sup1、Sup2、...、Sup6 Sup1Sup2...Sup6的损失,而 l f u s e l_{fuse} lfuse表示最终特征图的损失, w w w则表示两种损失的权重参数(论文给出的源码中全为1)。 l s i d e l_{side} lside l f u s e l_{fuse} lfuse采用二值交叉熵(standard binary cross-entropy)进行计算:
在这里插入图片描述
其中, ( r , c ) (r,c) (r,c)表示像素坐标值, ( H , W ) (H,W) (H,W)表示图像高度和宽度, P G ( r , c ) P_{G(r,c)} PG(r,c)表示标签图像素灰度值, P S ( r , c ) P_{S(r,c)} PS(r,c)表示预测的图像素灰度值。

1.5总体网络架构

在这里插入图片描述
   U 2 N e t U^2Net U2Net主要由三部分组成:

  • 一个六级编码器:在 E n 1 、 E n 2 、 E n 3 、 E n 4 En_1、En_2、En_3、En_4 En1En2En3En4中分别使用 R S U 7 、 R S U 6 、 R S U 5 、 R S U 4 RSU7、RSU6、RSU5、RSU4 RSU7RSU6RSU5RSU4 L L L通常根据输入特征图的空间分辨率进行配置。对于高宽较大的特征图,使用更大的 L L L来捕获更多的大比例度信息。 E n 5 En_5 En5 E n 6 En_6 En6中特征图的分辨率相对较低,进一步降低这些特征图的采样会导致有用的上下文丢失。因此,在 E n 5 En_5 En5 E n 6 En_6 En6阶段,都使用 R S U − 4 F RSU-4F RSU4F,其中 F F F表示 R S U RSU RSU是扩张版本,其用膨胀卷积代替池化和上采样操作这意味着 R S U − 4 F RSU-4F RSU4F的所有中间特征图都与其输入特征图具有相同的分辨率。
  • 一个五级解码器: D e 5 De_5 De5阶段同样使用 R S U − 4 F RSU-4F RSU4F,并且每个解码器都使用前一阶段的上采样特征图和来自其对称编码器阶段的特征图的串联作为输入(跳跃连接)。
  • 特征图融合模块:将六个侧输出显著性特征图上采样到输入图像的尺寸,之后使用融合操作,并通过1x1卷积层和sigmoid函数生成最终的显著性特征图。

  研究中将3320320的图像裁剪为3288288大小输入模型,最终得到1288288的图像分割结果(二值图像):
在这里插入图片描述

1.6代码汇总

import torch
import torch.nn as nn
import torch.nn.functional as Fclass REBNCONV(nn.Module):def __init__(self,in_ch=3,out_ch=3,dirate=1):super(REBNCONV,self).__init__()self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate)self.bn_s1 = nn.BatchNorm2d(out_ch)self.relu_s1 = nn.ReLU(inplace=True)def forward(self,x):hx = xxout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))return xout## upsample tensor 'src' to have the same spatial size with tensor 'tar'
def _upsample_like(src,tar):src = F.upsample(src,size=tar.shape[2:],mode='bilinear')return src### RSU-7 ###
class RSU7(nn.Module):#UNet07DRES(nn.Module):def __init__(self, in_ch=3, mid_ch=12, out_ch=3):super(RSU7,self).__init__()self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1)self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2)self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)def forward(self,x):hx = xhxin = self.rebnconvin(hx)hx1 = self.rebnconv1(hxin)hx = self.pool1(hx1)hx2 = self.rebnconv2(hx)hx = self.pool2(hx2)hx3 = self.rebnconv3(hx)hx = self.pool3(hx3)hx4 = self.rebnconv4(hx)hx = self.pool4(hx4)hx5 = self.rebnconv5(hx)hx = self.pool5(hx5)hx6 = self.rebnconv6(hx)hx7 = self.rebnconv7(hx6)hx6d =  self.rebnconv6d(torch.cat((hx7,hx6),1))hx6dup = _upsample_like(hx6d,hx5)hx5d =  self.rebnconv5d(torch.cat((hx6dup,hx5),1))hx5dup = _upsample_like(hx5d,hx4)hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))hx4dup = _upsample_like(hx4d,hx3)hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))hx3dup = _upsample_like(hx3d,hx2)hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))hx2dup = _upsample_like(hx2d,hx1)hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))return hx1d + hxin### RSU-6 ###
class RSU6(nn.Module):#UNet06DRES(nn.Module):def __init__(self, in_ch=3, mid_ch=12, out_ch=3):super(RSU6,self).__init__()self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1)self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2)self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)def forward(self,x):hx = xhxin = self.rebnconvin(hx)hx1 = self.rebnconv1(hxin)hx = self.pool1(hx1)hx2 = self.rebnconv2(hx)hx = self.pool2(hx2)hx3 = self.rebnconv3(hx)hx = self.pool3(hx3)hx4 = self.rebnconv4(hx)hx = self.pool4(hx4)hx5 = self.rebnconv5(hx)hx6 = self.rebnconv6(hx5)hx5d =  self.rebnconv5d(torch.cat((hx6,hx5),1))hx5dup = _upsample_like(hx5d,hx4)hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1))hx4dup = _upsample_like(hx4d,hx3)hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))hx3dup = _upsample_like(hx3d,hx2)hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))hx2dup = _upsample_like(hx2d,hx1)hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))return hx1d + hxin### RSU-5 ###
class RSU5(nn.Module):#UNet05DRES(nn.Module):def __init__(self, in_ch=3, mid_ch=12, out_ch=3):super(RSU5,self).__init__()self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1)self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2)self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)def forward(self,x):hx = xhxin = self.rebnconvin(hx)hx1 = self.rebnconv1(hxin)hx = self.pool1(hx1)hx2 = self.rebnconv2(hx)hx = self.pool2(hx2)hx3 = self.rebnconv3(hx)hx = self.pool3(hx3)hx4 = self.rebnconv4(hx)hx5 = self.rebnconv5(hx4)hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1))hx4dup = _upsample_like(hx4d,hx3)hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1))hx3dup = _upsample_like(hx3d,hx2)hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))hx2dup = _upsample_like(hx2d,hx1)hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))return hx1d + hxin### RSU-4 ###
class RSU4(nn.Module):#UNet04DRES(nn.Module):def __init__(self, in_ch=3, mid_ch=12, out_ch=3):super(RSU4,self).__init__()self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1)self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1)self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2)self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1)self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)def forward(self,x):hx = xhxin = self.rebnconvin(hx)hx1 = self.rebnconv1(hxin)hx = self.pool1(hx1)hx2 = self.rebnconv2(hx)hx = self.pool2(hx2)hx3 = self.rebnconv3(hx)hx4 = self.rebnconv4(hx3)hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))hx3dup = _upsample_like(hx3d,hx2)hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1))hx2dup = _upsample_like(hx2d,hx1)hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1))return hx1d + hxin### RSU-4F ###
class RSU4F(nn.Module):#UNet04FRES(nn.Module):def __init__(self, in_ch=3, mid_ch=12, out_ch=3):super(RSU4F,self).__init__()self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1)self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1)self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2)self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4)self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8)self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4)self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2)self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1)def forward(self,x):hx = xhxin = self.rebnconvin(hx)hx1 = self.rebnconv1(hxin)hx2 = self.rebnconv2(hx1)hx3 = self.rebnconv3(hx2)hx4 = self.rebnconv4(hx3)hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1))hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1))hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1))return hx1d + hxin##### U^2-Net ####
class U2NET(nn.Module):def __init__(self,in_ch=3,out_ch=1):super(U2NET,self).__init__()self.stage1 = RSU7(in_ch,32,64)self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.stage2 = RSU6(64,32,128)self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.stage3 = RSU5(128,64,256)self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.stage4 = RSU4(256,128,512)self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.stage5 = RSU4F(512,256,512)self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True)self.stage6 = RSU4F(512,256,512)# decoderself.stage5d = RSU4F(1024,256,512)self.stage4d = RSU4(1024,128,256)self.stage3d = RSU5(512,64,128)self.stage2d = RSU6(256,32,64)self.stage1d = RSU7(128,16,64)self.side1 = nn.Conv2d(64,out_ch,3,padding=1)self.side2 = nn.Conv2d(64,out_ch,3,padding=1)self.side3 = nn.Conv2d(128,out_ch,3,padding=1)self.side4 = nn.Conv2d(256,out_ch,3,padding=1)self.side5 = nn.Conv2d(512,out_ch,3,padding=1)self.side6 = nn.Conv2d(512,out_ch,3,padding=1)self.outconv = nn.Conv2d(6*out_ch,out_ch,1)def forward(self,x):hx = x#stage 1hx1 = self.stage1(hx)hx = self.pool12(hx1)#stage 2hx2 = self.stage2(hx)hx = self.pool23(hx2)#stage 3hx3 = self.stage3(hx)hx = self.pool34(hx3)#stage 4hx4 = self.stage4(hx)hx = self.pool45(hx4)#stage 5hx5 = self.stage5(hx)hx = self.pool56(hx5)#stage 6hx6 = self.stage6(hx)hx6up = _upsample_like(hx6,hx5)#-------------------- decoder --------------------hx5d = self.stage5d(torch.cat((hx6up,hx5),1))hx5dup = _upsample_like(hx5d,hx4)hx4d = self.stage4d(torch.cat((hx5dup,hx4),1))hx4dup = _upsample_like(hx4d,hx3)hx3d = self.stage3d(torch.cat((hx4dup,hx3),1))hx3dup = _upsample_like(hx3d,hx2)hx2d = self.stage2d(torch.cat((hx3dup,hx2),1))hx2dup = _upsample_like(hx2d,hx1)hx1d = self.stage1d(torch.cat((hx2dup,hx1),1))#side outputd1 = self.side1(hx1d)d2 = self.side2(hx2d)d2 = _upsample_like(d2,d1)d3 = self.side3(hx3d)d3 = _upsample_like(d3,d1)d4 = self.side4(hx4d)d4 = _upsample_like(d4,d1)d5 = self.side5(hx5d)d5 = _upsample_like(d5,d1)d6 = self.side6(hx6)d6 = _upsample_like(d6,d1)d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))return F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)

1.7普通残差块与RSU对比

在这里插入图片描述
  普通残差块的操作可概况为 H ( x ) = F 2 ( F 1 ( x ) ) + x H(x)=F_2(F_1(x))+x H(x)=F2(F1(x))+x,其中, F 1 、 F 2 F_1、F_2 F1F2代表权重层,此处设为卷积运算。RSU 和残差块之间的主要设计区别在于,RSU 用类似 U N e t UNet UNet的结构替换了普通的单流卷积,并将原始特征替换为由权重层转换的局部特征: H R S U ( x ) = U ( F 1 ( x ) ) + F 1 ( x ) H_{RSU}(x)=U(F_1(x))+F_1(x) HRSU(x)=U(F1(x))+F1(x),其中 U U U表示多层U型结构。种设计更改使网络能够直接从每个残差块中提取来自多个尺度的特征。更值得注意的是,由于 U 结构导致的计算开销很小,因为大多数操作都应用于下采样的特征图。下图中给出了RSU中 F 1 、 U F_1、U F1U的含义:
在这里插入图片描述

  残差块性能比较:
在这里插入图片描述

  • PLN:普通卷积块。
  • RES:残差块。
  • DSE:密集块。
  • INC:初始块。
  • RSU:U型残差块。

二、代码复现

https://github.com/xuebinqin/U-2-Net/tree/master

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1546816.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

网站建设完成后,国家为什么要求域名备案后才能上线

网站建设完成后,域名备案后才能上线的原因有很多。以下是对这些原因的详细分析: 维护国家互联网安全:备案制度有助于防范网络恐怖分子利用国内网站发布恐怖言论和宣传恐怖活动,从而保护国家的互联网安全。保护用户合法权益&#…

虚拟实训室建设需要投入哪些设备?实际效果如何?

随着虚拟现实技术的飞速发展,虚拟现实实训室作为现代教学的重要组成部分,正逐步成为提升教学质量和学生实践能力的重要手段。本文将从虚拟现实实训室建设所需的软硬件设备投入以及实际效果两个方面进行探讨。 软、硬件设备投入 硬件设备方面,…

Linux高级IO之五种IO模型

文章目录 Linux高级IO之五种IO模型IO的理解阻塞式IO非阻塞IO信号驱动式IOIO多路转接异步IO同步和异步 Linux高级IO之五种IO模型 IO的理解 IO模型其实更像是网络部分的延伸和拓展,在学习完计算机网络之后,结合Linux系统,我们就该认识到&…

[数据库实验三]安全性

目录 一、实验目的与要求: 二、实验内容: 三、实验小结 一、实验目的与要求: 1、设计用户子模式 2、根据实际需要创建用户角色及用户,并授权 3、针对不同级别的用户定义不同的视图,以保证系统的安全性 二、实验内…

SAP ABAP ‘‘ 和 `` 的区别

DATA(LV_01) 100. DATA(LV_02) 200.’ ’ 输出为 Char 输出为 String 如下直接定义赋值就会报错 DATA ls_value TYPE TABLE OF string. *ls_value VALUE #( ( A ) ). "报错行 ls_value VALUE #( ( A ) ).使用的场景:动态SQL取数 DATA OPTION TYPE STRI…

生成速度更快!AI绘画工具新版 SD WebUI Forge 保姆级安装教程,更低的显存更快的生成速度!

大家好,我是程序员晓晓 不知道平时经常使用 SD WebUI 的小伙伴发现没有,随着安装插件和模型越来越多,WebUI 时不时会出现卡顿或爆显存的情况,尤其在低显存的硬件上更加明显,只能不停的重启来解决。 估计是 WebUI 的作…

进击J8:Inception v1算法实战与解析

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 一、实验目的: 了解并学习图2中的卷积层运算量的计算过程了解并学习卷积层的并行结构与1x1卷积核部分内容(重点)尝试根据模…

基于 Redis 实现滑动窗口的限流

⏳ 限流场景:突发流量,恶意流量,业务本身需要 基于 Redis 实现滑动窗口的限流是一种常见且高效的做法。Redis 是一种内存数据库,具有高性能和支持原子操作的特点,非常适合用来实现限流功能。下面是一个使用 Redis 实现…

浅拷贝和深拷贝(Java 与 JavaScript)

一、Java 浅拷贝和深拷贝 在Java中,浅拷贝和深拷贝的主要区别在于对对象的引用和内容的复制方式。 浅拷贝 Java 的类型有基本数据类型和引用类型,基本数据类型是可以由 CPU 直接操作的类型,无论是深拷贝还是浅拷贝,都是会复制出…

海外媒体投稿:提高效果的6个国内外媒体套餐内容方法

媒体推广已经成为每个企业形象宣传产品与服务关键方式之一。国内外媒体套餐内容推广方法提供了许多有效的办法,助力企业能够更好地推广自己的产品。下面我们就详细介绍这其中的6个方法,以帮助提升推广效果。 1.明确目标群体和产品定位在制订推广策略以前…

《数字图像处理(面向新工科的电工电子信息基础课程系列教材)》例10-9

灰度共生矩阵的相关性 相关性(Correlation) 公式 Correlation ∑ i 1 N g ∑ j 1 N g ( i − μ x ) ( j − μ y ) P ( i , j ) σ x σ y \text{Correlation} \frac{\sum_{i1}^{N_g} \sum_{j1}^{N_g} (i - \mu_x)(j - \mu_y) P(i,j)}{\sigma_x \…

HTML【知识改变命运】02标签/元素说明

html 的标签/元素-说明 1:说明2&#xff1a;标签的使用注意点 1:说明 1&#xff1a;html标签在<>中 2&#xff1a;html标签一般是双标签&#xff0c;如起始标签 结束标签 3&#xff1a;html也有单标签&#xff0c;如 4&#xff1a;双标签之间的文本内容就是html的元素内…

开发微信记账本小程序之技术要点记录

我喜欢极简风格&#xff0c;所以我搭建了这款微信记账本小程序。在开发微信记账本小程序过程中&#xff0c;有一些值得关注的技术要点&#xff0c;我则简而记之。 1、空态界面 在没有数据时&#xff0c;我设计了空状态时的占位提示。 在框架中&#xff0c;我使用了 wx:if&qu…

C++安全密码生成与强度检测

目标 密码生成 // Function to generate a random password std::string generatePassword(int length, bool includeUpper, bool includeNumbers, bool includeSymbols) {std::string lower "abcdefghijklmnopqrstuvwxyz";std::string upper "ABCDEFGHIJKLM…

IP地址的打卡路径是什么?

众多周知&#xff0c;IP地址使我们浏览网站的“必需品”&#xff0c;他会在我们进行网络活动时起到通关文牒一般的作用。那么&#xff0c;放我们进行网络活动时&#xff0c;我们的“通关文牒”上面会在哪些地点留下痕迹&#xff0c;IP地址的流转路径是什么&#xff1f; 第一关…

企业如何选择合适的半导体设计小企业软件

在半导体行业日益精细化与智能化的今天&#xff0c;企业选择合适的半导体设计小企业软件&#xff0c;已成为提升研发效率、缩短产品上市周期、增强市场竞争力的关键。面对市场上琳琅满目的软件产品&#xff0c;企业需从多方面考量&#xff0c;以确保所选软件既能满足当前需求&a…

SpringMVC中出现的sql语句错误

1、原sql语句&#xff1a;select major_id AS majorId ,major_name AS majorName,tuition,dept_id as deptId from tb_major where major_id ? 出现问题&#xff1a;Request processing failed: org.springframework.jdbc.BadSqlGrammarException: StatementCallback; bad SQ…

java并发之并发关键字

并发关键字 关键字一&#xff1a;volatile 可以这样说&#xff0c;volatile 关键字是 Java 虚拟机提供的轻量级的同步机制。 功能 volatile 有 2 个主要功能&#xff1a; 可见性。一个线程对共享变量的修改&#xff0c;其他线程能够立即得知这个修改。普通变量不能做到这一点&a…

将Docker镜像推送到阿里云仓库,使用Docker-compose将mysql、redis、jar包整合在一起

进入阿里云&#xff1a; https://cr.console.aliyun.com 阿里云镜像控制台 选择个人实例 创建命名空间 创建镜像仓库 下一步之后&#xff0c;创建我们的本地仓库 创建好之后可以在个人实例里看到我们刚创建好的镜像仓库 点击我们的仓库进去里面&#xff0c;可以看到里面有我们…

4.5 了解大数据处理基本流程

文章目录 1. 引言2. 数据采集2.1 数据库采集2.2 实时数据采集2.3 网络爬虫采集 3. 数据预处理3.1 数据清洗3.2 数据集成3.3 数据归约3.4 数据转换 4. 数据处理与分析4.1 数据处理4.2 数据分析 5. 数据可视化与应用5.1 数据可视化5.2 ECharts框架5.3 课堂作业 6. 结语 1. 引言 …