深度学习:迁移学习

目录

一、迁移学习

1.什么是迁移学习

2.迁移学习的步骤

1、选择预训练的模型和适当的层

2、冻结预训练模型的参数

3、在新数据集上训练新增加的层

4、微调预训练模型的层

5、评估和测试

二、迁移学习实例

1.导入模型

2.冻结模型参数

3.修改参数

4.创建类,数据增强,导入数据

5.定义训练集和测试集函数

6.将模型传入GPU,并有序调整学习率

7.进行训练和测试


一、迁移学习

1.什么是迁移学习

        迁移学习是指利用已经训练好的模型,在新的任务上进行微调。迁移学习可以加快模型训练速度,提高模型性能,并且在数据稀缺的情况下也能很好地工作。

 

2.迁移学习的步骤

1、选择预训练的模型和适当的层

        通常,我们会选择在大规模图像数据集(如ImageNet)上预训练的模型,如VGG、ResNet等。然后,根据新数据集的特点,选择需要微调的模型层。对于低级特征的任务(如边缘检测),最好使用浅层模型的层,而对于高级特征的任务(如分类),则应选择更深层次的模型。

 

2、冻结预训练模型的参数

        保持预训练模型的权重不变,只训练新增加的层或者微调一些层,避免因为在数据集中过拟合导致预训练模型过度拟合。

 

3、在新数据集上训练新增加的层

        在冻结预训练模型的参数情况下,训练新增加的层。这样,可以使新模型适应新的任务,从而获得更高的性能。

 

4、微调预训练模型的层

        在新层上进行训练后,可以解冻一些已经训练过的层,并且将它们作为微调的目标。这样做可以提高模型在新数据集上的性能。

 

5、评估和测试

        在训练完成之后,使用测试集对模型进行评估。如果模型的性能仍然不够好,可以尝试调整超参数或者更改微调层。

 

二、迁移学习实例

  • 该实例使用的模型是ResNet-18残差神经网络模型

 

1.导入模型

  • 导入所要用的库,加载ResNet18模型
import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np"""将resnet18模型迁移到食物分类项目中"""
resent_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)  # 既调用了resnet18网络,又使用了训练好的模型 在这里下载了模型

 

2.冻结模型参数

  • 将导入的模型参数冻结
for param in resent_model.parameters():param.requires_grad = False  # 设置每个参数的requires_grad属性为False,表示在训练过程中这些参数不需要计算梯度,也就是说它们不会在反向传播中更新。# print(param)
# 模型所有参数(即权重和偏差)的requires_grad属性设置为False,从而冻结所有模型参数
# 使得在反向传播过程中不会计算它们的梯度,以此减少模型的计算量,提高理速度。

 

3.修改参数

  • 因为我们所用的数据分类是20个,原模型分类是1000个,所以需要修改全连接层
  • 获取原模型输入层的特征个数
  • 将原模型的全连接层替换成原输入,输出为20的全连接层
  • 保存需要训练的参数,后面优化器进行优化时就可以只训练该层参数
in_features = resent_model.fc.in_features  # 获取模型原输入的特征个数
resent_model.fc = nn.Linear(in_features, 20)  # 创建一个全连接层,输入特征为in_features,输出为20param_to_update = []  # 保存需要训练的参数,仅仅包含全连接层的参数
for param in resent_model.parameters():if param.requires_grad == True:param_to_update.append(param)

 

4.创建类,数据增强,导入数据

  • 将图片从本地导入,并进行数据增强,最后进行打包
class food_dataset(Dataset):def __init__(self, file_path, transform=None):  # 类的初始化,解析数据文件txtself.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:  # 是把train.txt文件中图片的路径保存在 self.imgs,train.txt文件中标签保存在self.label里samples = [x.strip().split(' ') for x in f.readlines()]  # 去掉首尾空格 再按空格分成两个元素for img_path, label in samples:self.imgs.append(img_path)  # 图像的路径self.labels.append(label)  # 标签,还不是tensor# 初始化:把图片目录加载到selfdef __len__(self):  # 类实例化对象后,可以使用len函数测量对象的个数return len(self.imgs)def __getitem__(self, idx):  # 关键,可通过索引的形式获取每一个图片数据及标签image = Image.open(self.imgs[idx])  # 读取到图片数据,还不是tensorif self.transform:# 将pil图像数据转换为tensorimage = self.transform(image)  # 图像处理为256x256,转换为tenorlabel = self.labels[idx]  # label还不是tensorlabel = torch.from_numpy(np.array(label, dtype=np.int64))  # label也转换为tensorreturn image, labeldata_transforms = {'train':transforms.Compose([transforms.Resize([300, 300]),transforms.RandomRotation(45),transforms.CenterCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),# transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),transforms.RandomGrayscale(p=0.1),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 为 ImageNet 数据集计算的标准化参数]),'test':transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 为 ImageNet 数据集计算的标准化参数])
}train_data = food_dataset(file_path=r'trainda.txt',transform=data_transforms['train'])  # 64张图片为一个包  训练集60000张图片 打包成了938个包
test_data = food_dataset(file_path=r'testda.txt', transform=data_transforms['test'])train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

 

