半监督学习与数据增强


✨✨ 欢迎大家来访Srlua的博文(づ ̄3 ̄)づ╭❤~✨✨

🌟🌟 欢迎各位亲爱的读者,感谢你们抽出宝贵的时间来阅读我的文章。

我是Srlua小谢,在这里我会分享我的知识和经验。🎥

希望在这里,我们能一起探索IT世界的奥妙,提升我们的技能。🔮

记得先点赞👍后阅读哦~ 👏👏

📘📚 所属专栏:人工智能、话题分享

欢迎访问我的主页:Srlua小谢 获取更多信息和资源。✨✨🌙🌙

​​

​​

目录

概述

算法原理

核心逻辑

效果演示

使用方式

参考文献


 本文所有资源均可在该地址处获取。

概述

本文复现论文 FixMatch: Simplifying Semi-Supervised Learning with Consistency and Confidence[1] 提出的半监督学习方法。

半监督学习(Semi-supervised Learning)是一种机器学习方法,它将少量的标注数据(带有标签的数据)和大量的未标注数据(不带标签的数据)结合起来训练模型。在许多实际应用中,标注数据获取成本高且困难,而未标注数据通常较为丰富和容易获取。因此,半监督学习方法被引入并被用于利用未标注数据来提高模型的性能和泛化能力。

图1:半监督数据集

该论文介绍了一种基于一致性和置信度的半监督学习方法 FixMatch。FixMatch首先使用模型为弱增强后的未标注图像生成伪标签。对于给定图像,只有当模型产生高置信度预测时才保留伪标签。然后,模型在输入同一图像的强增强版本时被训练去预测伪标签。FixMatch 在各种半监督学习数据集上实现了先进的性能。

算法原理

FixMatch 结合了两种半监督学习方法:一致性正则化和伪标签。其主要创新点在于这两种方法的结合以及在执行一致性正则化时分别使用了弱增强和强增强。

FixMatch 的损失函数由两个交叉熵损失项组成:一个用于有标签数据的监督损失 lsls​ 和一个用于无标签数据的无监督损失 lulu​ 。具体来说,lsls​ 只是对弱增强有标签样本应用的标准交叉熵损失:

ls=1B∑b=1BH(pb,pm(y∣α(xb)))ls​=B1​b=1∑B​H(pb​,pm​(y∣α(xb​)))

其中 BB 表示 batch size,HH 表示交叉熵损失,pbpb​ 表示标记,pm(y∣α(xb))pm​(y∣α(xb​)) 表示模型对弱增强样本的预测结果。

FixMatch 对每个无标签样本计算一个伪标签,然后在标准交叉熵损失中使用该标签。为了获得伪标签,我们首先计算模型对给定无标签图像的弱增强版本的预测类别分布:qb=pm(y∣α(ub))qb​=pm​(y∣α(ub​))。然后,我们使用 q^b=arg⁡max⁡qbq^​b​=argmaxqb​ 作为伪标签,但我们在交叉熵损失中对模型对 ubub​ 的强增强版本的输出进行约束:

lu=1μB∑b=1μB1(max(qb)>τ)H(q^b,pm(y∣A(ub)))lu​=μB1​b=1∑μB​1(max(qb​)>τ)H(q^​b​,pm​(y∣A(ub​)))

其中 μμ 表示无标签样本与有标签样本数量之比,1(max(qb)>τ)1(max(qb​)>τ) 当前仅当 max(qb)>τmax(qb​)>τ 成立时为 1 否则为 0,ττ 表示置信度阈值,A(ub)A(ub​) 表示对无标签样本的强增强。

FixMatch的总损失就是 ls+λululs​+λu​lu​,其中 λuλu​ 是表示无标签损失相对权重的标量超参数。

图2:方法原理图

