AlexNet项目图片分类通用模型代码

目录

一:建立AlexNet模型(在model文件中写)

1.构造5层卷积层

2.构造3层神经网络层

3.forward函数

4.模型最终代码

二:训练数据(在train中写)

1.读出数据

2.训练

3. 测试模型更新参数

4.完整的训练代码:

三:预测和模型评分(在predict文件中写)

 四:代码使用:

点个赞呗!!!!!!


 

一:建立AlexNet模型(在model文件中写)

AlexNet网络结构相对简单,使用了8层卷积神经网络,前5层是卷积层,剩下的3层是全连接层

1.构造5层卷积层

Conv2d:构造卷积层,参数:(输入通道数,输出通道数等价于卷积核个数,卷积核大小,步长,加0)

ReLU:激活函数,inplace设置是否改变数据

MaxPool2d:池化操作,参数:(核大小,步长)

import torch
import torch.nn as nnclass AlexNet(nn.Module):def __init__(self,num_classes=1000):super().__init__()self.features = nn.Sequential(nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(48, 128, kernel_size=5, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(128, 192, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2) # output [128, 6, 6])

2.构造3层神经网络层

Dropout:将比例数据置空,比如数据为(1,2,3,4,5,6),当参数p=0.5时,数据会变成:(1,0,3,0,5,0)。p代表置空的比例,当然这个置空是随机的

Linear:线性变换,参数:(输入数据的通道数,输出数据的通道数)

ReLu:和上面一样

self.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(128 * 6 * 6, 2048),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes),)

3.forward函数

torch.nn.Flatten(start_dim=1, end_dim=-1)
start_dim与end_dim代表合并的维度,开始的默认值为1,结束的默认值为 - 1,因此常被使用在神经网络当中,将每个batch的数据拉伸成一维

forward的作用:先让数据经过5层卷积,在经过3层全连接层

    def forward(self, x):x = self.features(x)x = torch.flatten(x, start_dim=1)x = self.classifier(x)return x

4.模型最终代码

