PyTorch深度学习实战——交通标志识别

PyTorch深度学习实战——交通标记识别

    • 0. 前言
    • 1. 交通标志识别
      • 1.1 数据集介绍
      • 1.2 数据增强和批归一化
    • 3. 交通标志检测
    • 相关链接

0. 前言

在道路交通场景中,交通标志识别作为驾驶辅助系统与无人驾驶车辆中不可缺少的技术,为车辆行驶中提供了安全保障。在道路上行驶的车辆,道路周围的环境包括许多重要的交通标志信息,根据交通标志信息在道路上做出正确的驾驶行为,通常能够避免发生交通事故。交通标志识别可以检测并识别当前行驶道路上的交通标志,然后得出有关道路的必要信息。
但交通标志会受到车辆的运动状态、光照以及遮挡等环境因素的影响,因此如何使车辆在道路交通中快速准确地帮助驾驶员识别交通标志已经成为智能交通领域的热点问题之一。鉴于交通标志识别在自动驾驶等应用中具有重要作用,在节中,我们将学习使用卷积神经网络实现交通标志识别。

1. 交通标志识别

1.1 数据集介绍

德国交通标志识别基准 (German Traffic Sign Recognition Benchmark, GTSRB) 是高级驾驶辅助系统和自动驾驶领域的交通标志图像分类基准。其中共包含 43 种不同类别的交通标志。可以在官方网页中下载相关数据集。
每张图片包含一个交通标志,图像包含实际交通标志周围的环境像素,大约为交通标志尺寸的 10% (至少为 5 个像素),图像以 PPM 格式存储,图像尺寸在 15x15250x250 像素之间。

1.2 数据增强和批归一化

在介绍神经网络时,我们已经了解了利用数据增强可以提高模型准确性。在现实世界中,我们会遇到具有不同特征的图像,例如,某些图像可能更亮,某些图像中的感兴趣对象可能在图像边缘附近,而某些图像可能较为模糊。在本节中,我们将介绍如何使用数据增强和批归一化提高模型的准确率。
为了了解数据增强和批归一化对模型性能的影响,我们将使用交通标志数据集训练交通标志识别模型,并评估以下三种情况:

  • 不使用批归一化和数据增强
  • 只使用批归一化,但不使用数据增强
  • 同时使用批归一化和数据增强

除了以上不同外,数据集及其预处理方法完全相同。

3. 交通标志检测

首先考虑不使用批归一化和数据增强的情况,使用 PyTorch 实现交通标志识别。

(1) 下载数据集并导入相关库:

import os
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset
from glob import glob
from random import randint
import cv2
from pathlib import Path
import torch.nn as nn
from torch import optim
from matplotlib import pyplot as plt
import pandas as pddevice = 'cuda' if torch.cuda.is_available() else 'cpu'

(2) 指定与输出类别对应的索引:

from torchvision import transforms as T
classIds = pd.read_csv('signnames.csv')
classIds.set_index('ClassId', inplace=True)
classIds = classIds.to_dict()['SignName']
classIds = {f'{k:05d}':v for k,v in classIds.items()}
id2int = {v:ix for ix,(k,v) in enumerate(classIds.items())}

(3) 定义图像转换管道,执行图像转换操作(不使用数据增强):