FixMatch 利用两种增强方法:“弱增强”和“强增强”。论文所使用的弱增强是一种标准的翻转和位移增强策略。具体来说,除了SVHN数据集之外,我们在所有数据集上以50%的概率随机水平翻转图像,并随机在垂直和水平方向上平移图像最多12.5%。对于“强增强”,我采用了基于随机幅度采样的 RandAugment,然后进行了 Cutout 处理。

我在CIFAR-10、CIFAR-100 、SVHN 和 FER2013 数据集上对 FixMatch 进行了实验。关于使用的神经网络,我在 CIFAR-10 和 SVHN 上使用了 Wide ResNet-28-2,在 CIFAR-100 上使用了 Wide ResNet-28-8,在 FER2013 上使用了 Wide ResNe-37-2。实验结果如下表所示:

数据集准确率(%)
CIFAR-1086.39
CIFAR-10068.88
SVHN91.25
FER201368.57

为了直观展示 FixMatch 的效果,我在线部署了基于 FER2013 数据集训练的 Wide ResNe-37-2 模型。FER2013[2] 是一个面部表情识别数据集,其包含约 30000 张不同表情的面部 RGB 图像,尺寸限制为 48×48。其主要标签可分为 7 种类型:愤怒(Angry),厌恶(Disgust),恐惧(Fear),快乐(Happy),悲伤(Sad),惊讶(Surprise),中性(Neutral)。厌恶表情的图像数量最少,只有 600 张,而其他标签的样本数量均接近 5,000 张。

核心逻辑

具体的核心逻辑如下所示:

for epoch in range(epochs):model.train()train_tqdm = zip(labeled_dataloader, unlabeled_dataloader)for labeled_batch, unlabeled_batch in train_tqdm:optimizer.zero_grad()# 利用标记样本计算损失data = labeled_batch[0].to(device)labels = labeled_batch[1].to(device)logits = model(normalize(strong_aug(data)))loss = F.cross_entropy(logits, labels)# 计算未标记样本伪标签with torch.no_grad():data = unlabeled_batch[0].to(device)logits = model(normalize(weak_aug(data)))probs = F.softmax(logits, dim=-1)trusted = torch.max(probs, dim=-1).values > thresholdpseudo_labels = torch.argmax(probs[trusted], dim=-1)loss_factor = weight * torch.sum(trusted).item() / data.shape[0]# 利用未标记样本计算损失logits = model(normalize(strong_aug(data[trusted])))loss += loss_factor * F.cross_entropy(logits, pseudo_labels)# 反向梯度传播并更新模型参数loss.backward()optimizer.step()

以上代码仅作展示,更详细的代码文件请参见附件。

效果演示

网站提供了在线体验功能。用户需要输入一张长宽尽可能相等且大小不超过 1MB 的正面脸部 JPG 图像,网站就会返回图片中人物表情所表达的情绪。

图3:在线演示结果

使用方式

  • 解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令:
unzip FixMatch.zip
cd FixMatch

  • 代码的运行环境可通过如下命令进行配置:
pip install -r requirements.txt

  • 如果希望在本地运行程序,请运行如下命令:
python main.py

  • 如果希望在线部署,请运行如下命令:
python main-flask.py

(以上内容皆为原创,请勿转载)

参考文献

[1] Sohn K, Berthelot D, Carlini N, et al. Fixmatch: Simplifying semi-supervised learning with consistency and confidence[J]. Advances in neural information processing systems, 2020, 33: 596-608.

[2] Wang L, Xu S, Wang X, et al. Eavesdrop the composition proportion of training labels in federated learning[J]. arXiv preprint arXiv:1910.06044, 2019.

​​

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/35214.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【成功解决:Can‘t uninstall ‘ultralytics‘. No files were found to uninstall.】

问题: 尝试卸载ultralytics时,使用pip uninstall ultralytics命令,已经卸载了ultralytics,但是仍出现Cant uninstall ultralytics. No files were found to uninstall,导致无法卸载干净。 原因 ultralytics相应的dis…

AcWing 3496. 特殊年份