import torch
import torch.nn as nnclass AlexNet(nn.Module):def __init__(self, num_classes=1000):super().__init__()self.features = nn.Sequential(nn.Conv2d(3, 48, kernel_size=11, stride=4, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(48, 128, kernel_size=5, padding=2),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2),nn.Conv2d(128, 192, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(192, 192, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.Conv2d(192, 128, kernel_size=3, padding=1),nn.ReLU(inplace=True),nn.MaxPool2d(kernel_size=3, stride=2) # output [128, 6, 6])self.classifier = nn.Sequential(nn.Dropout(p=0.5),nn.Linear(128 * 6 * 6, 2048),nn.ReLU(inplace=True),nn.Dropout(p=0.5),nn.Linear(2048, 2048),nn.ReLU(inplace=True),nn.Linear(2048, num_classes),)def forward(self, x):torch.nn.Flatten(start_dim=1, end_dim=-1)x = self.features(x)x = torch.flatten(x, start_dim=1)x = self.classifier(x)return x

二:训练数据(在train中写)

1.读出数据

RandomResizedCrop:对图片的区域随机改变图片的大小

RandomHorizontalFlip:对图片进行随机翻转

ToTensor:将图片转为Tensor数据类型

Normalize:对数据进行标准化,参数:(标准化的平均值元组,方差元组)

datasets.ImageFolder:读数据类别返回一个字典{0:类别一,1:类别二}, 这行代码可以获取数据的类别数以及对应的类别标签。以字典的形式保存

        参数:(root:读取文件的路径(注意:路径文件中不能直接放图片,应该放各个图片类别的文件),transfrom:对图片进行预处理函数)

torch.utils.data.DataLoader:读取数据:参数(dataset:数据加载的数据集,batch_size:每次加载多少样本数,suffle:是否打乱数据,num_workers:最多并行加载数量)

import os
import sys
import jsonfrom tqdm import tqdm
import torch
import torch.nn as nn
from torchvision import transforms,datasetsfrom model import AlexNetdef main():# 看看是否使用cpudevice=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')print(f'using:',{device})# 要读入图片的目录路径,以下代码假设你这个路径下有文件(train训练集, val测试集)image_path=os.path.join('./','训练集和测试集的图片路径')# print(image_path)# 判断这个路径是否存在,若不存在则报错image path done nit existassert os.path.exists(image_path),'image path done ont exist'# 创建读入数据后对数据处理的方法集合data_transform={# 将训练数据和测试数据的处理集使用对象的方法,以便后面使用'train':transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))]),'val':transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}train_dataset=datasets.ImageFolder(root=os.path.join(image_path,'train'),transform=data_transform['train'])train_num=len(train_dataset)# print(train_num)flower_list=train_dataset.class_to_idx# print(flower_list)cla_dict=dict((val,key) for key,val in flower_list.items())# print(cla_dict)json_str=json.dumps(cla_dict,indent=4)with open('class_indices.json','w') as json_file:json_file.write(json_str)# print(json_str)batch_size=32# 这行代码可以得到你电脑的cpu最大进程数量,如果大于16那么就按照16来nw=min([os.cpu_count(),batch_size if batch_size>1 else 0,16])print(f'using {nw} dataloader workers every process')# 读取训练集数据train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=nw)validata_dataset=datasets.ImageFolder(root=os.path.join(image_path,'val'),transform=data_transform['val'])val_num = len(validata_dataset)# 读取测试集数据validata_loader=torch.utils.data.DataLoader(validata_dataset,batch_size=batch_size,shuffle=False,num_workers=nw)

2.训练

每行代码都要注释哦!!!

    net.to(device)# 使用交叉熵损失函数loss_fn=nn.CrossEntropyLoss()# 使用优化器类,这里使用Adam优化器optimizer=torch.optim.Adam(net.parameters(),lr=0.0002)# 迭代数量epochs=10# 训练后的参数保存地址save_path='./AlexNet.pth'best_acc=0.0# 训练样本数train_step=len(train_loader)for epoch in range(epochs):# 开启训练模式net.train()# 初始化每次迭代的总损失running_loss=0.0# tqdm是一个进度条,将要迭代的数据放入,可以查看迭代的进度# stdout :它使用其参数直接显示在控制台窗口上。train_bar=tqdm(train_loader,file=sys.stdout)for step,data in enumerate(train_bar):# images是要训练的图片,labels是这张图片的类别images,labels=data# 将优化器的梯度置零optimizer.zero_grad()# 将图片加入到cpu中训练后返回outputsoutputs=net(images.to(device))# 计算损失loss=loss_fn(outputs,labels.to(device))# 反向传播计算参数loss.backward()# 跟新优化器中的参数optimizer.step()# 累加损失running_loss+=loss.item()# 输出语句train_bar.desc = f'train epoch {epoch + 1}/ {epochs} loss: {loss:.3f}'

3. 测试模型更新参数

       # 开启预测模型net.eval()acc=0.0# 关闭torch中的梯度记录with torch.no_grad():# 开启进度条val_bar=tqdm(validata_loader,file=sys.stdout)# 开始迭代for val_data in val_bar:# 测试集图片,测试集目标值val_images,val_labels=val_data# 开始预测outputs=net(val_images.to(device))# 获取预测的结果集,因为用的是softmax,所以取没张图片的结果集的最大值就是这张图片的预测结果predict_y=torch.max(outputs,dim=1)[1]# 计算总损失acc+=torch.eq(predict_y,val_labels.to(device)).sum().item()# 计算平均损失val_accuracy=acc/val_num# 打印结果print(f'[epoch {epoch + 1}] train_loss: {running_loss / train_step:.3f},         val_accuracy:{val_accuracy:.3f}')# 如果这次的结果比上次的好,就更新参数,否则不变if val_accuracy>best_acc:best_acc=val_accuracytorch.save(net.state_dict(),save_path)

4.完整的训练代码:

import os
import sys
import jsonfrom tqdm import tqdm
import torch
import torch.nn as nn
from torchvision import transforms,datasetsfrom model import AlexNetdef main():# 看看是否使用cpudevice=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')print(f'using:',{device})# 要读入图片的目录路径,以下代码假设你这个路径下有文件(train训练集, val测试集)image_path=os.path.join('./','训练集和测试集的图片路径')# print(image_path)# 判断这个路径是否存在,若不存在则报错image path done nit existassert os.path.exists(image_path),'image path done ont exist'# 创建读入数据后对数据处理的方法集合data_transform={# 将训练数据和测试数据的处理集使用对象的方法,以便后面使用'train':transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize((0.485, 0.456, 0.406),(0.229, 0.224, 0.225))]),'val':transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}train_dataset=datasets.ImageFolder(root=os.path.join(image_path,'train'),transform=data_transform['train'])train_num=len(train_dataset)# print(train_num)flower_list=train_dataset.class_to_idx# print(flower_list)cla_dict=dict((val,key) for key,val in flower_list.items())# print(cla_dict)json_str=json.dumps(cla_dict,indent=4)with open('class_indices.json','w') as json_file:json_file.write(json_str)# print(json_str)batch_size=32nw=min([os.cpu_count(),batch_size if batch_size>1 else 0,16])print(f'using {nw} dataloader workers every process')train_loader=torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=nw)validata_dataset=datasets.ImageFolder(root=os.path.join(image_path,'val'),transform=data_transform['val'])val_num = len(validata_dataset)validata_loader=torch.utils.data.DataLoader(validata_dataset,batch_size=batch_size,shuffle=False,num_workers=nw)net=AlexNet(num_classes=5)net.to(device)# 使用交叉熵损失函数loss_fn = nn.CrossEntropyLoss()# 使用优化器类,这里使用Adam优化器optimizer = torch.optim.Adam(net.parameters(), lr=0.0002)# 迭代数量epochs = 10#  训练后的参数保存地址save_path = './AlexNet.pth'best_acc = 0.0# 训练样本数train_step = len(train_loader)for epoch in range(epochs):# 开启训练模式net.train()# 初始化每次迭代的总损失running_loss = 0.0# tqdm是一个进度条,将要迭代的数据放入,可以查看迭代的进度# stdout :它使用其参数直接显示在控制台窗口上。train_bar = tqdm(train_loader, file=sys.stdout)for step, data in enumerate(train_bar):# images是要训练的图片,labels是这张图片的类别images, labels = data# 将优化器的梯度置零optimizer.zero_grad()# 将图片加入到cpu中训练后返回outputsoutputs = net(images.to(device))# 计算损失loss = loss_fn(outputs, labels.to(device))# 反向传播计算参数loss.backward()# 跟新优化器中的参数optimizer.step()# 累加损失running_loss += loss.item()# 输出语句train_bar.desc = f'train epoch {epoch + 1}/ {epochs} loss: {loss:.3f}'# 开启预测模型net.eval()acc = 0.0# 关闭torch中的梯度记录with torch.no_grad():# 开启进度条val_bar = tqdm(validata_loader, file=sys.stdout)# 开始迭代for val_data in val_bar:# 测试集图片,测试集目标值val_images, val_labels = val_data# 开始预测outputs = net(val_images.to(device))# 获取预测的结果集,因为用的是softmax,所以取没张图片的结果集的最大值就是这张图片的预测结果predict_y = torch.max(outputs, dim=1)[1]# 计算总损失acc += torch.eq(predict_y, val_labels.to(device)).sum().item()# 计算平均损失val_accuracy = acc / val_num# 打印结果print(f'[epoch {epoch + 1}] train_loss: {running_loss / train_step:.3f},         val_accuracy:{val_accuracy:.3f}')# 如果这次的结果比上次的好,就更新参数,否则不变if val_accuracy > best_acc:best_acc = val_accuracytorch.save(net.state_dict(), save_path)

三:预测和模型评分(在predict文件中写)

import os
import jsonimport torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as pltfrom model import AlexNetdef main():device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')# 处理图片函数data_transform=transforms.Compose([transforms.Resize((224,224)),transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])# 测试图片的路径img_path='./1.jpeg'# 判断这张图片是否存在assert os.path.exists(img_path),f'{img_path} does not exist'# 读取图片img=Image.open(img_path)# 将图片展示出来plt.imshow(img)# 使用上面的函数处理图片img=data_transform(img)# print(img.shape)# 对图片维度进行扩充img=torch.unsqueeze(img,dim=0)print(img.shape)# 训练的时候生成的图片类别文件json_path='./class_indices.json'# 判断类别文件是否存在assert os.path.exists(json_path),f'{json_path} done not exist'# 读取文件with open(json_path,'r') as f:class_dict=json.load(f)print(class_dict)# 建立模型model=AlexNet(num_classes=5).to(device)# 训练后的参数文件weights_path='./AlexNet.pth'# 判断参数文件是否存在assert os.path.exists(weights_path),f'file {weights_path} does not exist'# 模型加载参数model.load_state_dict(torch.load(weights_path))# 开启预测模式model.eval()# 关闭梯度with torch.no_grad():# 预测output=model(img.to(device))print(output)# 对图片维度压缩回来output=torch.squeeze(output).cpu()# 使用softmax函数分类predict=torch.softmax(output,dim=0)# 获得图片的预测概率最大的那个就是这张图片的类别predict_class=torch.argmax(predict).numpy()# 完成打印print_res = f"class: {class_dict[str(predict_class)]}, prob: {predict[predict_class].numpy():.3f}"# 将图片类别写在图片上plt.title(print_res)# 展示图片plt.show()if __name__ == '__main__':main()

 四:代码使用:

准备好需要训练和预测的数据集比如:

 文件中有训练集和测试集

 

将flower_data这个路径写到image_path这个变量中就可以

点个赞呗!!!!!!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1539217.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

问题——IMX6UL的uboot无法ping主机或Ubuntu

主要描述可能的方向,不涉具体过程,详细操作可以查阅网上相关教程 跟随正点原子教程测试以太网端口时,即便按照步骤多次尝试也无法ping通,后补充了些许网络工程基础知识解决了这个问题。 uboot无法ping主机或Ubuntu有多种可能&…

Redis集群知识及实战

1. 为什么使用集群 在哨兵模式中,仍然只有一个Master节点。当并发写请求较大时,哨兵模式并不能缓解写压力。我们知道只有主节点才具有写能力,那如果在一个集群中,能够配置多个主节点,是不是就可以缓解写压力了呢&…

总结拓展十:SAP开发计划(下)

第一节 接口功能开发说明书设计 1、软件系统接口作用 答:系统接口,是实现系统间数据传输的功能。 2、软件系统接口特点 1)采用Web Service技术作为平台,有众多的数据传输协议标准,通过API与外界交流数据。 2&…

Vscode搭配latex简易教程

1. 找镜像网站下载texlive的iso文件 清华源镜像 下载之后直接打开iso文件,打开install-tl-windows.bat文件,进行安装即可,安装大概30分钟左右 2. VScode端配置 2.1 下载这三个插件 2.2 打开设置 2.3 追加内容到配置json文件当中 // Latex…

14_input子系统my_touch_device,my_touch_handlerLinux内核模块

01_basicLinux内核模块_the kernel was built by:x86 64-linux-gnu-gcc-12(ub-CSDN博客文章浏览阅读678次,点赞3次,收藏3次。环境IDubuntuMakefilemodules:clean:basic.creturn 0;运行效果。_the kernel was built by:x86 64-linux-gnu-gcc-12(ubuntu 12…

贷款年利率迷局:年利率3.8%为何变成2.07%?

朋友们,聊聊贷款那点事儿,特别是那个让人又爱又恨的年利率,听起来简单,3.8%就是一年给银行3.8%的贷款总额当利息,对吧?但别急,这里头学问大着呢!有时候,你发现标着3.8%的…

keil调试变量值被篡改问题

今天遇到一个代码中变量值被篡改的问题,某个数组的第一个值运行一段时间之后变成了0,如图: 看现象基本可以断定是内存越界导致的,但是要如果定位是哪里内存越界呢? keil提供了两个工具 1、set access breakpoint at(设置访问断点…

ES6标准---【八】【学习ES6看这一篇就够了!!!】

目录 前言 export命令 输出变量 输出函数/类 export中的as别名 export必须一一对应 export接口的响应性 注意 import命令 import命令的语法 import命令里的as别名 import的只读性 import命令具有提升性 import的一些约定 import的静态执行 import的唯一执行性 模…

基于SmartUpload组件实现文件上传功能的案例

SmartUpload组件简介 SmartUpload组件 专门用于实现文件上传及下载的免费组件SmartUpload组件特点 使用简单:编写少量代码,完成上传下载功能能够控制上传内容能够控制上传文件的大小、类型缺点:目前已停止更新服务 SmartUpload组件应用 单文…

【Java】多线程前置知识 初识Thread

多线程前置知识 & 初识Thread 冯诺依曼体系结构初步认识存储设备CPU指令 操作系统初识操作系统内核态和用户态 进程/任务进程是什么进程的管理进程的调度虚拟内存地址进程间的通信 线程线程的出现线程是什么线程可能出现的问题线程与进程的联系和区别 协程初识Thread类Thre…

Java lambda表达式的变量捕获

有人看到这个lambda表达式能够访问isQuit这个变量而且还是可以被修改的变量,就发出疑问了,之前不是说lambda不能不或变量吗? 1.规则 java的lambda表达式变量捕获规则只是针对于外部作用域的局部变量来说的!!&#xf…

Linux环境变量进程地址空间

目录 一、初步认识环境变量 1.1常见的环境变量 1.2环境变量的基本概念 二、命令行参数 2.1通过命令行参数获取环境变量 2.2本地变量和内建命令 2.3环境变量的获取 三、进程地址空间 3.1进程(虚拟)地址空间的引入 3.2进程地址空间的布局和理解 …

【机器学习】:深潜智能的底层逻辑、前沿探索与未来展望】

欢迎来到 破晓的历程的 博客 ⛺️不负时光,不负己✈️ 在科技的浩瀚星空中,机器学习犹如一颗璀璨的新星,以其独特的魅力和无限潜力,引领着我们向智能的深处探索。今天,我们将一同踏上这场深度之旅,不仅解析…

pdf图片怎么提取出来?这6个pdf图片提取工具全搞定,值得推荐!

在我们的日常办公和学习中,pdf文件成为了信息传递的重要载体。然而,有时我们在pdf文档中发现一些精彩的图片,想将其提取出来供个人使用或分享给他人。无论是为了更灵活的处理,还是为了发送特定的图像,提取pdf中的图片都…

国产新港海岸NCS8622Type-C/DP1.4 to HDMI2.0 Converter

NCS8622是一款高性能低功耗的Type-C/DP1.4至HDMI2.0转换器,设计用于将USB Type-C源或DP1.4源连接到HDMI2.0。
 NCS8622集成了符合DP1.4标准的接收器和符合HDMI2.0标准的发射器。 此外,CC控制器用于CC通信以实现DP替代模式。
 DP接收器集成了HDCP 1.…

gazebo 仿真阶段性问题汇总二

目录 写在前面的话遇到的问题问题一:启动了多个 robot_state_publisher解决办法 问题二:rviz 启动报错解决办法 问题三:rviz 中 wheel 一直指向 base_link解决方法 问题四:摄像头和opencv坐标系的问题解决方法 问题五:…

Submariner 部署全过程

Submariner 部署全过程 部署集群配置 broker 集群: pod-cidr:11.244.0.0/16 service-cidr 11.96.0.0/12 broker 172.100.0.109 node 172.100.0.108 集群 1( pve3 ): pod-cidr:10.244.0.0/16 service-…

微信支付开发-支付工厂JsApi产品代码

一、JSAPI支付产品、APP支付产品、小程序支付产品流程图 二、H5支付产品、Native支付产品 三、工厂父类抽象类代码开发 <?php /*** 微信父类抽象类* User: 龙哥三年风水* Date: 2024/9/19* Time: 11:33*/ namespace Payment\WechatPay; abstract class WechatPaymentHandl…

翻页时钟 2.0-自动置顶显示,点击小时切换显示标题栏不显示标题栏-供大家学习研究参考

更新内容 自动置顶显示点击小时切换显示标题栏&#xff0c;&#xff08;显示标题栏后可移动时钟位置&#xff0c;鼠标拖动边框调整时钟大小&#xff09;不显示标题栏时&#xff0c;透明部分光标可穿透修正一个显示bu 下载地址&#xff1a; https://download.csdn.net/download…

一站式项目管理系统如何实现全链条数字化管理?

在当今数字化高速发展的时代&#xff0c;项目申报领域也面临着管理方式的革新挑战。从传统的人工管理到如今追求高效、精准的数字化管理模式转变&#xff0c;是行业发展的必然趋势。如启服云项目管理系统之类的出现&#xff0c;为项目申报管理带来了新的思路。 立项阶段的数字化…