打造你的Pokemon大师:深度学习多分类模型构建与本地部署全攻略

打造你的Pokemon大师:深度学习多分类模型构建与本地部署全攻略

引言

在这篇文章中,我将分享如何从头构建一个用于多分类任务的神经网络,并将其部署到本地环境。这是一个实践性质的教程,适合那些对深度学习模型部署感兴趣的初学者。

数据集准备

为了本次训练,我选择了一个网上流行的宝可梦数据集,它包含5个类别,每个类别的图片都存放在各自的文件夹中。为了确保标签和类别的一致性,我使用了sorted()函数对文件夹名称进行排序,并通过enumerate为每个类别分配一个唯一的标签。

以下是数据集准备的代码实现:

import glob
import os
import cv2
from torch.utils.data import DataLoader, Dataset
import torchvision
from PIL import Image
import randomclass PokemonData(Dataset):def __init__(self, root_path, mode=None):super(PokemonData, self).__init__()self.pokemon_names = sorted(os.listdir(root_path))self.labels = {name: i for i, name in enumerate(self.pokemon_names)}self.all_imgs = []for name in self.pokemon_names:self.all_imgs.extend(glob.glob(os.path.join(root_path, name, '*')))random.shuffle(self.all_imgs)self.imgs = self.all_imgs[:int(len(self.all_imgs) * 0.8)] if mode == "train" else self.all_imgs[int(len(self.all_imgs) * 0.8):]def __len__(self):return len(self.imgs)def __getitem__(self, item):name = self.imgs[item].split("\\")[-2]img = cv2.imread(self.imgs[item])tf_img = self.transformData(img)label = self.labels[name]return label, tf_imgdef transformData(self, img):img = Image.fromarray(img)tf_img = torchvision.transforms.Compose([torchvision.transforms.Resize((224, 224)),torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])img = tf_img(img)return imgif __name__ == '__main__':root_path = "E:/pokemon"mode = "train"pd = PokemonData(root_path, mode)pd_datas = DataLoader(pd, batch_size=32, shuffle=True)for label, data in pd_datas:print(label, data.shape)

模型构建

我选择了预训练的ResNet18模型作为基础,因为它在性能和计算资源之间取得了良好的平衡。ResNet18的最后一层输出1000个类别,我们需要将其替换为适合我们数据集的输出层。

from torchvision import models
import torch
import torch.nn as nnclass CustomResNet18(nn.Module):def __init__(self):super(CustomResNet18, self).__init__()self.base_model = models.resnet18(pretrained=True)self.fc = nn.Linear(512, 5)def forward(self, x):x = self.base_model(x)x = self.fc(x)return x

训练与评估

接下来,我们训练模型并评估其性能。我们使用交叉熵损失函数和Adam优化器。

# 训练和评估代码省略,与原文中相同

模型格式转换

为了提高预测性能,我们将PyTorch模型转换为ONNX格式。ONNX是一种开放的模型格式,允许模型在不同的框架和硬件之间迁移。

import onnx
import torch# 导入自定义模型
from custom_resnet import CustomResNet18model = CustomResNet18()
model.load_state_dict(torch.load("best_model.pt"))
model.eval()x = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, x, "best_model.onnx", input_names=["input"], output_names=["output"], opset_version=11)onnx_model = onnx.load("best_model.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX模型格式正确!")

本地部署与预测

最后,我们使用ONNX模型进行本地预测。以下是如何加载ONNX模型并对一张图片进行分类的示例代码。

