MEGANet: Multi-Scale Edge-Guided Attention Network for Weak Boundary Polyp Segmentation
论文地址:
解决问题:
解决方案细节:
解决方案用于目标检测:
即插即用代码:
论文地址:
https://arxiv.org/pdf/2309.03329https://arxiv.org/pdf/2309.03329
解决问题:
MEGANet 主要解决了弱边界息肉分割问题。息肉图像通常具有复杂的背景、多变的形状和模糊的边界,这给分割任务带来了挑战。
MEGANet 通过结合边缘信息和注意力机制,有效地保留了高频边缘信息,从而提高了分割精度。MEGANet 的解决方案主要包括三个模块:
-
编码器: 从输入图像中提取特征。
-
解码器: 利用编码器提取的特征生成分割结果。
-
边缘引导注意力模块 (EGA): 利用拉普拉斯算子增强息肉边界信息,并引导模型关注边缘相关的特征。
解决方案细节:
-
EGA 模块:
-
接收来自编码器的嵌入特征、来自拉普拉斯算子的高频特征以及来自解码器的预测特征。
-
将高频特征与边界注意力图和反向注意力图进行元素级乘法,得到融合特征。
-
使用注意力掩码引导模型关注重要区域,抑制背景噪声。
-
通过 CBAM 模块进一步细化特征,捕捉边界与背景区域之间的特征相关性。
-
解决方案用于目标检测:
MEGANet 的 EGA 模块可以应用于目标检测任务,用于增强目标边界信息,提高检测精度。 具体应用位置可以参考以下几种方案:
-
特征提取阶段: 将 EGA 模块添加到特征提取网络中,例如在 ResNet 或 EfficientNet 的某些层之间插入 EGA 模块,增强特征图中目标边界信息。
-
目标框回归阶段: 将 EGA 模块添加到目标框回归网络中,例如在 RetinaNet 或 YOLO 的回归层之前添加 EGA 模块,引导模型更精确地回归目标边界。
-
目标分类阶段: 将 EGA 模块添加到目标分类网络中,例如在 Faster R-CNN 的 RoI Pooling 层之后添加 EGA 模块,增强目标区域特征,提高分类准确率。
需要注意的是,将 EGA 模块应用于目标检测任务需要进行一些调整,例如:
-
选择合适的边缘检测方法: 拉普拉斯算子可能不适用于所有目标检测任务,需要根据任务特点选择合适的边缘检测方法。
-
调整 EGA 模块结构: 根据目标检测网络的结构和任务需求,调整 EGA 模块的结构和参数。
-
训练策略: 需要重新训练模型,并调整训练策略,例如学习率、优化器等。
总的来说,MEGANet 的 EGA 模块为解决弱边界目标分割问题提供了一种有效的方法,并且可以应用于目标检测任务,提高检测精度。
即插即用代码:
import torch
import torch.nn.functional as F
import torch.nn as nndef gauss_kernel(channels=3, cuda=True):kernel = torch.tensor([[1., 4., 6., 4., 1],[4., 16., 24., 16., 4.],[6., 24., 36., 24., 6.],[4., 16., 24., 16., 4.],[1., 4., 6., 4., 1.]])kernel /= 256.kernel = kernel.repeat(channels, 1, 1, 1)if cuda:kernel = kernel.cuda()return kerneldef downsample(x):return x[:, :, ::2, ::2]def conv_gauss(img, kernel):img = F.pad(img, (2, 2, 2, 2), mode='reflect')out = F.conv2d(img, kernel, groups=img.shape[1])return outdef upsample(x, channels):cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3], device=x.device)], dim=3)cc = cc.view(x.shape[0], x.shape[1], x.shape[2] * 2, x.shape[3])cc = cc.permute(0, 1, 3, 2)cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2] * 2, device=x.device)], dim=3)cc = cc.view(x.shape[0], x.shape[1], x.shape[3] * 2, x.shape[2] * 2)x_up = cc.permute(0, 1, 3, 2)return conv_gauss(x_up, 4 * gauss_kernel(channels))def make_laplace(img, channels):filtered = conv_gauss(img, gauss_kernel(channels))down = downsample(filtered)up = upsample(down, channels)if up.shape[2] != img.shape[2] or up.shape[3] != img.shape[3]:up = nn.functional.interpolate(up, size=(img.shape[2], img.shape[3]))diff = img - upreturn diffdef make_laplace_pyramid(img, level, channels):current = imgpyr = []for _ in range(level):filtered = conv_gauss(current, gauss_kernel(channels))down = downsample(filtered)up = upsample(down, channels)if up.shape[2] != current.shape[2] or up.shape[3] != current.shape[3]:up = nn.functional.interpolate(up, size=(current.shape[2], current.shape[3]))diff = current - uppyr.append(diff)current = downpyr.append(current)return pyrclass ChannelGate(nn.Module):def __init__(self, gate_channels, reduction_ratio=16):super(ChannelGate, self).__init__()self.gate_channels = gate_channelsself.mlp = nn.Sequential(nn.Flatten(),nn.Linear(gate_channels, gate_channels // reduction_ratio),nn.ReLU(),nn.Linear(gate_channels // reduction_ratio, gate_channels))def forward(self, x):avg_out = self.mlp(F.avg_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))))max_out = self.mlp(F.max_pool2d(x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3))))channel_att_sum = avg_out + max_outscale = torch.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x)return x * scaleclass SpatialGate(nn.Module):def __init__(self):super(SpatialGate, self).__init__()kernel_size = 7self.spatial = nn.Conv2d(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2)def forward(self, x):x_compress = torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)x_out = self.spatial(x_compress)scale = torch.sigmoid(x_out) # broadcastingreturn x * scaleclass CBAM(nn.Module):def __init__(self, gate_channels, reduction_ratio=16):super(CBAM, self).__init__()self.ChannelGate = ChannelGate(gate_channels, reduction_ratio)self.SpatialGate = SpatialGate()def forward(self, x):x_out = self.ChannelGate(x)x_out = self.SpatialGate(x_out)return x_out# Edge-Guided Attention Module(EGA)
class EGA(nn.Module):def __init__(self, in_channels):super(EGA, self).__init__()self.fusion_conv = nn.Sequential(nn.Conv2d(in_channels * 3, in_channels, 3, 1, 1),nn.BatchNorm2d(in_channels),nn.ReLU(inplace=True))self.attention = nn.Sequential(nn.Conv2d(in_channels, 1, 3, 1, 1),nn.BatchNorm2d(1),nn.Sigmoid())self.cbam = CBAM(in_channels)def forward(self, edge_feature, x, pred):residual = xxsize = x.size()[2:]pred = torch.sigmoid(pred)# reverse attentionbackground_att = 1 - predbackground_x = x * background_att# boudary attentionedge_pred = make_laplace(pred, 1)pred_feature = x * edge_pred# high-frequency featureedge_input = F.interpolate(edge_feature, size=xsize, mode='bilinear', align_corners=True)input_feature = x * edge_inputfusion_feature = torch.cat([background_x, pred_feature, input_feature], dim=1)fusion_feature = self.fusion_conv(fusion_feature)attention_map = self.attention(fusion_feature)fusion_feature = fusion_feature * attention_mapout = fusion_feature + residualout = self.cbam(out)return outif __name__ == '__main__':# 模拟输入张量edge_feature = torch.randn(1, 1, 128, 128).cuda()x = torch.randn(1, 64, 128, 128).cuda()pred = torch.randn(1, 1, 128, 128).cuda() # pred 通常是1通道# 实例化 EGA 类block = EGA(64).cuda()# 传递输入张量通过 EGA 实例output = block(edge_feature, x, pred)# 打印输入和输出的形状print(edge_feature.size())print(x.size())print(pred.size())print(output.size())
大家对于YOLO改进感兴趣的可以进群了解,群中有答疑,(QQ群:828370883)