【深度学习基础模型】深度残差网络(Deep Residual Networks, DRN)详细理解并附实现代码。

【深度学习基础模型】深度残差网络(Deep Residual Networks, DRN)详细理解并附实现代码。

【深度学习基础模型】深度残差网络(Deep Residual Networks, DRN)详细理解并附实现代码。


文章目录

  • 【深度学习基础模型】深度残差网络(Deep Residual Networks, DRN)详细理解并附实现代码。
  • 1. 算法提出
  • 2. 概述
  • 3. 发展
  • 4. 应用
  • 5. 优缺点
  • 6. Python代码实现


参考地址:https://www.asimovinstitute.org/neural-network-zoo/
论文地址:https://arxiv.org/pdf/1512.03385

欢迎宝子们点赞、关注、收藏!欢迎宝子们批评指正!
在这里插入图片描述

1. 算法提出

深度残差网络(DRN)最初由何凯明等人于2015年在论文“Deep Residual Learning for Image Recognition”中提出。该算法的核心思想是通过残差块(Residual Block)来解决深层神经网络训练中的退化问题

传统神经网络在层数增加时,随着网络变深,训练误差反而会上升,这种现象被称为梯度消失/爆炸问题DRN通过引入跳跃连接(Skip Connection),将前几层的输入直接传递到后几层,从而有效缓解了这个问题

2. 概述

DRN的核心结构是残差块。一个典型的残差块包含一个跳跃连接,将输入直接加到输出上,如下所示:

y = F ( x ) + x y=F(x)+x y=F(x)+x

其中, x x x是残差块的输入, F ( x ) F(x) F(x)是经过几层非线性变换后的输出。通过将输入 x x x直接添加到输出 F ( x ) F(x) F(x),残差网络实际上是在学习一个残差函数。这种结构使得网络能够更容易训练,并且即使网络层数增加,网络也不会出现退化现象。

残差网络的优点在于:

  • 更深的网络结构:传统前馈神经网络(Feedforward Neural Networks, FFNN)的层数通常在几层到几十层,而DRN可以扩展到上百层甚至更深(如ResNet-152)。
  • 稳定的训练过程:通过引入跳跃连接,梯度可以更好地传播,从而缓解了梯度消失问题。

3. 发展

自2015年提出以来,残差网络成为了许多深度学习模型的基础架构。随着研究的深入,残差网络的变种也被提出,例如:

  • ResNet:最早的残差网络版本,适用于图像分类等任务。
  • ResNeXt:将残差块中的卷积运算拆分为多个并行的路径,提高了模型的可扩展性。
  • DenseNet:一种变体,进一步增加了层之间的密集连接。

4. 应用

DRN被广泛应用于各种深度学习任务中,特别是在计算机视觉领域表现出色。典型的应用包括:

  • 图像分类:ResNet在ImageNet分类任务中取得了极好的效果,常用于图像分类任务。
  • 目标检测:许多目标检测模型(如Faster R-CNN)都基于残差网络作为主干结构。
  • 语义分割:在语义分割任务中,残差网络作为特征提取器也广泛使用。

5. 优缺点

优点:

  • 有效的深度学习:DRN能够有效训练非常深的网络(可达150层甚至更多),而不会出现明显的性能退化。
  • 跳跃连接:通过跳跃连接,DRN能够更好地传播梯度,解决梯度消失问题,从而加快训练速度。
  • 强大的表达能力:可以通过残差学习获得更高的模型表达能力,适用于复杂的学习任务。

缺点:

  • 计算复杂性高:随着网络深度的增加,计算资源需求显著增加,训练时间可能较长。
  • 模型可解释性差:深度模型的复杂性可能导致难以解释其内部机制和决策过程。
  • 需要大量数据:有效训练深度残差网络通常需要大量标注数据,以防止过拟合。

6. Python代码实现