5.定义训练集和测试集函数

def train(dataloader, model, loss_fn, optimizer):model.train()  # 告诉模型,我要开始训练,模型中w进行随机化操作,已经更新w.在训练过程中,w会被修改的batch_size_num = 1for x, y in dataloader:x, y = x.to(device), y.to(device)  # 把训练数据集和标签传入CPU或GPUpred = model.forward(x)  # 向前传播loss = loss_fn(pred, y)  # 通过交叉熵损失函数计算损失值lossoptimizer.zero_grad()  # 梯度值清零loss.backward()  # 反向传播计算得到每个参数的梯度值woptimizer.step()  # 根据梯度更新网络w参数loss_value = loss.item()  # 从tensor数据中提取数据出来,tensor获取损失值if batch_size_num % 40 == 0:print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1best_acc = 0def test(dataloader, model, loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()  # 测试,w就不能再更新。test_loss, correct = 0, 0with torch.no_grad():  # 一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。这可以减少计算所占用的消耗for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model.forward(x)test_loss += loss_fn(pred, y).item()  # test loss是会自动累加每一个批次的损失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batches  # 能来衡量模型测试的好坏。correct /= size  # 平均的正确率print(f"Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}\n")acc_s.append(correct)loss_s.append(test_loss)if correct > best_acc:  # 保存正确率最大的那一次的模型best_acc = correct

 

6.将模型传入GPU,并有序调整学习率

from torch import nndevice = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_avaibale() else 'cpu'
model = resent_model.to(device)  # 为什么不需要加括号,之前是model = CNN().to(device) 因为 resnet_model 是对象不是类"""有序调整学习率"""
loss_fn = nn.CrossEntropyLoss()  # 处理多分类
optimizer = torch.optim.Adam(param_to_update, lr=0.001)  # 仅训练最后一层的参数
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # 调整学习率

 

7.进行训练和测试

  • 选择训练100轮,每训练一轮,输出测试结果
epchos = 100
acc_s = []
loss_s = []
for t in range(epchos):print(f"Epoch {t + 1}\n--------------------------")train(train_dataloader, model, loss_fn, optimizer)scheduler.step()test(test_dataloader, model, loss_fn)
print('最优测试结果为:', best_acc)

输出:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1548440.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

C++多态原理

C多态原理 多态的原理动态绑定与静态绑定静态多态动态多态 虚函数存在哪的&#xff1f;虚表存在哪的&#xff1f; 多态的原理 // 这里常考一道笔试题&#xff1a;sizeof(Base)是多少&#xff1f; class Base { public:virtual void Func1(){cout << "Func1()"…

【单调栈】单调栈基础及经典案例

【单调栈】单调栈基础及经典案例 单调栈理论基础每日温度下一个更大元素Ⅰ下一个更大元素Ⅱ经典例题—接雨水思路一 暴力求解思路二 双指针优化思路三 单调栈解法 经典例题—柱状图中最大的矩形思路一 暴力求解思路二 单调栈 单调栈理论基础 单调栈的应用场景&#xff1a;要寻…

109.游戏安全项目:信息显示二-利用游戏通知辅助计算基址

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 内容参考于&#xff1a;易道云信息技术研究院 本人写的内容纯属胡编乱造&#xff0c;全都是合成造假&#xff0c;仅仅只是为了娱乐&#xff0c;请不要盲目相信…

Visual Studio使用与“Hello Word“的编写

1.打开Visual Studio点击"创建新项目" 2.点击"空项目"&#xff0c;并点击"下一步" 3.设置"项目名称"并"设置地址" 4.打开项目后&#xff0c;右击"源文件"并选择"添加"的"新建项" 5.点击"…

Java每日面试题(JVM)(day15)

目录 Java对象内存布局markWord 数据结构JDK1.8 JVM 内存结构JDK1.8堆内存结构GC垃圾回收如何发现垃圾如何回收垃圾 JVM调优参数 Java对象内存布局 markWord 数据结构 JDK1.8 JVM 内存结构 程序计数器: 线程私有&#xff0c;记录代码执行的位置. Java虚拟机栈: 线程私有&#…

【移植】标准系统方案之扬帆移植案例

往期知识点记录&#xff1a; 鸿蒙&#xff08;HarmonyOS&#xff09;应用层开发&#xff08;北向&#xff09;知识点汇总 鸿蒙&#xff08;OpenHarmony&#xff09;南向开发保姆级知识点汇总~ 持续更新中…… 本文章是基于瑞芯微RK3399芯片的yangfan开发板&#xff0c;进行标准…

【论文速看】DL最新进展20240927-目标检测、Transformer

目录 【目标检测】【Transformer】 【目标检测】 [2024小目标检测] A DeNoising FPN With Transformer R-CNN for Tiny Object Detection 论文链接&#xff1a;https://arxiv.org/abs/2406.05755 代码链接&#xff1a;https://github.com/hoiliu-0801/DNTR 尽管计算机视觉领域…

笔记整理—linux进程部分(1)进程终止函数注册、进程环境、进程虚拟地址