trn_tfms = T.Compose([T.ToPILImage(),T.Resize(32),T.CenterCrop(32),# T.ColorJitter(brightness=(0.8,1.2), # contrast=(0.8,1.2), # saturation=(0.8,1.2), # hue=0.25),# T.RandomAffine(5, translate=(0.01,0.1)),T.ToTensor(),T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])val_tfms = T.Compose([T.ToPILImage(),T.Resize(32),T.CenterCrop(32),T.ToTensor(),T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

在以上代码中,对输入图像进行了一系列转换——首先将图像尺寸调整为 128 (最小边为 128),然后从图像中心进行裁剪。此外,使用 .ToTensor() 方法对图像进行缩放(使像素值位于 01 之间),最后对图像进行归一化处理,以便使用预训练模型。
取消以上代码中的注释并重新运行即可执行数据增强。此外,我们并不会对 val_tfms 执行数据增强,因为在模型训练期间没有使用这些图像。但是,val_tfms 图像需要通过与 trn_tfms 相同的转换管道。

(4) 定义数据集类 GTSRB

class GTSRB(Dataset):"""Road Sign Detection dataset."""def __init__(self, files, transform=None):self.files = filesself.transform = transformdef __len__(self):return len(self.files)def __getitem__(self, ix):fpath = self.files[ix]clss = os.path.basename(Path(fpath).parent)img = cv2.imread(fpath)img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)return img, classIds[clss]def choose(self):return self[randint(len(self))]def collate_fn(self, batch):imgs, classes = list(zip(*batch))if self.transform:imgs = [self.transform(img)[None] for img in imgs]classes = [torch.tensor([id2int[clss]]) for clss in classes]imgs, classes = [torch.cat(i).to(device) for i in [imgs, classes]]return imgs, classes

(5) 创建训练、验证数据集和数据加载器:

all_files = glob('GTSRB/Final_Training/Images/*/*.ppm')
np.random.shuffle(all_files)from sklearn.model_selection import train_test_split
trn_files, val_files = train_test_split(all_files, random_state=1)trn_ds = GTSRB(trn_files, transform=trn_tfms)
val_ds = GTSRB(val_files, transform=val_tfms)
trn_dl = DataLoader(trn_ds, 32, shuffle=True, collate_fn=trn_ds.collate_fn)
val_dl = DataLoader(val_ds, 32, shuffle=False, collate_fn=val_ds.collate_fn)

(6) 定义模型 SignClassifier

def convBlock(ni, no):return nn.Sequential(nn.Dropout(0.2),nn.Conv2d(ni, no, kernel_size=3, padding=1),nn.ReLU(inplace=True),#nn.BatchNorm2d(no),nn.MaxPool2d(2),)class SignClassifier(nn.Module):def __init__(self):super().__init__()self.model = nn.Sequential(convBlock(3, 64),convBlock(64, 64),convBlock(64, 128),convBlock(128, 64),nn.Flatten(),nn.Linear(256, 256),nn.Dropout(0.2),nn.ReLU(inplace=True),nn.Linear(256, len(id2int)))self.loss_fn = nn.CrossEntropyLoss()def forward(self, x):return self.model(x)def compute_metrics(self, preds, targets):ce_loss = self.loss_fn(preds, targets)acc = (torch.max(preds, 1)[1] == targets).float().mean()return ce_loss, acc

当需要在模型中使用 BatchNormalization (批归一化)时,需要取消注释以上代码中注释行。

(7) 定义使用批数据对模型进行训练和验证的函数:

def train_batch(model, data, optimizer, criterion):ims, labels = data_preds = model(ims)optimizer.zero_grad()loss, acc = criterion(_preds, labels)loss.backward()optimizer.step()return loss.item(), acc.item()@torch.no_grad()
def validate_batch(model, data, criterion):ims, labels = data_preds = model(ims)loss, acc = criterion(_preds, labels)return loss.item(), acc.item()

(8) 定义模型并对其进行训练:

model = SignClassifier().to(device)
criterion = model.compute_metrics
optimizer = optim.Adam(model.parameters(), lr=1e-3)
n_epochs = 50train_loss_epochs_no_aug_no_bn = []
train_acc_epochs_no_aug_no_bn = []
val_loss_epochs_no_aug_no_bn = []
val_acc_epochs_no_aug_no_bn = []
for ex in range(n_epochs):train_loss = []train_acc = []val_loss = []val_acc = []N = len(trn_dl)for bx, data in enumerate(trn_dl):loss, acc = train_batch(model, data, optimizer, criterion)train_loss.append(loss)train_acc.append(acc)N = len(val_dl)for bx, data in enumerate(val_dl):loss, acc = validate_batch(model, data, criterion)val_loss.append(loss)val_acc.append(acc)train_loss_epochs_no_aug_no_bn.append(np.average(train_loss))train_acc_epochs_no_aug_no_bn.append(np.average(train_acc))val_loss_epochs_no_aug_no_bn.append(np.average(val_loss))val_acc_epochs_no_aug_no_bn.append(np.average(val_acc))if ex == 10:optimizer = optim.Adam(model.parameters(), lr=1e-4)epochs = np.arange(50)+1
import matplotlib.pyplot as plt
plt.subplot(121)
plt.plot(epochs, train_loss_epochs_no_aug_no_bn, 'bo', label='Training loss')
plt.plot(epochs, val_loss_epochs_no_aug_no_bn, 'r', label='Test loss')
plt.title('Training and Test loss over increasing epochs \n with no batchnormalization and augmentation')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid('off')
plt.subplot(122)
plt.plot(epochs, train_acc_epochs_no_aug_no_bn, 'bo', label='Training accuracy')
plt.plot(epochs, val_acc_epochs_no_aug_no_bn, 'r', label='Test accuracy')
plt.title('Training and Test accuracy over increasing epochs \n with no batchnormalization and augmentation')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.grid('off')
plt.show()

在三种不同实验设定中,模型在训练过程中的训练和验证准确率如下:

模型性能变化

根据以上结果,我们可以看出:

  • 当没有使用批归一化时,模型的准确率较低
  • 只使用批归一化但未使用数据增强时,模型的准确性会大大提高,但模型在训练数据上出现过拟合现象
  • 同时使用批归一化和数据增强的模型具有很高的准确性和较小的过拟合

相关链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/141221.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

tomcat在idea上的配置

tomcat在idea上的配置主要包含以下几个步骤: 1、创建一个maven web工程 2、配置tomcat 1、创建一个maven web工程 第一个是仓库配置文件的路径,第二个是你的仓库路径。 2、配置tomcat 配置tomcat有以下两种方式: 1、集成配置 2、插件配置…

【数据结构】链表和LinkedList的理解和使用

目录 1.前言 2.链表 2.1链表的概念以及结构 2.2链表的实现 3.LinkedList的使用 3.1什么是LinkedList 3.2LinkedList的使用 2.常用的方法介绍 4. ArrayList和LinkedList的区别 1.前言 在上一篇文章中我们介绍了顺序表,ArrayList的底层原理和具体的使用&#x…

数字IC笔试千题解--单选题篇(二)

前言 出笔试题汇总,是为了总结秋招可能遇到的问题,做题不是目的,在做题的过程中发现自己的漏洞,巩固基础才是目的。 所有题目结果和解释由笔者给出,答案主观性较强,若有错误欢迎评论区指出,资料…

Spring面试题18:Spring中可以注入一个null和一个空字符串吗?Spring中如何注入一个java集合?

该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:Spring中可以注入一个null和一个空字符串吗? 在Spring中是可以注入null和空字符串的。 注入null:可以使用@Value注解,将属性值设为null。例如:…

使用 PyTorch 的计算机视觉简介 (3/6)

一、说明 在本单元中,我们将了解卷积神经网络(CNN),它是专门为计算机视觉设计的。 卷积层允许我们从图像中提取某些图像模式,以便最终分类器基于这些特征。 二、卷积神经网络 计算机视觉不同于通用分类,因…

C++之va_start、vasprintf、va_end应用总结(二百二十六)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…

【算法】排序——插入排序及希尔排序

目录 前言 一、排序的概念及其应用 1.1排序的概念 1.2排序的应用 1.3常见的排序算法 二、插入排序的实现 基于插入排序的优化——希尔排序(缩小增量排序 个人主页 代码仓库 C语言专栏 初阶数据结构专栏 Linux专栏 LeetCode刷题 算法专栏 前言 这…

Neo4j图数据库_web页面关闭登录实现免登陆访问_常用的cypher语句_删除_查询_创建关系图谱---Neo4j图数据库工作笔记0013

由于除了安装,那么真实使用的时候,就是导入数据了,有了关系和节点的csv文件以后如果用 cypher进行导入数据和创建关系图谱,还有进行查询,以及如果导入错误如何清空,大概是这些 用的最多的,单独把这些拿进来,总结一下,用的会比较方便. 1.实现免登陆访问: /data/module/neo4j-…

基于微信小程序的在线小说阅读系统,附数据库、教程

1 功能简介 Java基于微信小程序的在线小说阅读系统 微信小程序的在线小说阅读系统,系统的整体功能需求分为两部分,第一部分主要是后台的功能,后台功能主要有小说信息管理、注册用户管理、系统系统等功能。微信小程序主要分为首页、分类和我的…

【数据结构--排序】堆排序

💐 🌸 🌷 🍀 🌹 🌻 🌺 🍁 🍃 🍂 🌿 🍄🍝 🍛 🍤 📃个人主页 :阿然成长日记 …

Linux的socket通信

关于套接字通信定义如下: 套接字对应程序猿来说就是一套网络通信的接口,使用这套接口就可以完成网络通信。网络通信的主体主要分为两部分:客户端和服务器端。在客户端和服务器通信的时候需要频繁提到三个概念:IP、端口、通信数据&…

测试C#图像文本识别模块Tesseract的基本用法

微信公众号“dotNET跨平台”的文章《c#实现图片文体提取》(参考文献3)介绍了C#图像文本识别模块Tesseract,后者是tesseract-ocr(参考文献2) 的C#封装版本,目前版本为5.2,关于Tesseract的详细介绍…

马尔可夫链预测举例——钢琴销售的存贮策略

问题概述 一家钢琴专卖店,根据以往的销售经验,平均每周只能售出一架钢琴,现在经理指定的存贮策略是,每周末检查库存存量,仅当库存量为零时,才订购3架供下周销售;否则就不订购。试估计这种策略下…

pytest一些常见的插件

Pytest拥有丰富的插件架构,超过800个以上的外部插件和活跃的社区,在PyPI项目中以“ pytest- *”为标识。 本篇将列举github标星超过两百的一些插件进行实战演示。 插件库地址:http://plugincompat.herokuapp.com/ 1、pytest-html&#xff1…

数据库——理论基础

目录 1.1 什么是数据库 1.2 数据库管理系统(DBMS) 1.3 数据库和文件系统的区别 1.4 数据库的发展史 1.5常见的数据库 1.5.1关系型数据库 1.5.2 非关系型数据库 1.6 DBMS支持的数据模型 1.1 什么是数据库 数据:描述事物的符号记录 数…

opencv实现仿射变换和透射变换

##1, 什么是仿射变换? 代码实现 import numpy as np import cv2 as cv import matplotlib.pyplot as plt#设置字体 from pylab import mpl mpl.rcParams[font.sans-serif] [SimHei]#图像的读取 img cv.imread("lena.png")#仿射变换 row…

clickhouse简单安装部署

目录 前言(来源于官方文档): 一.下载并上传 1.下载地址:点我跳转下载 2.上传至Linux 二.解压和配置 1.解压顺序 注意:必须按照以下顺序解压,并且每解压一个都要执行该解压后文件的install/doinst.sh文件 解压步骤&#xff…

Mycat管理及监控

Mycat管理 -h 是你自己的ip地址 相关命令及含义 Mycat-eye(图形化界面监控) 仅限于Linux系统

基于Xml方式Bean的配置-初始化方法和销毁方法

SpringBean的配置详解 Bean的初始化和销毁方法配置 Bean在被实例化后&#xff0c;可以执行指定的初始化方法完成一些初始化的操作&#xff0c;Bean在销毁之前也可以执行指定的销毁方法完成一些操作&#xff0c;初始化方法名称和销毁方法名称通过 <bean id"userService…