《迁移学习》—— 将 ResNet18 模型迁移到食物分类项目中

文章目录

  • 一、迁移学习的简单介绍
    • 1.迁移学习是什么?
    • 2.迁移学习的步骤
  • 二、数据集介绍
  • 三、代码实现
    • 1. 步骤
    • 2.所用到方法介绍的文章链接
    • 3. 完整代码

一、迁移学习的简单介绍

1.迁移学习是什么?

  • 迁移学习是指利用已经训练好的模型,在新的任务上进行微调。
  • 迁移学习可以加快模型训练速度,提高模型性能,并且在数据稀缺的情况下也能很好地工作。

2.迁移学习的步骤

  • (1) 选择预训练的模型和适当的层:通常,我们会选择在大规模图像数据集(如ImageNet)上预训练的模型,如VGG、ResNet等。然后,根据新数据集的特点,选择需要微调的模型层。对于低级特征的任务(如边缘检测),最好使用浅层模型的层,而对于高级特征的任务(如分类),则应选择更深层次的模型。
  • (2) 冻结预训练模型的参数:保持预训练模型的权重不变,只训练新增加的层或者微调一些层,避免因为在数据集中过拟合导致预训练模型过度拟合。
  • (3) 在新数据集上训练新增加的层:在冻结预训练模型的参数情况下,训练新增加的层。这样,可以使新模型适应新的任务,从而获得更高的性能。
  • (4) 微调预训练模型的层:在新层上进行训练后,可以解冻一些已经训练过的层,并且将它们作为微调的目标。这样做可以提高模型在新数据集上的性能。
  • (5) 评估和测试:在训练完成之后,使用测试集对模型进行评估。如果模型的性能仍然不够好,可以尝试调整超参数或者更改微调层。

二、数据集介绍

  • 下图是数据集的结构
    • 在 food_dataset2 文件夹下含有训练数据和测试数据
    • 训练集和测试集数据中都含有 20 种食物图片,数量在200~400不等
    • trainda.txt 和 testda.txt 文本中存放了每张图片的路径及标签,用 0~19 这20个数字分别对20种食物进行标签
    • 在代码中通过trainda.txt 和 testda.txt 文本中的内容来获取每张图片及对应的标签
      在这里插入图片描述
    • 下面是trainda.txt文本中的部分内容(testda.txt 中的内容格式相同)
      在这里插入图片描述
  • 送福利!!! 私信送此数据集 !!!

三、代码实现

1. 步骤

  • 1.调用resnet18模型,并保存需要训练的模型参数
  • 2.定义一个图像预处理和数据增强字典
  • 3.定义获取每张食物图片和标签的类方法
  • 4.获取训练集和测试集数据
  • 5.对数据集进行打包
  • 6.调用交叉熵损失函数并创建优化器
  • 7.定义训练模型的函数
  • 8.定义测试模型的函数
  • 9.训练模型,并每训练一轮测试一次

2.所用到方法介绍的文章链接

  • ResNet 残差网络神经网络
    • https://blog.csdn.net/weixin_73504499/article/details/142575775?spm=1001.2014.3001.5501
  • 数据增强
    • https://blog.csdn.net/weixin_73504499/article/details/142499263?spm=1001.2014.3001.5501
  • 调整学习率
    • https://blog.csdn.net/weixin_73504499/article/details/142526863?spm=1001.2014.3001.5501

3. 完整代码