对于mian()函数而言&#xff0c;执行前也需要先执行一段引导代码才会去执行main()函数&#xff0c;该部分的代码包含构建c语言的运行环境等配置&#xff0c;如清理bss段等。 在使用gcc去编译程序的时候&#xff0c;使用gcc -v xxx.c可见链接过程。在编译完成后可见xxx.out文件。…

数据结构——计数、桶、基数排序

目录 引言 计数排序 1.算法思想 2.算法步骤 3.代码实现 4.复杂度分析 桶排序 1.算法思想 2.算法步骤 3.代码实现 4.复杂度分析 基数排序 1.算法思想 2.算法步骤 3.代码实现 4.复杂度分析 排序算法的稳定性 1.稳定性的概念 2.各个排序算法的稳定性 结束语 引…

NVLM多模态 LLM 在图像和语言任务中的表现优于 GPT-4o

论文地址&#xff1a;https://arxiv.org/pdf/2409.11402 背景 传统的多模态 LLM 有两种主要方法&#xff1a;纯解码器架构&#xff08;如 LLaVA&#xff09;和基于交叉注意力的架构&#xff08;如 Flamingo&#xff09;。混合架构&#xff0c;既提高了训练效率&#xff0c;又增…

[CKA]CKA的购买和注册考试券

CKA的购买和注册考试券 一、购买CKA 1、注册 LF开源软件学园 账号 LF开源软件学园&#xff1a;https://training.linuxfoundation.cn/register 2、个人中心进行实名认证 3、按需求进行购买 4、在考试中心–我的订单 中查看购买的订单 我是在"黑色星期五"打折买的…

LLM大模型书籍:专补大模型短板的RAG入门与实战书来了!

文末赠书 RAG自2020年由Facebook AI Research推出后&#xff0c;一下子就窜红了。 毕竟&#xff0c;它是真的帮了大忙&#xff0c;在解决大语言模型的“幻觉”问题上起到了关键作用。 如今&#xff0c;Google、AWS、IBM、微软、NVIDIA等科技巨头都在支持RAG应用的开发。微软…

中国新媒体联盟与中运律师事务所 建立战略合作伙伴关系

2024年9月27日&#xff0c;中国新媒体联盟与中运律师事务所举行战略合作协议签字仪式。中国新媒体联盟主任兼中国社会新闻网主编、中法新闻法制网运营中心主任左新发&#xff0c;中运律师事务所高级顾问刘学伟代表双方单位签字。 中国新媒体联盟是由央视微电影中文频道联合多家…

你的下一台手机会是眼镜吗?RTE 大会与你一同寻找下一代计算平台丨「空间计算和新硬件」论坛报名

周四 Meta 刚公布新一代 AR 眼镜 Orion 后&#xff0c;Perplexity 的 CEO 发了一条状态&#xff1a;「如果你还在做软件&#xff0c;请转型硬件。」 一家估值 30 亿美元的 AI 软件公司 CEO 说出这样的言论&#xff0c;既有有见到「最强 AR 眼镜」Orion 后的激动情绪&#xff0c…

如何组织一场考试并筛选未参加答题的考生?

&#x1f64b;频繁有小伙伴咨询&#xff1a;我组织了一场答题活动&#xff0c;导出考试成绩时只有参加了答题的人&#xff0c;但我想要找到哪些人没答题 此前我们会建议小伙伴逐人排查&#xff0c;但这建议被反复吐槽&#x1f926; 确实&#xff0c;如果只有十几个人逐人排查还…

鸿蒙开发(NEXT/API 12)【硬件(Pen Kit)】手写笔服务

Pen Kit&#xff08;手写笔服务&#xff09;是华为提供的一套手写套件&#xff0c;提供笔刷效果、笔迹编辑、报点预测、一笔成形和全局取色的功能。手写笔服务可以为产品带来优质手写体验&#xff0c;为您创造更多的手写应用场景。 目前Pen Kit提供了四种能力&#xff1a;手写…

银行大模型,走到哪了?

频道说 透过近期披露的上市银行中报&#xff0c;窥探银行业大模型最新进展。 大模型浪潮依然汹涌澎湃。 9月12日&#xff0c;OpenAI全新发布o1模型&#xff0c;在复杂推理任务取得重大进步&#xff0c;代表了人工智能能力的新水平&#xff0c;被视为AI时代的又一个里程碑。 …

Bigemap Pro首发(一款真正全面替代Arcgis的国产基础软件)

Bigemap Pro是一款功能强大的计算机数据要素辅助设计(Computer-Aided Data Elements Design CADED)软件&#xff0c;由成都比格图数据处理有限公司研发设计&#xff0c;主要应用在数据要素设计领域&#xff0c;为各行业提供安全可靠高效易用的数据要素设计类国产化基础软件。Bi…

公交换乘C++

题目&#xff1a; 样例解释&#xff1a; 样例#1&#xff1a; 第一条记录&#xff0c;在第 3 分钟花费 10 元乘坐地铁。 第二条记录&#xff0c;在第 46 分钟乘坐公交车&#xff0c;可以使用第一条记录中乘坐地铁获得的优惠票&#xff0c;因此没有花费。 第三条记录&#xff0c;…