动手学深度学习(pytorch)学习记录31-批量规范化(batch normalization)[学习记录]

目录

  • 批量规范化(batch normalization)
  • 从头开始实现一个具有张量的批量规范化层
  • 简明实现

批量规范化(batch normalization)

可持续加速深层网络的的收敛速度。再结合残差块,批量规范化使得研究人员可以训练100层以上的网络。
为什么需要批量规范化层呢?
1、数据预处理的方式通常会对最终结果产生巨大影响;
2、中间层的变量可能具有更广变化范围,由于可变值的范围不同,是否需要对学习率进行调整;
3、深层的网络很复杂,容易过拟合。

批量规范化层的在训练模式中,通过小批量统计数据进行规范化;在预测模式中通过数据集统计进行规范化。

从头开始实现一个具有张量的批量规范化层

import torch
from torch import nn
from d2l import torch as d2ldef batch_norm(X, gamma, beta, moving_mean, moving_var, eps, momentum):# 通过is_grad_enabled来判断当前模式是训练模式还是预测模式if not torch.is_grad_enabled():# 如果是在预测模式下,直接使用传入的移动平均所得的均值和方差X_hat = (X - moving_mean) / torch.sqrt(moving_var + eps)else:assert len(X.shape) in (2, 4)if len(X.shape) == 2:# 使用全连接层的情况,计算特征维上的均值和方差mean = X.mean(dim=0)var = ((X - mean) ** 2).mean(dim=0)else:# 使用二维卷积层的情况,计算通道维上(axis=1)的均值和方差。# 这里我们需要保持X的形状以便后面可以做广播运算mean = X.mean(dim=(0, 2, 3), keepdim=True)var = ((X - mean) ** 2).mean(dim=(0, 2, 3), keepdim=True)# 训练模式下,用当前的均值和方差做标准化X_hat = (X - mean) / torch.sqrt(var + eps)# 更新移动平均的均值和方差moving_mean = momentum * moving_mean + (1.0 - momentum) * meanmoving_var = momentum * moving_var + (1.0 - momentum) * varY = gamma * X_hat + beta  # 缩放和移位return Y, moving_mean.data, moving_var.data

创建一个正确的BatchNorm层

class BatchNorm(nn.Module):# num_features:完全连接层的输出数量或卷积层的输出通道数。# num_dims:2表示完全连接层,4表示卷积层def __init__(self, num_features, num_dims):super().__init__()if num_dims == 2:shape = (1, num_features)else:shape = (1, num_features, 1, 1)# 参与求梯度和迭代的拉伸和偏移参数,分别初始化成1和0self.gamma = nn.Parameter(torch.ones(shape))self.beta = nn.Parameter(torch.zeros(shape))# 非模型参数的变量初始化为0和1self.moving_mean = torch.zeros(shape)self.moving_var = torch.ones(shape)def forward(self, X):# 如果X不在内存上,将moving_mean和moving_var# 复制到X所在显存上if self.moving_mean.device != X.device:self.moving_mean = self.moving_mean.to(X.device)self.moving_var = self.moving_var.to(X.device)# 保存更新过的moving_mean和moving_varY, self.moving_mean, self.moving_var = batch_norm(X, self.gamma, self.beta, self.moving_mean,self.moving_var, eps=1e-5, momentum=0.9)return Y

使用批量规范化层的 LeNet

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), BatchNorm(6, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), BatchNorm(16, num_dims=4), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(16*4*4, 120), BatchNorm(120, num_dims=2), nn.Sigmoid(),nn.Linear(120, 84), BatchNorm(84, num_dims=2), nn.Sigmoid(),nn.Linear(84, 10))
lr, num_epochs, batch_size = 1.0, 10, 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.273, train acc 0.898, test acc 0.881
19328.6 examples/sec on cuda:0

在这里插入图片描述
看从第一个批量规范化层中学到的拉伸参数gamma和偏移参数beta

net[1].gamma.reshape((-1,)), net[1].beta.reshape((-1,))
(tensor([3.9210, 1.8005, 0.3629, 1.7734, 2.6866, 3.6042], device='cuda:0',grad_fn=<ViewBackward0>),tensor([-2.6395, -2.0010,  0.3424, -2.0999, -1.0537, -3.7402], device='cuda:0',grad_fn=<ViewBackward0>))

简明实现

直接使用深度学习框架中定义的BatchNorm

net = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5), nn.BatchNorm2d(6), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.BatchNorm2d(16), nn.Sigmoid(),nn.AvgPool2d(kernel_size=2, stride=2), nn.Flatten(),nn.Linear(256, 120), nn.BatchNorm1d(120), nn.Sigmoid(),nn.Linear(120, 84), nn.BatchNorm1d(84), nn.Sigmoid(),nn.Linear(84, 10))

通常高级API变体运行速度快得多,因为它的代码已编译为C++或CUDA,而我们的自定义代码由Python实现。

d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.276, train acc 0.898, test acc 0.873
31860.4 examples/sec on cuda:0

在这里插入图片描述

· 本文使用了d2l包,这极大地减少了代码编辑量,需要安装d2l包才能运行本文代码

封面图片来源
欢迎点击我的主页查看更多文章。
本人学习地址https://zh-v2.d2l.ai/
恳请大佬批评指正。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/143090.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Java项目实战II基于Java+Spring Boot+MySQL的校园社团信息管理系统(源码+数据库+文档)

目录 一、前言 二、技术介绍 三、系统实现 四、论文参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发&#xff0c;CSDN平台Java领域新星创作者&#xff0c;专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 在当今高校…

回归预测|2024年2月最新优化算法角蜥优化HLOA|基于角蜥优化BP神经网络数据回归Matlab程序HLOA-BP【优化效果好】

回归预测|2024年2月最新优化算法角蜥优化HLOA|基于角蜥优化BP神经网络数据回归Matlab程序HLOA-BP【优化效果好】 文章目录 一、基本原理1. 角蜥优化算法&#xff08;HLOA&#xff09;简介2. BP 神经网络&#xff08;BP Neural Network&#xff09;简介3. HLOA-BP 回归预测流程总…

Qt开发技巧(四)“tr“使用,时间类使用,Qt容器取值,类对象的删除,QPainter画家类,QString的转换,用好 QVariant类型

继续讲一些Qt技巧操作 1.非必要不用"tr" 如果程序运行场景确定是某一固定语言&#xff0c;就不需要用tr,"tr"之主要针对多语种翻译的&#xff0c;因为tr的本意是包含英文&#xff0c;然后翻译到其他语言比如中文&#xff0c;不要滥用tr&#xff0c;如果没有…

[Python数据可视化]Plotly Express: 地图数据可视化的魅力

在数据分析和可视化的世界中&#xff0c;地图数据可视化是一个强大而直观的工具&#xff0c;它可以帮助我们更好地理解和解释地理数据。Python 的 Plotly Express 库提供了一个简单而强大的方式来创建各种地图。本文将通过一个简单的示例&#xff0c;展示如何使用 Plotly Expre…

FastAdmin CMS 操作手册

FastAdmin CMS 操作手册 概述&#xff1a; 安装&#xff1a; 配置&#xff1a; 模板&#xff1a; 模板目录&#xff1a; 标签&#xff1a; 全局&#xff1a; 文章&#xff1a; 专题&#xff1a; 栏目&#xff1a; 公共参数&#xff1a; 单页&#xff1a; 特殊标签&#xff1a;…

macOS平台TextRank环境配置

1.安装python 2.安装numpy: pip3 install numpy 3.安装networkx: pip3 install networkx 4.安装math2: pip3 install math2 math2安装成功

npm安装时候报错certificate has expired

打开了一个很久没用的电脑&#xff0c;npm和node都装好了&#xff0c;安装包的时候一直报错 request to https://registry.npm.taobao.org/create-react-app failed, reason: certificate has expired而且先报错rollbackFailedOptional 然而npm没什么问题&#xff0c;是ssl过…

TSRPC+Cocos

TSRPC文档: https://tsrpc.cn/docs/get-started/api.html 创建 先创建一个默认的会话项目&#xff0c;找一个文件夹在控制台运行以下代码&#xff1a; npx create-tsrpc-applatest first-api --presets browser # 或者 yarn create tsrpc-app first-api --presets browser运…

【C++语言】C/C++内存管理

一、C/C内存分布 我们先来看一看C/C中有哪些区域&#xff0c;为什么C/C中区分这些区域呢&#xff1f;&#xff1f;不同的数据有不同的存储需求&#xff0c;各个区域满足不同的需求。我们有临时用的数据&#xff0c;该数据是存储在栈帧区域的&#xff1b;在一些数据结构中&#…

『功能项目』回调函数处理死亡【54】