import onnxruntime as ort
import torch
import cv2
from PIL import Image
from torchvision import transforms# 加载ONNX模型
session = ort.InferenceSession("best_model.onnx")# 图像预处理
data_preproce = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])img_cv = cv2.imread("test.png")
img_pil = Image.fromarray(img_cv)
img = data_preproce(img_pil)
input_tensor = torch.unsqueeze(img, 0).numpy()# 进行预测
pred = session.run(None, {"input": input_tensor})[0]
pred_softmax = torch.softmax(torch.tensor(pred), dim=1)
values, indices = torch.topk(pred_softmax, 3)# 显示预测结果
labels_dict = {'bulbasaur': 0, 'charmander': 1, 'mewtwo': 2, 'pikachu': 3, 'squirtle': 4}
revers_dict = {v: k for k, v in labels_dict.items()}
for i in indices[0].tolist():print(revers_dict[i], ":", round(values[0].tolist()[index_n.index(i)] * 100, 5), "%")

通过这篇文章,我们不仅学习了如何构建和训练一个多分类神经网络,还了解了如何将其部署到本地环境并进行预测。希望这篇文章对你有所帮助!

请添加图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/3776.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

TikTok云手机怎样进行TikTok矩阵运营?

由于地区限制,国内无法直接访问TikTok。本文将介绍如何借助海外版TikTok云手机实现多账号管理,也就是矩阵运营,并探讨这种方式对提升工作效率的优势。 常见的多账号管理方式 许多人尝试通过VPN访问TikTok,但通常会遇到连接不稳定、…

光控资本:进入超级事件周 A股将如何运行

从国内来看,近期最重要的作业无疑是十四届全国人大常委会第十二次会议,该会议将于11月4日至8日在北京举办。商场广泛预期,本次会议将审议上调政府债务限额的议题,并或许推进新一轮的债务化解作业。这些方针意向有望为商场带来新的…

D59【python 接口自动化学习】- python基础之异常

day59 捕获异常常见问题 学习日期:20241105 学习目标:异常 -- 75 避坑指南:编写捕获异常程序时经常出现的问题 学习笔记: 捕获位置设置不当 设置范围不当 捕获处理设置不当 嵌套try-except语法错误 总结 位置,范围…

“高效开发之路:用Spring MVC构建健壮的企业级应用”

一、SpringMVC框架概念: (一)概述 SpringMVC是Spring框架的一个模块,Spring和SpringMVC无需中间整合层整合。该模块是一个基于MVC的web框架。 作用:只要需要前后端通信,就需要springMVC帮我完成&#xff…

Unity使用Spine导致设备发烫

spine制作过程中,美术同学使用裁剪技术 将一个特效文件做固定范围显示,实际上非常消耗CPU算力。 解决办法: 交给程序来实现裁剪,只要加Mask组件即可

if-else语句+例题练手(2)

前面我们讲过循环语句的for、while、do-while的使用,即组成C语言中的循环结构,而除了循环其实还有顺序和选择,顺序结构就是顺着程序中的代码一行一行执行下去,而选择为分支结构,有if语句和switch语句,今天先讲if语句和…

HTTP服务器测试与优化

目录 1 搭建一个基础的HTTP服务器 2 长连接测试 3 测试错误报文的处理 4 测试业务处理耗时超过超时时间的处理 5 测试同时收到多条正常请求 6 大文件传输测试 7 压力测试 1 搭建一个基础的HTTP服务器 在这个部分,我们需要搭建一个最简单的HTTP服务器&#xf…

【spring】Cookie和Session的设置与获取(@CookieValue()和@SessionAttribute())

💐个人主页:初晴~ 📚相关专栏:程序猿的春天 获取Cookie 使用 Servlet 获取Cookie: Spring MVC 是基于 Servlet API 构建的原始 Web 框架,也是在 Servlet 的基础上实现的 RestController RequestMapping…

网页版五子棋—— WebSocket 协议

目录 前言 一、背景介绍 二、原理解析 1.连接过程(握手) 2.报文格式 三、代码示例 1.服务端代码 (1)TestAPI 类 (2)WebSocketConfig 类 2.客户端代码 3.代码演示 结尾 前言 从本篇文章开始&am…

【Go语言】| 第2课:变量声明与、初始化、匿名变量和作用域

