Meta-Learning数学原理

文章目录

  • 什么是元学习
  • 元学习的目标
  • 元学习的类型
  • 数学推导
    • 1. 传统机器学习的数学表述
    • 2. 元学习的基本思想
    • 3. MAML 算法推导
      • 3.1 元任务设置
      • 3.2 内层优化:任务级别学习
      • 3.3 外层优化:元级别学习
      • 3.4 元梯度计算
      • 3.5 最终更新规则
    • 4. 算法合并
    • 5. 理解 MAML 的优化
  • 图例
  • MAML 的优势
  • 其他元学习方法
  • 总结
  • 手写笔记

🍃作者介绍:双非本科大四网络工程专业在读,阿里云专家博主,专注于Java领域学习,擅长web应用开发,目前开始人工智能领域相关知识的学习
🦅个人主页:@逐梦苍穹
📕所属专栏:人工智能
🌻gitee地址:xzl的人工智能代码仓库
✈ 您的一键三连,是我创作的最大动力🌹

之前介绍过元学习的内容:https://xzl-tech.blog.csdn.net/article/details/142025393
这篇文章讲一下Meta-Learning的数学原理。

什么是元学习

元学习(Meta-Learning),也称为“学习如何学习”,是一种机器学习方法,其目的是通过学习算法的经验和结构特性,提升算法在新任务上的学习效率。

换句话说,元学习试图学习一种更有效的学习方法,使得模型能够快速适应新的任务或环境。


传统的机器学习算法通常需要大量的数据来训练模型,并且当数据分布发生变化或者遇到一个新任务时,模型往往需要重新训练才能保持良好的性能。

而元学习则不同,它通过 从多个相关任务中学习,从而在面对新任务时更快速地进行学习。

元学习的核心思想是利用“学习的经验”来提高学习的速度和质量。

在元学习的框架中,有两个层次的学习过程:

  1. 元学习者(Meta-Learner): 负责从多个任务中提取经验和知识,用于更新学习策略或模型参数。
  2. 基础学习者(Base Learner): 在每个具体任务上执行实际的学习过程。

元学习的目标

元学习的目标是解决以下问题:

  • 快速适应: 当模型面临新任务时,能够基于已有的经验快速适应,而无需大量的数据和计算资源。
  • 跨任务泛化: 提高模型从多个任务中学习到的知识在新任务上的泛化能力。
  • 提高数据效率: 减少模型在新任务上所需的数据量,尤其是在数据稀缺或高昂的情况下。

元学习的类型

元学习可以按照不同的方式分类,以下是三种主要类型:

  1. 基于模型的元学习(Model-Based Meta-Learning):
    • 这种方法通过直接设计一种能够快速适应新任务的模型架构,通常是通过某种特殊的神经网络结构来实现的。例如,基于记忆的神经网络(如 LSTM 或 Memory-Augmented Neural Networks)被设计成能有效地记住过去的任务信息,并在新任务上进行快速调整。
    • 例子: MANN(Memory-Augmented Neural Networks),SNAIL(Simple Neural Attentive Meta-Learner)。
  2. 基于优化的元学习(Optimization-Based Meta-Learning):
    • 这种方法的核心是通过改进优化过程本身来实现快速学习。其代表算法是 MAML(Model-Agnostic Meta-Learning),它通过在所有任务上共享一个初始模型参数,使得初始模型在每个任务上进行少量梯度下降更新后能够快速适应新任务。
    • 例子: MAML(Model-Agnostic Meta-Learning),Reptile。
  3. 基于记忆的元学习(Memory-Based Meta-Learning):
    • 这类方法直接存储并检索训练过程中的经验数据。当遇到新任务时,通过查找与之相似的旧任务,并利用这些旧任务的数据和经验来快速学习。k-NN(k-近邻)方法是最基本的例子,而更复杂的方法可能使用深度记忆网络。
    • 例子: Meta Networks,Prototypical Networks。

数学推导

1. 传统机器学习的数学表述