以下是一个使用深度残差网络进行图像分类的示例,基于PyTorch框架:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader# 定义残差块
class ResidualBlock(nn.Module):def __init__(self, in_channels, out_channels, stride=1):super(ResidualBlock, self).__init__()self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)self.bn2 = nn.BatchNorm2d(out_channels)# 如果输入维度和输出维度不匹配,通过1x1卷积进行匹配self.shortcut = nn.Sequential()if stride != 1 or in_channels != out_channels:self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(out_channels))def forward(self, x):out = self.relu(self.bn1(self.conv1(x)))out = self.bn2(self.conv2(out))out += self.shortcut(x)  # 跳跃连接out = self.relu(out)return out# 定义ResNet模型
class ResNet(nn.Module):def __init__(self, block, num_blocks, num_classes=10):super(ResNet, self).__init__()self.in_channels = 64self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)self.bn1 = nn.BatchNorm2d(64)self.relu = nn.ReLU(inplace=True)self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512, num_classes)def _make_layer(self, block, out_channels, num_blocks, stride):layers = []layers.append(block(self.in_channels, out_channels, stride))self.in_channels = out_channelsfor _ in range(1, num_blocks):layers.append(block(self.in_channels, out_channels))return nn.Sequential(*layers)def forward(self, x):out = self.relu(self.bn1(self.conv1(x)))out = self.layer1(out)out = self.layer2(out)out = self.layer3(out)out = self.layer4(out)out = self.avg_pool(out)out = out.view(out.size(0), -1)out = self.fc(out)return out# 实例化ResNet18模型
def ResNet18():return ResNet(ResidualBlock, [2, 2, 2, 2])  # 定义ResNet18结构# 数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])# 加载CIFAR-10数据集
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=100, shuffle=False)# 定义设备、损失函数和优化器
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNet18().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型
def train_model(num_epochs=5):for epoch in range(num_epochs):model.train()for images, labels in train_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)loss = criterion(outputs, labels)optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')# 评估模型
def test_model():model.eval()correct = 0total = 0with torch.no_grad():for images, labels in test_loader:images, labels = images.to(device), labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'测试集准确率: {100 * correct / total:.2f}%')# 运行训练和测试
train_model(num_epochs=5)
test_model()

代码解释:

  • ResidualBlock:实现了残差块,其中包括卷积层、批量归一化(Batch Normalization)、ReLU激活函数和跳跃连接。通过跳跃连接,将输入直接加到输出中,以实现残差学习。
  • ResNet:定义了ResNet模型结构,包括多个残差块的堆叠。_make_layer方法用于构建每一层的残差块。
  • 数据预处理:使用transforms.Compose对CIFAR-10数据集进行转换,进行标准化处理。
  • 模型训练:在train_model函数中,模型通过多轮训练,不断优化损失函数。
  • 模型评估:在test_model函数中,模型评估在测试集上的性能,并输出准确率。

该代码实现了基于深度残差网络的图像分类任务,展示了DRN在实际应用中的有效性。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1554899.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

使用前端三剑客实现一个备忘录

一,界面介绍 这个备忘录的界面效果如下: 可以实现任务的增删,并且在任务被勾选后会被放到已完成的下面。 示例: (1),增加一个任务 (2),勾选任务 &#xff…

Chat登录时出现SSO信息出错的解决方法

目录 1. 问题所示2. 问题所示3. 解决方法 1. 问题所示 此贴主要是总结回顾,对此放置在运维专栏 出现如下问题,很懵,以为是节点挂了还是网址蹦了 一直刷新,登录之后就出现这个问题 2. 问题所示 对于SSO,也就是单点登…

ExcelToWord-Excel套打Word-Word邮件合并工具分享

Excel to Word转换工具分享 在日常工作或学习中,我们常常需要将Excel中的数据导出到Word文档中,以便更好地展示信息。市场上有许多Excel to Word的转换工具,它们各有特色。今天,我们就来推荐几款这样的工具,并探讨一下…

基于Springboot+Vue的教师科研管理系统 (含源码数据库)

1.开发环境 开发系统:Windows10/11 架构模式:MVC/前后端分离 JDK版本: Java JDK1.8 开发工具:IDEA 数据库版本: mysql5.7或8.0 数据库可视化工具: navicat 服务器: SpringBoot自带 apache tomcat 主要技术: Java,Springboot,mybatis,mysql,vue 2.视频演示地址 3.功能 系统中…

用Python实现运筹学——Day 12: 线性规划在物流优化中的应用

一、学习内容 线性规划在物流优化中可以用于解决诸如配送路径优化、货物运输调度等问题。配送中心的路径优化问题本质上是寻找一条最优路径,在满足需求点的需求条件下,最小化配送的总运输成本或时间。常见的物流优化问题包括: 配送中心的货…

Python小示例——质地不均匀的硬币概率统计

在概率论和统计学中,随机事件的行为可以通过大量实验来研究。在日常生活中,我们经常用硬币进行抽样,比如抛硬币来决定某个结果。然而,当我们处理的是“质地不均匀”的硬币时,事情就变得复杂了。质地不均匀的硬币意味着…

【C++】—— 类和对象(中)

【C】—— 类和对象(中) 文章目录 【C】—— 类和对象(中)前言1. 类的默认成员函数2. 构造函数3. 析构函数4. 拷贝构造函数5. 赋值运算符重载5.1 运算符重载5.2 赋值运算符重载 结语 前言 小伙伴们大家好呀,昨天的 【C】——类和对象(上) 大家理解的怎么样了 今天…