文章目录 前言代码思路 前言 写简单题没啥。反正都是要写的&#xff0c;先把能拿到的分数拿了&#xff0c;之后有机会再去啃一啃硬骨头。啃不下来就算了。 代码 #include<bits/stdc.h> using namespace std; char a1[10],a2[10],a3[10],a4[10],a5[10]; int main(){cin…

MongoDB性能监控工具

mongostat mongostat是MongoDB自带的监控工具&#xff0c;其可以提供数据库节点或者整个集群当前的状态视图。该功能的设计非常类似于Linux系统中的vmstat命令&#xff0c;可以呈现出实时的状态变化。不同的是&#xff0c;mongostat所监视的对象是数据库进程。mongostat常用于…

【LeetCode: 999. 可以被一步捕获的棋子数 + 模拟】

&#x1f680; 算法题 &#x1f680; &#x1f332; 算法刷题专栏 | 面试必备算法 | 面试高频算法 &#x1f340; &#x1f332; 越难的东西,越要努力坚持&#xff0c;因为它具有很高的价值&#xff0c;算法就是这样✨ &#x1f332; 作者简介&#xff1a;硕风和炜&#xff0c;…

智创 AI 新视界 -- 优化 AI 模型训练效率的策略与技巧(16 - 1)

&#x1f496;&#x1f496;&#x1f496;亲爱的朋友们&#xff0c;热烈欢迎你们来到 青云交的博客&#xff01;能与你们在此邂逅&#xff0c;我满心欢喜&#xff0c;深感无比荣幸。在这个瞬息万变的时代&#xff0c;我们每个人都在苦苦追寻一处能让心灵安然栖息的港湾。而 我的…

专业140+总分420+上海交通大学819考研经验上交电子信息与通信工程,真题,大纲,参考书。博睿泽信息通信考研论坛,信息通信考研Jenny

考研结束&#xff0c;专业819信号系统与信号处理140&#xff0c;总分420&#xff0c;终于梦圆交大&#xff0c;高考时敢都不敢想目标&#xff0c;现在已经成为现实&#xff0c;考研后劲很大&#xff0c;这一年的复习经历&#xff0c;还是历历在目&#xff0c;整理一下&#xff…

K8S服务突然中断无法访问:报The node had condition: [DiskPressure]异常

一、背景 程序在运行过程中&#xff0c;突然无法访问&#xff0c;发现后台接口也无法访问&#xff1b;查看kuboard&#xff0c;发现报如下异常&#xff1a;The node had condition: [DiskPressure]. 继续查看磁盘使用率&#xff0c;发现系统盘使用率已经高达93%。问题前后呼应…

【工具变量】上市公司企业违规数据(企业当年是否违规、企业当年违规的次数)2000-2022年

一、测算方式&#xff1a;参考C刊《当代财经》纪亚方&#xff08;2023&#xff09;老师的研究&#xff0c;通过对上市公司被处罚涉及的年份进行追溯&#xff0c;为了保证企业违规行为变量度量的准确性&#xff0c;将追溯到公司被处罚的年份定义为违规年份。 采用两个指标对企业…

视频孪生携手视联网 智汇云舟亮相中国电信2024数字科技生态大会

12月3日&#xff0c;由中国电信主办的“2024数字科技生态大会”在广州盛大开幕。活动现场&#xff0c;前沿科技与创新理念交相辉映&#xff0c;数字科技未来蓝图徐徐展开。智汇云舟作为中国电信的战略合作伙伴&#xff0c;受邀出席本次活动。 展会期间&#xff0c;以“天翼视联…

Unity 使用LineRenderer制作模拟2d绳子

效果展示&#xff1a; 实现如下&#xff1a; 首先&#xff0c;直接上代码&#xff1a; using System.Collections; using System.Collections.Generic; using UnityEngine;public class LineFourRender : MonoBehaviour {public Transform StartNode;public Transform MidNod…