我们打开上一篇53伤害数字UI显示的项目&#xff0c; 本章要做的事情是使用回调函数处理怪物Boss01死亡后增加主角经验值的功能&#xff0c;以及生成一个七秒的升级特效 首先增加一个技能特效重命名为PlayerUpGradeEffect 修改脚本&#xff1a;BossCtrl.cs 修改脚本&#xff1a…

【Linux系统编程】信号的保存与处理

目录 一&#xff0c;信号的保存 1-1&#xff0c;core与Term终止信号 1-2&#xff0c;进程退出与信号的关系 1-3&#xff0c;信号在内核中的表示 1-4&#xff0c;信号操作函数 二&#xff0c;信号的处理 2-1&#xff0c;信号被处理的时期 2-2&#xff0c;内核实现信号的…

zip-password-finder

1.zip-password-finder 对于传统ZIP文件密码的破解&#xff0c;采用密码匹配的方式进行实现&#xff0c;该github库的地址是&#xff1a; GitHub - agourlay/zip-password-finder: Find the password of protected ZIP files.Find the password of protected ZIP files. Cont…

整数二分算法和浮点数二分算法

整数二分算法和浮点数二分算法 二分 现实中运用到二分的就是猜数字的游戏 假如有A同学说B同学所说数的大小&#xff0c;B同学要在1~100中间猜中数字65&#xff0c;当B同学每次说的数都是范围的一半时这就算是一个二分查找的过程 二分查找的前提是这个数字序列要有单调性 基…

java--JDBC-连接池----JDBC小总结

一.连接池 1.连接池概述 目的&#xff1a;为了解决建立数据库连接耗费资源和时间很多的问题&#xff0c;提高性能。 Connection对象在JDBC使用的时候就会去创建一个对象,使用结束以后就会将这个对象给销毁了(close).每次创建和销毁对象都是耗时操作.需要使用连接池对其进行优…

类加载器详细介绍

类加载器我们要聊一个神秘而又重要的角色——Java类加载器。这家伙&#xff0c;就像是个超级英雄&#xff0c;总是在关键时刻挺身而出&#xff0c;为我们的Java程序提供强大的支持。我会尽量用简单易懂的方式来介绍它。 一 、类加载器介绍 1、类加载器是什么&#xff1f; 想象…

【python设计模式2】创建型模式1

目录 简单工厂模式 工厂方法模式 简单工厂模式 简单工厂模式不是23中设计模式中的&#xff0c;但是必须要知道。简单工厂模式不直接向客户端暴露对象创建的细节&#xff0c;而是通过一个工厂类来负责创建产品类的实例。简单工程模式的角色有&#xff1a;工厂角色、抽象产品角…

Redis(redis基础,SpringCache,SpringDataRedis)

文章目录 前言一、Redis基础1. Redis简介2. Redis下载与安装3. Redis服务启动与停止3 Redis数据类型4. Redis常用命令5. 扩展数据类型 二、在Java中操作Redis1. Spring Data Redis的使用1.1. 介绍1.2. 环境搭建1.3. 编写配置类&#xff0c;创建RedisTemplate对象1.4. 通过Redis…

Linux入门学习:git

文章目录 1. 创建仓库2. 仓库克隆3. 上传文件4. 相关问题4.1 git进程阻塞4.2 git log4.3 上传的三个步骤在做什么 本文介绍如何在Linux操作系统下简单使用git&#xff0c;对自己的代码进行云端保存。 1. 创建仓库 &#x1f539;这里演示gitee的仓库创建。 2. 仓库克隆 &…

Zookeeper 3.8.4 安装和参数解析

安装 zookeeper 之前必须先安装 JDK&#xff0c;有关Linux环境JDK可以参考我以前写的博文 1、关于Linux服务器配置java环境遇到的问题 2、Linux环境安装openJDK 3、Centos7.3云服务器上安装Nginx、MySQL、JDK、Tomcat环境 文章目录 1. zookeeper 安装2. 参数解析 1. zookeeper …

VScode快速配置c++(菜鸟版)

1.vscode是什么 Visual Stdio Code简称VS Code&#xff0c;是一款跨平台的、免费且开源的现代轻量级代码编辑器&#xff0c;支持几乎 主流开发语言的语法高亮、智能代码补全、自定义快捷键、括号匹配和颜色区分、代码片段提示、代码对比等特性&#xff0c;也拥有对git的开箱即…