在传统的机器学习中,我们通常试图找到一个函数 f θ f_\theta fθ来最小化给定数据集 D D D的损失函数:
θ ∗ = arg ⁡ min ⁡ θ L ( f θ , D ) \theta^* = \arg\min_{\theta} L(f_\theta, D) θ=argminθL(fθ,D)
其中:

  • θ \theta θ是模型的参数。
  • L ( f θ , D ) L(f_\theta, D) L(fθ,D)是损失函数,例如交叉熵损失。
  • 通过梯度下降等优化方法,我们不断更新参数 θ \theta θ以最小化损失。

2. 元学习的基本思想

元学习的目标是找到一种元算法 F ϕ F_\phi Fϕ,使得它可以快速学习新任务。这里的关键是学习一种 学习算法。换句话说,元学习希望找到一组元参数 ϕ \phi ϕ,从而在给定一个新任务 T i T_i Ti时,使用少量数据和梯度更新就可以迅速找到特定任务的参数 θ i \theta_i θi

3. MAML 算法推导

MAML 的目标是学习一个初始模型参数 θ \theta θ,使得它可以通过少量的梯度更新快速适应新任务。

3.1 元任务设置

假设有一组任务 { T 1 , T 2 , … , T N } \{T_1, T_2, \dots, T_N\} {T1,T2,,TN},每个任务 T i T_i Ti有自己的训练数据 D i train D_i^{\text{train}} Ditrain和测试数据 D i test D_i^{\text{test}} Ditest

3.2 内层优化:任务级别学习

对于每个任务 T i T_i Ti,我们首先使用任务的训练数据 D i train D_i^{\text{train}} Ditrain和当前的模型参数 θ \theta θ进行一次或多次梯度更新,得到任务特定的参数 θ i ′ \theta_i' θi
θ i ′ = θ − α ∇ θ L T i ( f θ , D i train ) \theta_i' = \theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}}) θi=θαθLTi(fθ,Ditrain)
其中:

  • α \alpha α是学习率。
  • L T i ( f θ , D i train ) L_{T_i}(f_\theta, D_i^{\text{train}}) LTi(fθ,Ditrain)是任务 T i T_i Ti的损失函数,例如对于分类任务可以是交叉熵损失。

3.3 外层优化:元级别学习

在每个任务的测试数据上评估更新后的模型参数 θ i ′ \theta_i' θi,计算其损失,并在所有任务上最小化测试损失的总和:
min ⁡ θ ∑ i = 1 N L T i ( f θ i ′ , D i test ) \min_{\theta} \sum_{i=1}^N L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) minθi=1NLTi(fθi,Ditest)
θ i ′ \theta_i' θi展开,这个目标实际上是关于初始参数 θ \theta θ的优化问题:
min ⁡ θ ∑ i = 1 N L T i ( f θ − α ∇ θ L T i ( f θ , D i train ) , D i test ) \min_{\theta} \sum_{i=1}^N L_{T_i}(f_{\theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}})}, D_i^{\text{test}}) minθi=1NLTi(fθαθLTi(fθ,Ditrain),Ditest)

3.4 元梯度计算

为了优化这个目标,我们需要对 θ \theta θ求梯度。这里涉及二阶梯度,因为 θ i ′ \theta_i' θi是通过内层优化得到的:
θ ← θ − β ∑ i = 1 N ∇ θ L T i ( f θ i ′ , D i test ) \theta \leftarrow \theta - \beta \sum_{i=1}^N \nabla_\theta L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) θθβi=1NθLTi(fθi,Ditest)
其中 β \beta β是元学习的学习率。

  • 这个更新包含了二阶导数项: ∇ θ θ i ′ = ∇ θ ( θ − α ∇ θ L T i ( f θ , D i train ) ) \nabla_\theta \theta_i' = \nabla_\theta \left(\theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}})\right) θθi=θ(θαθLTi(fθ,Ditrain))

3.5 最终更新规则

最终的元学习更新规则可以写为:
θ ← θ − β ∑ i = 1 N ∇ θ L T i ( f θ − α ∇ θ L T i ( f θ , D i train ) , D i test ) \theta \leftarrow \theta - \beta \sum_{i=1}^N \nabla_\theta L_{T_i}\left(f_{\theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}})}, D_i^{\text{test}}\right) θθβi=1NθLTi(fθαθLTi(fθ,Ditrain),Ditest)