import torchimport torchvision.models as models  # 导入存有各种深度学习模型的模块from torch import nn  # 导入神经网络模块from torch.utils.data import Dataset, DataLoader  # Dataset: 抽象类,一种用于获取数据的方法  DataLoader:数据包管理工具,打包数据from torchvision import transforms  # transforms模块提供了一系列用于图像预处理和数据增强的函数和类from PIL import Image  # 用于处理图片import numpy as np""" 调用resnet18模型 """resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)for param in resnet_model.parameters():param.requires_grad = False# 模型所有参数(即权重和偏差)的 requires_grad 属性设置成 False,从而冻结所有模型参数# 使得在反向传播过程中不会计算他们的梯度,从此减少模型的计算量,提高推理速度in_features = resnet_model.fc.in_features  # 获取resnet18模型全连接层原输入的特征个数resnet_model.fc = nn.Linear(in_features, 20)  # 创建一个全连接层输入特征个数为: in_features  输出特征个数为:数据集中事务的种类数量params_to_update = []  # 保存需要训练的参数,仅仅包含全连接层的参数for param in resnet_model.parameters():if param.requires_grad == True:params_to_update.append(param)""" 图像预处理和数据增强 """data_transforms = {'train':transforms.Compose([transforms.Resize([300, 300]),transforms.RandomRotation(45),  # 随机旋转,-45到45度之间随机transforms.CenterCrop(224),  # 中心裁剪transforms.RandomHorizontalFlip(p=0.5),  # 随机水平反转 选择一个概率transforms.RandomVerticalFlip(p=0.5),  # 随机垂直翻转# transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),  # 亮度、对比度transforms.RandomGrayscale(p=0.1),  # 概率转换成灰度率,3通道就是R G Btransforms.ToTensor(),  # 转化为神经网络可以识别的 Tensor 类型transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 对图片数据进行归一化,[均值],[标准差]]),'valid':transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}""" 定义获取每张食物图片和标签的类方法 """class food_dataset(Dataset):def __init__(self, file_path, transform=None):self.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:samples = [x.strip().split(' ') for x in f.readlines()]for img_path, label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image = Image.open(self.imgs[idx])if self.transform:image = self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label, dtype=np.int64))return image, label""" 获取训练集和测试集数据 """training_data = food_dataset(file_path='trainda.txt', transform=data_transforms['train'])test_data = food_dataset(file_path='testda.txt', transform=data_transforms['valid'])""" 对数据集进行打包 """train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)  # 64张图片为一个包,shuffle --> 打乱顺序test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)""" 判断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU """device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"print(f"Using {device} device")# 把模型传入到 gpu 或 cpumodel = resnet_model.to(device)""" 调用交叉熵损失函数 """loss_fn = nn.CrossEntropyLoss()"""" 创建优化器并调整优化器中的学习率--> lr """optimizer = torch.optim.Adam(params_to_update, lr=0.001)scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)""" 定义训练模型的函数 """def train(dataloader, model, loss_fn, optimizer):model.train()  # 告诉模型,开始训练# pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。# 一般用法是:在训练开始之前写上model.trian(),在测试时写上model.for X, y in dataloader:X, y = X.to(device), y.to(device)  # 把训练数据集和标签传入cpu或gpupred = model.forward(X)  # .forward可以被省略,父类中已经对次功能进行了设置。自动初始化w权值loss = loss_fn(pred, y)  # 通过交叉熵损失函数计算损失值 lossoptimizer.zero_grad()  # 梯度值清零loss.backward()  # 反向传播计算得到每个参数的梯度值woptimizer.step()  # 根据梯度更新网络w参数""" 定义测试模型的函数 """best_acc = 0  # 用于更新准确率def test(dataloader, model, loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()  # 测试,w就不能再更新test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)test_loss += loss_fn(pred, y).item()  # test_loss是会自动累加每一个批次的损失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()  # correct是会自动累加每一个批次的正确率test_loss /= num_batches  # 平均的损失值correct /= size  # 平均的正确率print(f"Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}")# 找到最好的准确率if correct > best_acc:best_acc = correct""" 定义模型训练的轮数,并每训练一轮测试一次 """epochs = 30for e in range(epochs):print(f"Epoch {e + 1}\n---------------------------")train(train_dataloader, model, loss_fn, optimizer)scheduler.step()  # 在每个epoch的训练中,使用scheduler.step()语句进行学习率更新test(test_dataloader, model, loss_fn)print('最优的训练结果为:', best_acc)
  • 结果如下
    • 此结果只是训练了30轮后的结果,可以训练更多轮,最后的准确率还会有所提高
      在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1549676.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Linux防火墙-常用命令,零基础入门到精通,收藏这一篇就够了

我们经过上小章节讲了Linux的部分进阶命令,我们接下来一章节来讲讲Linux防火墙。由于目前以云服务器为主,而云服务器基本上就不会使用系统自带的防火墙,而是使用安全组来代替了防火墙的功能,可以简单理解安全组就是web版的防火墙&…

Windows环境下训练开源图像超分项目 ECBSR 教程

ECBSR 介绍 ECBSR(Edge-oriented Convolution Block for Real-time Super Resolution)是一种针对移动设备设计的轻量级超分辨率网络。它的核心是一种可重参数化的构建模块,称为边缘导向卷积块(ECB),这种模…

Qt 首次配置 Qt Creator 14.01 for Python

前言: 如何用QT实现Python的配置的交互界面。本文从0开始,进行实践的介绍。 在上一节里面,我们做了社区版本的配置: https://blog.csdn.net/yellow_hill/article/details/142597007?spm1001.2014.3001.5501 这一节&#xff0…

vue+UEditor附件上传问题

🏆本文收录于《全栈Bug调优(实战版)》专栏,主要记录项目实战过程中所遇到的Bug或因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&am…

快速实现AI搜索!Fivetran 支持 Milvus 作为数据迁移目标

Fivetran 现已支持 Milvus 向量数据库作为数据迁移的目标,能够有效简化 RAG 应用和 AI 搜索中数据源接入的流程。 数据是 AI 应用的支柱,无缝连接数据是充分释放数据潜力的关键。非结构化数据对于企业搜索和检索增强生成(RAG)聊天…

Java SPI 原理、样例