网约班车升级手机端退票

背景 作为老古董程序员,不,应该叫互联网人员,因为我现在做的所有的事情,都是处于爱好,更多的时间是在和各行各业的朋友聊市场,聊需求,聊怎么通过IT互联网 改变实体行业的现状,准确的…

卡码网KamaCoder 53. 寻宝

题目来源:53. 寻宝(第七期模拟笔试) C题解(来源代码随想录):最小生成树 prim prim三部曲 第一步,选距离生成树最近节点第二步,最近节点加入生成树第三步,更新非生成树节…

随时随地,轻松翻译:英汉互译软件的便捷之旅

翻译英汉互译工具,就如同一位随时待命的语言助手,在这纷繁复杂的语言世界中为我们搭建起理解与沟通的桥梁。接下来,让我们一同深入了解这些神奇的英汉互译工具,探索它的诸多功能和独特魅力。 1.福晰在线翻译 链接直达>>h…

Python案例--三数排序

一、引言 在信息爆炸的时代,我们每天都会接触到大量的数据。无论是工作中的报表、学习中的数据集,还是日常生活中的购物清单,数据的有序性对于提高效率和决策质量都至关重要。排序算法作为数据处理的基础工具,其重要性不言而喻。…

RTSP协议讲解

1.RTSP协议 rtsp,英文全称 Real Time Streaming Protocol,RFC2326,实时流传输协议,是 TCP/IP 协议体系中的一个应用层协议。 RTSP 交互流程 1)OPTIONS C--->S 客户端向服务器端发现 OPTIONS,请求可用…

netty之SpringBoot+Netty+Elasticsearch收集日志信息数据存储

前言 将大量的业务以及用户行为数据存储起来用于分析处理,但是由于数据量较大且需要具备可分析功能所以将数据存储到文件系统更为合理。尤其是一些互联网高并发级应用,往往数据库都采用分库分表设计,那么将这些分散的数据通过binlog汇总到一个…

Go基础学习11-测试工具gomock和monkey的使用

文章目录 基础回顾MockMock是什么安装gomockMock使用1. 创建user.go源文件2. 使用mockgen生成对应的Mock文件3. 使用mockgen命令生成后在对应包mock下可以查看生成的mock文件4. 编写测试代码5. 运行代码并查看输出 GomonkeyGomonkey优势安装使用对函数进行monkey对结构体中方法…

SQL专项练习第二天

在数据处理和分析中,Hive 是一个强大的工具。本文将通过五个 Hive 相关的问题展示其在不同场景下的应用技巧。 先在home文件夹下建一个hivedata文件夹,把我们所需的数据写成txt文件导入到/home/hivedata/文件夹下面。 一、找出连续活跃 3 天及以上的用户…

茄子病虫害数据集。四类:果肉腐烂、蛀虫、健康、黄斑病。4000张图片,已经按照8:2的比例划分好训练集、验证集 txt格式 含类别yaml文件 已经标注好

茄子病虫害数据集。可用于筛选茄子品质、质量,训练采摘机器人视觉算法模型……数据集大部分图片来源于真实果园拍摄的图片(生长在果树之上的),图片分辨率高,数据集分为四类:果肉腐烂、蛀虫、健康、黄斑病。…

Pandas数据分析基础

目录标题 Pandas读取和写入数据数据读取读取csv读取excel数据输出 Pandas基础操作索引数据信息统计计算位置计算数据选择 Pandas高级操作复杂查询类型转换数据排序添加修改高级过滤数据迭代高阶函数 Pandas读取和写入数据 Pandas将数据加载到DataFrame后,就可以使用…

算法知识点————贪心

贪心:只考虑局部最优解,不考虑全部最优解。有时候得不到最优解。 DP:考虑全局最优解。DP的特点:无后效性(正在求解的时候不关心前面的解是怎么求的); 二者都是在求最优解的,都有最优…

TB6612电机驱动模块(STM32)

目录 一、介绍 二、模块原理 1.原理图 2.电机驱动原理 三、程序设计 main.c文件 Motor.h文件 Motor.c文件 四、实验效果 五、资料获取 项目分享 一、介绍 TB6612FNG 是东芝半导体公司生产的一款直流电机驱动器件,它具有大电流 MOSFET-H 桥结构&#xff…

【每天学个新注解】Day 15 Lombok注解简解(十四)—@UtilityClass、@Helper

UtilityClass 生成工具类的注解 将一个类通过注解变成一个工具类,并没有什么用,本来代码中的工具类数量就极为有限,并不能达到减少重复代码的目的 1、如何使用 加在需要委托将其变为工具类的普通类上。 2、代码示例 例: Uti…