😎 作者介绍:我是程序员洲洲,一个热爱写作的非著名程序员。CSDN全栈优质领域创作者、华为云博客社区云享专家、阿里云博客社区专家博主。 🤓 同时欢迎大家关注其他专栏,我将分享Web前后端开发、人工智能、机器学习、深…

K8S简单部署,以及UI界面配置

准备两台服务器K8Smaster和K8Sminion 分别在两台服务器上执行以下代码 #添加hosts解析&#xff1b; cat >/etc/hosts<<EOF 127.0.0.1 localhost localhost.localdomain 192.168.45.133 master1 192.168.45.135 node2 EOF #临时关闭selinux和防火墙&#xff1b; sed …

创业初期,找了个没有成本的地方当办公场地

大家好&#xff0c;我是小悟。 如果我问你&#xff0c;创业的第一步是什么&#xff1f;或许你会说资金、团队、市场定位&#xff0c;这些确实都是创业不可或缺的因素。找办公场地也是很重要的一个环节&#xff0c;但如果我现在告诉你&#xff0c;把图书馆作为办公场地&#xf…

一个记事本(可复制源码)

htmlcssjs做了一个记事本&#xff0c;可复制源码 html <!DOCTYPE html> <html lang"zh"> <head> <meta charset"UTF-8"> <meta name"viewport" content"widthdevice-width, initial-scale1.0">…

川大华西团队发表关于早期癌症研究的综述,探索AI在预测癌症进展的应用|顶刊精析·24-11-05

小罗碎碎念 这篇文章于2024-10-21发表在《Nature Reviews Cancer》&#xff0c;是一篇关于早期癌症研究的综述文章&#xff0c;标题为《Emerging strategies to investigate the biology of early cancer》。 作者角色姓名单位名称&#xff08;中文&#xff09;第一作者Ran Zho…

AI 翻唱

本文记录用 So-vits-svc 4.1 训练模型全过程。 需要用到的工具 UVR&#xff1a;用于人声歌声分离&#xff0c;降噪。 (项目传送门) Slicer-gui(Audio-Slicer)&#xff1a;用于音频裁剪。(项目传送门) So-vits-svc 4.1&#xff1a;训练模型&#xff0c;GitHub项目中详细介绍…

讲讲⾼可用的原则?

大家好&#xff0c;我是锋哥。今天分享关于【讲讲⾼可用的原则&#xff1f;】面试题。希望对大家有帮助&#xff1b; 讲讲⾼可用的原则&#xff1f; 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 在当今信息化时代&#xff0c;随着互联网技术的快速发展&#xff0…

Java 基于SpringBoot+Vue 的公交智能化系统,附源码、文档

博主介绍&#xff1a;✌程序员徐师兄、7年大厂程序员经历。全网粉丝12w、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精彩专栏推荐订阅&#x1f447;…

Leetcode 二叉树中的最大路径和

算法思想 这道题要求在一棵二叉树中找到路径和最大的路径。路径可以从树中任意一个节点开始&#xff0c;到任意一个节点结束&#xff0c;但路径上的节点必须是连续的。 算法使用递归的方式来遍历树中的每个节点&#xff0c;并在遍历过程中计算包含当前节点的最大路径和。具体…

《2024中国城市音乐产业发展指数报告》重磅发布

11月4日,《2024中国城市音乐产业发展指数研究报告》(以下简称“报告”)在成都首次公开发布。该报告由中国音像与数字出版协会音乐产业促进工作委员会指导编制,道略产业研究院、四川音乐学院孙洪斌教授团队深度参与。 该指数评价对象涵盖直辖市、副省级城市和省会城市等共36个城…

解锁金融未来,Python带你玩转大数据!

厌倦了复杂的金融报表&#xff0c;想用数据驱动投资决策&#xff0c;却不知从何下手&#xff1f; 别担心&#xff01; 《Python金融大数据分析快速入门与案例详解》带你轻松入门&#xff0c;掌握数据分析利器&#xff0c;成为金融领域的弄潮儿&#xff01; 为什么选择这本书&…