在 Java 中,SPI(Service Provider Interface)全称为服务提供者接口,它是一种用于实现框架扩展和插件化的机制。 一、SPI 作用 允许在运行时动态地为接口查找服务实现,而不需要在代码中显式地指定具体的实现类。 这使得…

利用多模态输入的自我中心运动跟踪与理解框架:EgoLM

随着增强现实(AR)和虚拟现实(VR)技术的发展,对自我中心(第一人称视角)运动的精确跟踪和理解变得越来越重要。传统的单一模态方法在处理复杂场景时存在诸多局限性。为了解决这些问题,研究者们提出了一种基于多模态输入的自我中心运动跟踪与理解框架——EgoLM。本文将详细…

MySQL-数据库约束

1.约束类型 类型说明NOT NULL非空约束 指定非空约束的列不能存储NULL值 DEFAULT默认约束当没有给列赋值时使用的默认值UNIQUE唯一约束指定唯一约束的列每行数据必须有唯一的值PRIMARY KEY主键约束NOT NULL和UNIQUE的结合,可以指定一个列霍多个列,有助于…

文章解读与仿真程序复现思路——中国电机工程学报EI\CSCD\北大核心《考虑异步区域调频资源互济的电能、惯性与一次调频联合优化出清模型》

本专栏栏目提供文章与程序复现思路,具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…

Android页面跳转与返回机制详解

在Android开发中,页面跳转是实现应用功能交互的重要手段之一。本文将从Activity之间的跳转、Activity与Fragment之间的跳转、Fragment之间的跳转以及页面返回的问题四个方面进行详细解析。 一、Activity之间的跳转 Activity是Android应用的基本构建块,…

7.6透视变换

基本概念 在计算机视觉和图像处理领域中,透视变换(Perspective Transformation)是一种重要的几何变换,用于模拟从一个视角到另一个视角的变换,比如从鸟瞰视角到正面视角的变换。透视变换通常用于图像配准、增强现实、…

《志愿军·存亡之战》首映礼热血与感动并存,陈飞宇一年后再报这串番号

9月27日,国庆档电影《志愿军:存亡之战》在北京举行首映礼。导演陈凯歌,总制片人陈红,编剧张珂,演员朱一龙、辛柏青、张子枫、朱亚文、陈飞宇、张宥浩等在映后齐亮相。其中陈飞宇饰演的孙醒,作为贯穿一、二两…

如何快速自定义一个Spring Boot Starter!!

目录 引言: 一. 我们先创建一个starter模块 二. 创建一个自动配置类 三. 测试启动 引言: 在我们项目中,可能经常用到别人的第三方依赖,又是引入依赖,又要自定义配置,非常繁琐,当我们另一个项…

【C++报错已解决】std::ios_base::floatfield

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 专栏介绍 在软件开发和日常使用中,BUG是不可避免的。本专栏致力于为广大开发者和技术爱好者提供一个关于BUG解决的经…

8609 哈夫曼树

### 思路 1. **选择最小权值节点**:在哈夫曼树构建过程中,选择两个权值最小且父节点为0的节点。 2. **构建哈夫曼树**:根据权值构建哈夫曼树,确保左子树权值小于右子树权值。 3. **生成哈夫曼编码**:从叶子节点到根节点…

极限基本类型小结

极限基本类型小结 在之前的文章中已经看过了极限的多种基本类型,下面展示一些各种基本类型的代表性的图像,通过观察下面的图像可以帮助我们回顾函数在趋近于某一点时函数值的行为(这也叫极限值),也生动的描述了各种极…

初始爬虫9

1.元素定位后的操作 “find_element“仅仅能够获取元素,不能够直接获取其中的数据,如果需要获取数据需要使用以下方法”。下面列出了两个方法: 获取文本 element.text 通过定位获取的标签对象的 text 属性,获取文本内容 获取属性…

C语言进阶版第13课—字符函数和字符串函数2

文章目录 1. strstr函数的使用和模拟实现1.1 strstr函数的使用1.2 模拟实现strstr函数1.3 strstr函数和strncpy函数、puts函数的混合使用 2. strtok函数的使用**3. strerror函数的使用** 1. strstr函数的使用和模拟实现 1.1 strstr函数的使用 strstr函数是用来通过一个字符串来…

Linux进程-2

一:进程优先级 基本概念 cpu资源分配的先后顺序,就是指进程的优先权(priority)。 优先权高的进程有优先执行权利。配置进程优先权对多任务环境的linux很有用,可以改善系统性能。 还可以把进程运行到指定的CPU上&#…

Mysql数据库相关操作总结

目录 1.背景知识 2.创建数据库 2.1创建指令 2.2字符集 3.查看数据库 3.选中数据库 4.删除数据库 5.数据表的操作 5.1基本数据类型 5.2创建表 5.3查看所有的表 5.4查看表的结构 5.5删除表 6.CRUD增删查改 6.1新增和效果查看 6.3删除 6.4查找 1.背景知识 数据库就…