李宏毅2023机器学习HW15-Few-shot Classification

文章目录

  • Link
  • Task: Few-shot Classification
  • Baseline
    • Simple—transfer learning
    • Medium — FO-MAML
    • Strong — MAML

Link

Kaggle

Task: Few-shot Classification

The Omniglot dataset

  • background set: 30 alphabets
  • evaluation set: 20 alphabets
  • Problem setup: 5-way 1-shot classification

Omniglot数据集

  • 背景集:30个字母
  • 评估集:20个字母
  • 问题设置:5-way 1-shot分类
    Definition of support set and query set

Baseline

Simple—transfer learning

直接把sample code运行即可

  • traing:
    对随机选择的5个任务进行正常分类训练验证/测试
  • validation / testing:
    对五个 Support Images 进行微调,并对Query Images进行推理

Slover首先从训练集中选择5个任务,然后对选择的5个任务进行正常分类训练。在推理中,模型在支持集support set图像上微调inner_train_step步骤,然后在查询集Query Set图像上进行推理。
为了与元学习Slover保持一致,基本Slover具有与元学习Slover完全相同的输入输出格式

def BaseSolver(model,optimizer,x,n_way,k_shot,q_query,loss_fn,inner_train_step=1,inner_lr=0.4,train=True,return_labels=False,
):criterion, task_loss, task_acc = loss_fn, [], []labels = []for meta_batch in x:# Get datasupport_set = meta_batch[: n_way * k_shot]query_set = meta_batch[n_way * k_shot :]if train:""" training loop """# Use the support set to calculate losslabels = create_label(n_way, k_shot).to(device)logits = model.forward(support_set)loss = criterion(logits, labels)task_loss.append(loss)task_acc.append(calculate_accuracy(logits, labels))else:""" validation / testing loop """# First update model with support set images for `inner_train_step` stepsfast_weights = OrderedDict(model.named_parameters())for inner_step in range(inner_train_step):# Simply trainingtrain_label = create_label(n_way, k_shot).to(device)logits = model.functional_forward(support_set, fast_weights)loss = criterion(logits, train_label)grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)# Perform SGDfast_weights = OrderedDict((name, param - inner_lr * grad)for ((name, param), grad) in zip(fast_weights.items(), grads))if not return_labels:""" validation """val_label = create_label(n_way, q_query).to(device)logits = model.functional_forward(query_set, fast_weights)loss = criterion(logits, val_label)task_loss.append(loss)task_acc.append(calculate_accuracy(logits, val_label))else:""" testing """logits = model.functional_forward(query_set, fast_weights)labels.extend(torch.argmax(logits, -1).cpu().numpy())if return_labels:return labelsbatch_loss = torch.stack(task_loss).mean()task_acc = np.mean(task_acc)if train:# Update modelmodel.train()optimizer.zero_grad()batch_loss.backward()optimizer.step()return batch_loss, task_acc

Medium — FO-MAML

FOMAML(First-Order MAML)是MAML(Model-Agnostic Meta-Learning)的一种简化版本。MAML是一种元学习算法,旨在通过训练模型使其能够在少量新数据上快速适应新任务。FOMAML通过忽略二阶导数来简化MAML的计算过程,从而提高计算效率。它在许多情况下表现良好,尤其是在计算资源有限的情况下。然而,它也可能在某些任务上表现不如完整的MAML。

MAML的核心思想是通过在多个任务上进行训练,使得模型能够在面对新任务时,只需少量数据就能快速收敛到一个好的参数配置。具体来说,MAML的训练过程包括两个层次的优化:

  • 内层优化(Inner Loop):在每个任务上进行少量的梯度更新,以适应该任务。

  • 外层优化(Outer Loop):在所有任务上进行梯度更新,以优化模型的初始参数,使得模型在面对新任务时能够快速适应。

""" Inner Loop Update """
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=False) # create_graph=False:这个参数表示在计算梯度时不创建计算图。在FOMAML中,我们只关心一阶导数,因此不需要创建计算图fast_weights = OrderedDict((name, param - inner_lr * grad)for ((name, param), grad) in zip(fast_weights.items(), grads))""" Outer Loop Update """# TODO: Finish the outer loop update# raise NotimplementedErrormeta_batch_loss.backward()optimizer.step()

Strong — MAML

""" Inner Loop Update """
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)fast_weights = OrderedDict((name, param - inner_lr * grad)for ((name, param), grad) in zip(fast_weights.items(), grads))""" Outer Loop Update """# TODO: Finish the outer loop update# raise NotimplementedErrormeta_batch_loss.backward()optimizer.step()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1537583.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

【代码随想录训练营第42期 Day59打卡 - 图论Part9 - Bellman-Ford算法

目录 一、Bellman-Ford算法 定义 特性 伪代码实现 二、经典题目 题目:卡码网 94. 城市间货物运输 I 题目链接 题解: Bellman-Ford算法 三、小结 一、Bellman-Ford算法 定义 Bellman-Ford算法是一个迭代算法,它可以处理包含负权边的…

Zabbix的安装与基本使用(主机群组、应用集、监控项、触发器、动作、媒介)

前言:本博客仅作记录学习使用,部分图片出自网络,如有侵犯您的权益,请联系删除 一、环境准备 (1)实验基本设置: 主机名IP地址角色Mater192.168.1.10监控端node1192.168.1.11被监控端 # 网络自…

『功能项目』制作提示主角升级面板【56】

我们打开上一篇55事件中心处理怪物死亡的项目, 本章做的事情是制作提示主角升级的界面,当主角升级时就会被显示出来点击确认即可消失 首先在unity编辑场景制作 在确认按钮对象上添加事件 点击Button将Panel添加至事件框选 在事件函数中选择gameobject.S…

Linux操作系统入门(五)