力扣-图论-4【算法学习day.54】

前言 ###我做这类文章一个重要的目的还是给正在学习的大家提供方向和记录学习过程&#xff08;例如想要掌握基础用法&#xff0c;该刷哪些题&#xff1f;&#xff09;我的解析也不会做的非常详细&#xff0c;只会提供思路和一些关键点&#xff0c;力扣上的大佬们的题解质量是非…

【射频IC进阶实践教程】2.6 LNA版图设计及DRC/LVS验证

射频集成电路的版图设计非常关键&#xff0c;他对寄生参数非常敏感&#xff0c;需要使其最小化。还需要注意相互耦合的方式本次课程主要介绍射频IC的一些相关布局和连线方面的考虑。 一、版图设计 1. 版图的元件布局 首先打开对应的原理图 点击进行版图设计 由于已经有做好的…

uviewplus中的时间单选框up-datetime-picker的在uni-app+vue3的使用方法

uviewplus中的时间单选框up-datetime-picker的使用方法 前言 在实际开发中,我们经常需要使用时间选择器来让用户选择特定的时间。本文将详细介绍uviewplus中up-datetime-picker组件的使用方法,特别是在处理年月选择时的一些关键实现&#xff0c;因为官方有很多相关的功能和方法…

Spring Bean 的生命周期和获取方式

优质博文&#xff1a;IT-BLOG-CN 一、Spring Bean 的生命周期&#xff0c;如何被管理的 对于普通的 Java对象&#xff0c;当 new的时候创建对象&#xff0c;当它没有任何引用的时候被垃圾回收机制回收。而由 Spring IoC容器托管的对象&#xff0c;它们的生命周期完全由容器控…

【Spring MVC篇】返回响应

个人主页&#xff1a;兜里有颗棉花糖 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 兜里有颗棉花糖 原创 收录于专栏【Spring MVC】 本专栏旨在分享学习Spring MVC的一点学习心得&#xff0c;欢迎大家在评论区交流讨论&#x1f48c; 目录 一、返回静态页面…

使用Python创建API服务器并打包成exe文件

本文来记录下使用Python创建API服务器并打包成exe文件 文章目录 概述简述API服务器创建打包API服务器为exe文件本文小结 概述 在软件开发中&#xff0c;API服务器是连接前端和后端服务的桥梁&#xff0c;而Python因其丰富的库和框架&#xff0c;如Flask、Django等&#xff0c;成…

MHA切换过程

MHA&#xff08;Master High Availability&#xff09;是一套用于MySQL数据库的高可用性解决方案&#xff0c;它能够在主服务器发生故障时自动将一个从服务器提升为新的主服务器&#xff0c;从而实现数据库服务的持续可用。MHA的切换过程主要包括以下几个步骤&#xff1a; 1. …

NextUI 教程:打造美观高效的React UI

NextUI 教程&#xff1a;打造美观高效的React UI 项目地址:https://gitcode.com/gh_mirrors/ne/nextui 1. 项目介绍 NextUI 是一个轻量级、快速且现代化的React UI库&#xff0c;提供了一系列优雅的组件以帮助开发者构建令人印象深刻的Web应用。它注重性能和用户体验&#x…

Python和Java后端开发技术对比

在当今互联网技术飞速发展的时代&#xff0c;后端开发扮演着至关重要的角色。Python和Java作为两大主流的后端开发语言&#xff0c;各自具备独特的优势和应用场景。让我们深入了解这两种技术的特点和选择建议。 Java后端开发一直是企业级应用的首选方案。它以强大的类型系统、…

Java HashMap

HashMap 是一个散列表&#xff0c;它存储的内容是键值对(key-value)映射。 HashMap 实现了 Map 接口&#xff0c;根据键的 HashCode 值存储数据&#xff0c;具有很快的访问速度&#xff0c;最多允许一条记录的键为 null&#xff0c;不支持线程同步。 HashMap 是无序的&#x…