4. 算法合并

将内层优化 θ i ′ \theta_i' θi代入外层优化的公式中,外层优化的梯度 ∇ θ L T i ( f θ i ′ , D i test ) \nabla_\theta L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) θLTi(fθi,Ditest)需要应用链式法则:
∇ θ L T i ( f θ i ′ , D i test ) = ∇ θ L T i ( f θ − α ∇ θ L T i ( f θ , D i train ) , D i test ) \nabla_\theta L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) = \nabla_\theta L_{T_i}\left(f_{\theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}})}, D_i^{\text{test}}\right) θLTi(fθi,Ditest)=θLTi(fθαθLTi(fθ,Ditrain),Ditest)
通过链式法则,展开这个公式:
∇ θ L T i ( f θ i ′ , D i test ) = ∇ θ i ′ L T i ( f θ i ′ , D i test ) ⋅ ∇ θ θ i ′ \nabla_\theta L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) = \nabla_{\theta_i'} L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) \cdot \nabla_\theta \theta_i' θLTi(fθi,Ditest)=θiLTi(fθi,Ditest)θθi
其中 ∇ θ θ i ′ \nabla_\theta \theta_i' θθi的形式为:
∇ θ θ i ′ = I − α ∇ θ 2 L T i ( f θ , D i train ) \nabla_\theta \theta_i' = I - \alpha \nabla^2_\theta L_{T_i}(f_\theta, D_i^{\text{train}}) θθi=Iαθ2LTi(fθ,Ditrain)
I I I是单位矩阵, ∇ θ 2 L T i ( f θ , D i train ) \nabla^2_\theta L_{T_i}(f_\theta, D_i^{\text{train}}) θ2LTi(fθ,Ditrain)是损失函数关于 θ \theta θ的二阶导数(Hessian 矩阵)。


最终的公式:

将这些部分合并在一起,得到 MAML 的最终更新公式为:
θ ← θ − β ∑ i = 1 N ∇ θ i ′ L T i ( f θ − α ∇ θ L T i ( f θ , D i train ) , D i test ) ⋅ ( I − α ∇ θ 2 L T i ( f θ , D i train ) ) \theta \leftarrow \theta - \beta \sum_{i=1}^N \nabla_{\theta_i'} L_{T_i}\left(f_{\theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}})}, D_i^{\text{test}}\right) \cdot \left(I - \alpha \nabla^2_\theta L_{T_i}(f_\theta, D_i^{\text{train}})\right) θθβi=1NθiLTi(fθαθLTi(fθ,Ditrain),Ditest)(Iαθ2LTi(fθ,Ditrain))


解释:

  • 内层优化:第一部分 θ i ′ = θ − α ∇ θ L T i ( f θ , D i train ) \theta_i' = \theta - \alpha \nabla_\theta L_{T_i}(f_\theta, D_i^{\text{train}}) θi=θαθLTi(fθ,Ditrain)表示在每个任务上用梯度下降更新 θ \theta θ,得到特定于任务的参数 θ i ′ \theta_i' θi
  • 外层优化:外层优化考虑测试集上的损失,并通过链式法则计算对 θ \theta θ的梯度。这部分的关键是包含了内层更新的二阶导数 ∇ θ θ i ′ \nabla_\theta \theta_i' θθi
  • 合并公式:最终的更新公式同时结合了内层和外层优化的过程,充分考虑了内层更新对外层优化的影响。

简化(在某些情况下):

在实际应用中,计算二阶导数(Hessian 矩阵)非常昂贵。因此,有时会使用近似方法来简化计算,例如“一次近似 MAML (First-Order MAML, FOMAML)”,忽略二阶项,仅使用一阶导数进行更新。简化后的更新公式为:
θ ← θ − β ∑ i = 1 N ∇ θ i ′ L T i ( f θ i ′ , D i test ) \theta \leftarrow \theta - \beta \sum_{i=1}^N \nabla_{\theta_i'} L_{T_i}(f_{\theta_i'}, D_i^{\text{test}}) θθβi=1NθiLTi(fθi,Ditest)

这个简化版本去除了 ∇ θ θ i ′ \nabla_\theta \theta_i' θθi中的二阶导数计算。

5. 理解 MAML 的优化

通过上面的推导,MAML 的优化分为两个阶段:

  1. 内层优化:在每个任务上利用任务的训练数据对模型进行一次或多次更新,以获得任务特定的模型参数。
  2. 外层优化:在所有任务的测试数据上评估内层优化后的模型,并利用这个评估结果更新模型的初始参数。

图例

MAML 的优势

MAML 的一个关键优势在于,它学习了一个初始参数 θ \theta θ,使得它可以通过少量梯度更新快速适应新任务。这使得它非常适合少样本学习场景,如几次样本分类。

其他元学习方法

除了 MAML,文件中还提到其他元学习方法,如基于优化器的元学习、网络架构搜索(NAS)等。这些方法都在不同程度上优化了元学习的过程,使得模型能够在少量数据的情况下快速学习。

总结

元学习的数学推导核心在于通过多个任务的训练,学习到一个通用的学习算法(或模型初始化),使得模型可以快速适应新任务。MAML 是元学习的一个经典方法,通过在元任务上进行二阶优化,使模型获得更好的泛化能力。

手写笔记

最后放几张今天的手写笔记,主要是方便查阅。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/142968.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

钢索缺陷检测系统源码分享

钢索缺陷检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vis…

在线制作PPT组织架构图!这个AI工具简单又好用!

ppt组织架构图如何制作,用哪个软件好? 在现代商业世界中,组织架构图是展示公司结构和层级关系的重要工具,譬如内部沟通或者对外展示等场合下,一个精美且清晰的组织架构图都能有效传达信息,提升企业形象。 …

高精度加法和减法

高精度加法 在C/C中,我们经常会碰到限定数据范围的情况,我们先来看看常用的int和long long两种数据类型的范围吧。 C标准规定:int占一个机器字长。在32位系统中int占32位,即4个字节,所以int的范围是[-2的31次方&#…

独立站技能树之建站33项自检清单 1.0丨出海笔记

很多时候大家建好站之后很嗨,但过一会就开始担忧各种纠结我是不是还有什么点没做好,或者我的站漏了什么东西,那么接下来以下这个独立站自检清单能很好的帮到你。其实对于新手我还是建议大家直接用一些模板,因为模板上面基本该有的…

基于SpringBoot+Vue+MySQL的在线招投标系统

系统展示 用户前台界面 管理员后台界面 系统背景 在当今商业环境中,招投标活动是企业获取项目、资源及合作伙伴的重要途径。然而,传统招投标过程往往繁琐复杂,涉及众多文件交换、信息审核与沟通环节,不仅效率低下,还易…

车市状态喜人,国内海外“两开花”

文/王俣祺 导语:随着中秋假期告一段落,“金九”也正式过半,整体上这个销售旺季的数据可以说十分喜人,各家车企不是发布新车、改款车就是推出了一系列购车权益,充分刺激了消费者的购车热情。再加上政府政策的鼎力支持&a…

动态线程池实战(一)

动态线程池 对项目的认知 为什么需要动态线程池 DynamicTp简介 接入步骤 功能介绍 模块划分 代码结构介绍

中、美、德、日制造业理念差异

合格的产品依赖稳定可靠的人机料法环,要求减少变量因素,增加稳定因素,避免“熵”增;五个因素中任何一个不可控,批次产品的一致性绝对差; 日本汽车企业,侧重“人”和“环”, 倚重是人…

828华为云征文|华为云Flexus云服务器X实例之openEuler系统下部署SQLite数据库浏览器sqlite-web

828华为云征文|华为云Flexus云服务器X实例之openEuler系统下部署SQLite数据库浏览器sqlite-web 前言一、Flexus云服务器X实例介绍1.1 Flexus云服务器X实例简介1.2 Flexus云服务器X实例特点1.3 Flexus云服务器X实例使用场景 二、sqlite-web介绍2.1 sqlite-web简介2.2…

画台扇-第15届蓝桥省赛Scratch中级组真题第3题

[导读]:超平老师的《Scratch蓝桥杯真题解析100讲》已经全部完成,后续会不定期解读蓝桥杯真题,这是Scratch蓝桥杯真题解析第188讲。 如果想持续关注Scratch蓝桥真题解读,可以点击《Scratch蓝桥杯历年真题》并订阅合集,…

【教程】鸿蒙ARKTS 打造数据驾驶舱---前序

鸿蒙ARKTS 打造数据驾驶舱 ​ 前面2章我介绍了如何通过定义View绘制箭头以及圆形进度,初步了解了鸿蒙如何进行自定义View。接下来我将通过我最近在带的一个VUE的项目,简单实现了几个鸿蒙原生页面。帮助大家快速上手纯血鸿蒙开发. 本项目基于Api11Stage模…

揭开GPRC5D靶点的神秘面纱,助力多发性骨髓瘤药物开发

前 言 多发性骨髓瘤属于第二大常见的血液系统恶性肿瘤,起源于骨髓造血组织的浆细胞恶性增殖。首发症状表现为非特异性,如腰疼、反复感染等,造成误诊、漏诊率较高,且难治愈易复发。目前临床上的治疗有靶向治疗、放疗、化疗、干细…

C++之继承(通俗易懂版)

前言:我们都知道C是一门支持过程化编程,面向对象的高级语言,既然是面向对象的语言,那么对于对象而言,对象会有很多中相同的属性,举个例子:你和你老师,你们都有着共同的属性和身份,例…

Linux--守护进程与会话

进程组 概念 进程组就是一个或多个进程的集合。 一个进程组可以包含多个进程。 下面我们通过一句简单的命令行来展示: 为什么会有进程组? 批量操作:进程组允许将多个进程组织在一起,形成一个逻辑上的整体。当需要对多个进程…

【关联规则】【Apriori算法】理解

关联规则学习是数据挖掘中的一种技术,用于发现大型数据库中变量间的有趣关系,特别是变量之间的有意义的关联、相关和依赖关系。这种类型的规则在零售业中特别有用,因为它可以帮助确定哪些商品经常一起购买。 关键概念 频繁项集(F…

连锁会员管理系统应该有的高级功能

会员连锁管理系统是一种专门针对连锁企业设计的会员管理软件,它可以帮助连锁企业实现跨区域、跨店铺的会员信息、消费记录和积分等的统一管理。以下分析商淘云连锁会员管理系统的主要功能。 会员信息管理:全面收集和管理会员信息,如手机号码、…

2.4 卷积1

2.4 卷积1 2.4 卷积 在了解了系统及其脉冲响应之后,人们可能会想知道是否有一种方法可以通过任何给定的输入信号(不仅仅是单位脉冲)确定系统的输出信号。卷积就是这个问题的答案,前提是系统是线性且时不变的(LTI&…

不用价位宠物空气净化器有什么区别?性价比高宠物空气净化器推荐

自新冠之后,越来越多人意识到优质空气对健康的重要性了,纷纷购置了空气净化器。不少铲屎官便关注到了“宠物空气净化器”这一专业品牌,但越后面入手宠物空气净化器的人,看到的品牌越多。整个市场那是个“蓬勃发展”。 随着消费者…

erlang学习:mnesia数据库与ets表1

Mnesia 和 ETS 都是 Erlang 提供的表管理工具,用于存储和检索数据,但它们之间有一些重要的区别和共同点。 共同点 都是Erlang提供的表存储机制:ETS 和 Mnesia 都允许你在内存中创建表,并且可以用来存储键值对或者更复杂的数据结…

高级大数据开发协会

知识星球——高级大数据开发协会 协会内容: 教你参与开源项目提供新技术学习指导提供工作遇到的疑难问题技术支持参与大数据开源软件源码提升优化以互利共赢为原则,推动大数据技术发展探讨大数据职业发展和规划共享企业实际工作经验 感兴趣的私聊我,…