————————————————————————————————————————— 至此,大部分Linux操作系统的文件操作指令已经总结完成,最后还需进行vim编辑器的使用 使用方法:在FinalShell终端中输入"vim [文件]",以下图…

微信支付开发-前端api实现

一、操作流程图 二、代码实现 <?php /*** 数字人答题业务流* User: 龙哥三年风水* Date: 2024/9/11* Time: 14:59*/ namespace app\controller\shuziren; use app\controller\Base; use app\model\param\QuestionParam as PQPModel; use app\model\answer\QuestionBank; u…

这个公司可以做点什么呢?

我是穿拖鞋的汉子&#xff0c;魔都中坚持长期主义的汽车电子工程师。 老规矩&#xff0c;分享一段喜欢的文字&#xff0c;避免自己成为高知识低文化的工程师&#xff1a; 屏蔽力是信息过载时代一个人的特殊竞争力&#xff0c;任何消耗你的人和事&#xff0c;多看一眼都是你的不…

C++ Primer Plus(速记版)-容器和算法

第九章 顺序容器 容器是存储特定类型对象的集合&#xff0c;标准库提供了多种容器类型以支持不同的使用场景。其中&#xff0c;顺序容器&#xff08;如vector、list、deque&#xff09;根据元素添加到容器中的顺序来存储和访问元素&#xff0c;与元素值无关。 这些顺序容器各有…

Vue Application exit (SharedArrayBuffer is not defined)

vite配置 export default defineConfig { server: {cors: true, // 启用 CORSheaders: {Cross-Origin-Opener-Policy: same-origin,Cross-Origin-Embedder-Policy: require-corp,cross-origin-resource-policy: cross-origin}}, } 错误处理 报其它错误&#xff0c;如(Compi…

第159天:安全开发-Python-协议库爆破FTPSSHRedisSMTPMYSQL等

案例一: Python-文件传输爆破-ftplib 库操作 ftp 协议 开一个ftp 利用ftp正确登录与失败登录都会有不同的回显 使用ftplib库进行测试 from ftplib import FTP # FTP服务器地址 ftp_server 192.168.172.132 # FTP服务器端口&#xff08;默认为21&#xff09; ftp_po…

chromedriver下载与安装方法

chromedriver下载地址&#xff1a; 版本在114及以下&#xff1a;http://chromedriver.storage.googleapis.com/index.html 版本在128&#xff1a;https://googlechromelabs.github.io/chrome-for-testing/#stable 其他版本下载方法&#xff1a; 如版本128.0.6613.137位下载地址…

婴儿接触危险物品检测系统源码分享

婴儿接触危险物品检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Comp…

计算机三级网络技术总结(三)

宽带&#xff08;bandwidth&#xff09;单位是kbpspos framing sdh / pos framing sonetpos flag s1s0 2 / pos flag s1s0 0CRC 32network <目的网络的ip地址><子网掩码的反码>area 0area 0 range<子网地址><子网掩码>ip route <目的网络地址>&l…

通信工程学习:什么是GPON吉比特无源光网络

GPON&#xff1a;吉比特无源光网络 GPON&#xff08;Gigabit-Capable Passive Optical Network&#xff0c;吉比特无源光网络&#xff09;是一种基于ITU-T G.984.x标准的最新一代宽带无源光综合接入技术。该技术以其高带宽、高效率、大覆盖范围和用户接口丰富等特点&#xff0c…

并发安全与锁

总述 这篇文章&#xff0c;我想谈一谈自己对于并发变成的理解与学习。主要涉及以下三个部分&#xff1a;goroutine&#xff0c;channel以及lock 临界区 首先&#xff0c;要明确下面两组概念 并发和并行 并行&#xff1a;指几个程序每时每刻都同时进行 并发&#xff1a;指…

JVM 一个对象是否已经死亡?

目录 前言 引用计数法 可达性分析法 引用 finalize() 方法区回收 前言 虚拟机中垃圾回收器是掌握对象生死的判官, 只要是垃圾回收器认为需要被回收的, 那么这个对象基本可以宣告"死亡". 但是也不是所有的对象, 都需要被回收, 因此, 我们在学习垃圾回收的时候…

如何用MATLAB计算多边形的几何中心

在MATLAB中&#xff0c;计算多边形的几何中心&#xff08;又称质心或重心&#xff09;可以通过以下步骤实现。假设你有一个多边形&#xff0c;其顶点按照顺时针或逆时针顺序排列在一个矩阵中。具体步骤如下&#xff1a; 定义多边形顶点&#xff1a;首先&#xff0c;你需要将多边…

珠宝首饰检测系统源码分享

珠宝首饰检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Computer Vis…

【时时三省】(C语言基础)指针进阶 例题8

山不在高&#xff0c;有仙则名。水不在深&#xff0c;有龙则灵。 ----CSDN 时时三省 第一个打印2 a6不管它是多大 前面是&#xff1d;s 都得变成两个字节 所以打印2 第二个打印5 sizeof里面的表达式是不参与运算的 所以打印5 上面所有例题总结…

36.贪心算法3

1.坏了的计算器&#xff08;medium&#xff09; . - 力扣&#xff08;LeetCode&#xff09; 题目解析 算法原理 代码 class Solution {public int brokenCalc(int startValue, int target) {// 正难则反 贪⼼int ret 0;while (target > startValue) {if (target % 2 0…

gcc/g++的使用:

目录 (1). 程序的翻译过程 预处理&#xff1a; gcc -E 源文件 编译&#xff1a; gcc -S 源文件 汇编&#xff1a;gcc -c 源文件 连接&#xff1a; (2) 语言的自举(也叫 编译器的自举)&#xff1a; (3). 查看可执行程序在连接时依赖的库: ldd 可执行程序的名字 。 (4). …