2024年增量学习(二) l2p的jax版代码分析

背景介绍

l2p有两个代码实现,官方的jax实现,和个人开源的pytorch实现。两种实现有若干区别,而在jax实现中能看到replay和review机制。

训练机制

先跳过繁琐的代码实现,介绍一下jax版实现的训练机制。以数据集cifar100为例,会分为10个task,每个任务有10个类别。训练时(默认没开启gaussian schedule时),先训任务0的数据,再训任务1、任务2等剩余的数据。

关于replay和review机制

l2p的jax实现与pytorch实现有若干区别,其中一个区别在于jax代码实现有replay机制,核心代码在train_continual.py的train_and_evaluate_per_task,如下:

train_and_evaluate_per_task是个非常冗长的函数,接近400行,因此我们只展示代码片段,认识到核心机制即可。

从下面片段可知,每一轮epoch都有train和review阶段:

in_train_session = (step < initial_step + num_train_steps)
in_review_session = (step >= initial_step + num_train_steps)

然后以下代码有关replay和review机制:

if in_train_session:batch = jax.tree_map(np.asarray, next(train_iter))# replay starts# if replay, we should save it into the bufferif replay_buffer and (relative_step < num_savable_steps):replay_buffer.add_example(task_id, relative_step, batch)# if in 2nd or later task, we also sample from the bufferif replay_buffer and (task_id > 0) and (not review_trick):replay_batch = replay_buffer.get_random_batch(config.per_device_batch_size,config.continual.replay.include_new_task)# concatenate them through the batch_size axisimage_concat = np.concatenate([batch["image"], replay_batch["image"]], axis=1)label_concat = np.concatenate([batch["label"], replay_batch["label"]], axis=1)label_concat = label_concat.astype(np.int32)batch = {"image": image_concat, "label": label_concat}
else:batch = replay_buffer.get_random_batch(config.per_device_batch_size,True)

Q:replay机制的做法
A:从if replay_buffer and (task_id > 0) and (not review_trick):可知,当到了第二个及以后的任务时,只要有replay buffer,且没有激活review机制,就可以使用replay。具体做法是取重放buffer里的任意一个batch,拼接在当前任务的batch后。

Q:如何收集重放数据?
A:从relative_step < num_savable_steps)可知,前num_savable_steps个步的batch数据会存到replay buffer。

Q:review阶段的做法?
A:从else分支可知,review阶段只回顾历史数据。而replay是指训练新数据的同时回顾历史数据。

gaussian_schedule

根据注释可见,gaussian_schedule的作用是可以平滑地过度到新的任务。

def gaussian_schedule(rng,num_classes=100,num_tasks=200,step_per_task=5,random_label=False):"""Returns a schedule where one task blends smoothly into the next."""schedule_length = num_tasks * step_per_task  # schedule length in batchesepisode_length = step_per_task  # episode length in batches# Each class label appears according to a Gaussian probability distribution# with peaks spread evenly over the schedulepeak_every = schedule_length // num_classeswidth = 50  # width of Gaussianpeaks = range(peak_every // 2, schedule_length, peak_every)schedule = []labels = jnp.array(list(range(num_classes)))if random_label:labels = jax.random.permutation(rng, labels)  # labels in random orderfor ep_no in range(0, schedule_length // episode_length):lbls = []while not lbls:  # make sure lbls isn't emptyfor j in range(len(peaks)):peak = peaks[j]# Sample from a Gaussian with peak in the right placep = gaussian(peak, width, ep_no * episode_length)(rng2, rng) = jax.random.split(rng)add = jax.random.bernoulli(rng2, p=p)if add:lbls.append(int(labels[j]))episode = {"label_set": np.array(lbls), "n_batches": episode_length}# episode = {'label_set': lbls}schedule.append(episode)return schedule

首先,每个episode包含step_per_task个batch。每个batch涉及的label数,取决于采样label时,有多少label在采样时被激活了。
peaks = range(peak_every // 2, schedule_length, peak_every)可知,每个label都有一个正太分布,当peak_every=2时,其峰值分别为1,3,5,...
然后以下代码:

p = gaussian(peak, width, ep_no * episode_length)
(rng2, rng) = jax.random.split(rng)
add = jax.random.bernoulli(rng2, p=p)

可以看出,每次采样时,设置当前坐标为ep_no * episode_length,在不同峰值的正太分布中有不同的概率值p,然后在p的伯努利分布中采样,如果激活了,则添加这个label,否则不添加。

这里把所有正太分布画成图更好理解,但懒得画了。

acc_matrix和forgetting指标

在这段代码中,acc_matrix是一个全局矩阵,用于存储每个任务在训练结束时的准确率。这个矩阵用于计算遗忘率(forgetting)和学习准确率(learning accuracy)。

具体来说,acc_matrix是一个二维数组,其中行表示任务的编号,列表示当前任务的编号。矩阵中的每个元素acc_matrix[i, task_id]存储了在任务task_id结束时,任务i的准确率。

从下方代码可知,acc_matrix甚至是下三角矩阵,对角线之上(不含对角线)的元素都是0。

acc_matrix = np.zeros((config.continual.num_tasks, config.continual.num_tasks))

从以下代码可知:

  1. forgetting的计算方式:当前训练到的任务号为task_id,遍历每个之前训的的任务,看它至今为止的峰值,减去在该轮的表现,得到差值。将差值取平均即是遗忘程度。
  2. backward的计算方式类似:当前训练到的任务号为task_id,遍历每个之前的任务,看它在训练期的表现,减去在该轮的表现,得到差值。将差值取平均即是遗忘程度。
  3. learning_acc:在范围[0, task_id]内的任务在训练期的精度的平均值。

这里如果把矩阵画成图会很好理解,但懒得画了。

if task_id > 0:forgetting = np.mean((np.max(acc_matrix, axis=1) -acc_matrix[:, task_id])[:task_id])backward = np.mean((acc_matrix[:, task_id] - diagonal)[:task_id])writer.write_scalars(summary_step, {"forgetting": forgetting,"backward": backward})
learning_acc = np.mean(diagonal[:(task_id + 1)])
writer.write_scalars(summary_step, {"learning_acc": learning_acc})

评估函数

评估函数为evaluate_tasks_till_now,从主流程代码可知,该函数只测评序号范围[0, cur_task_id]的任务,换句话说,不会测那些还没训过的任务。
返回值为eval_metrics_list和prompt_idx_list,即序号范围内任务的测评指标,和那些任务用过的prompt_id(可用于统计)。

# logging.info(f"Starting evaluation for task 0 to {cur_task_id}.")
eval_metrics_list = []
prompt_idx_list = []for task_id in range(cur_task_id + 1):

调试时有两个困难:

  1. 其核心调用eval_func似乎不方便深入调试。只有当step=0的时候,函数内的断点才会生效。当step>0的时候,函数无法停在断点。
  2. 这段函数用到了一些不认识的第三方库,比如from libml.eval_metrics import EvalMetrics_listclu.metrics.Accuracy

但总之这是个简单的分类任务,计算的无非就是准确率,不需要查看细节。调试变量eval_metrics可以看到,其accuracy_0变量的正确率为 182 / 192 182/192 182/192,另外还保存了eval loss。只要知道返回值eval_metrics_list保存了多个任务的准确率就够了

深入eval_func
内部实现在eval_step函数内,可以看出,数据的label,模型输出的logits和loss都会被用于gather_from_model_output计算指标。

variables.update(state.model_state)
res= model(train=False).apply(variables, batch["image"], cls_features=cls_features, mutable=False)
logits = res["logits"]...loss = jnp.mean(losses.cross_entropy_loss(logits=logits, labels=batch["label"]))metrics_update = EvalMetrics_list[task_id].gather_from_model_output(logits=logits,labels=batch["label"],loss=loss,mask=batch.get("mask"),)

gather_from_model_output的返回值类型为EvalMetrics_0:

分类网络的输出机制

参考vit.py的VisionTransformer类,对于模型的输出有如下处理:

    elif self.classifier == 'prompt':x = x[:, 0:total_prompt_len]if self.reweight_prompt:reweight = self.param('reweight', nn.initializers.uniform(),(total_prompt_len,))reweight = nn.softmax(reweight)x = jnp.average(x, axis=1, weights=reweight)else:x = jnp.mean(x, axis=1)

当没有设置reweight_prompt时,模型的输出pre_logits就是前total_prompt_len个输出向量的均值,这个向量随后会送给分类头,作分类预测,如下所示。在cifar100中,一般分类头输出维度是100。

x = nn.Dense(features=self.num_classes,name='head',kernel_init=nn.initializers.zeros)(x)
x = x / self.temperature
res_vit['logits'] = x
return res_vit

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/149485.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

MODELS 2024震撼续章:科技与可持续性的未来交响曲

MODELS 2024国际会议正如火如荼地进行着&#xff0c;每一天都充满了新的发现与启迪&#xff0c;每一场分享都是对技术前沿的一次深刻探索&#xff0c;更是对现实世界可持续性挑战的一次积极回应。现在让我们继续这场科技盛宴&#xff0c;看看小编为您精选几场的学术分享吧~ 会议…

python如何实现日期加减

首先通过import datetime&#xff0c;导入日期处理库。 然后把日期转化成datetime标准格式&#xff0c;使用datetime.datetime.strptime()方法将字符串格式的时间转化为标准格式。 其中"%Y/%m/%d %H:%M:%S"为time字符串的时间格式&#xff1a;Y为年&#xff0c;m为月…

请不要在TS中使用Function类型

在 TypeScript 中&#xff0c;避免使用 Function 作为类型。Function 代表的是“任意类型的函数”&#xff0c;这会带来类型安全问题。对于绝大多数情况&#xff0c;你可能更希望明确地指定函数的参数和返回值类型。 如果你确实想表达一个可以接收任意数量参数并返回任意类型的…

Android13中Android.mk和Android.bp预编译多种架构文件

需求&#xff1a; 1&#xff0c; 当前有多个架构的config文件&#xff0c;但是需要不同架构使用不同config文件 2&#xff0c; 必须将config文件拷贝到out/host目录下 常规思路 在Android.bp中&#xff0c; 一般在编译多架构文件时&#xff0c;都会使用arch属性&#xff…

Stable Diffusion绘画 | XYZ Plot:让对比一目了然

XYZ Plot 是 SD 自带的&#xff0c;无需额外安装。 它的作用&#xff0c;是给我们用来对比不同参数下&#xff0c;生成图片效果的区别。 位置在页面左侧底部&#xff1a; 实操 开启 x轴进行对比&#xff0c;这里面有各种可选的对比参数&#xff1a; 现在 X轴类型 选择「Sampler…

【秋招笔试题】阵营分配

解法&#xff1a;简单背包题。 def solve(nums):n len(nums)totalSum sum(nums)dp [[False] * (totalSum // 2 1) for _ in range(n 1)]for i in range(n 1):dp[i][0] Truefor i in range(1, n 1):for j in range(1, totalSum // 2 1):if nums[i - 1] < j:dp[i][j…

网上超市开发:SpringBoot技术要点

3 系统分析 这部分内容虽然在开发流程中处于最开始的环节&#xff0c;但是它对接下来的设计和实现起着重要的作用&#xff0c;因为系统分析结果的好坏&#xff0c;将直接影响后面环节的开展。 3.1可行性研究 影响系统开发的因素有很多&#xff0c;比如开发成本高就不适合开展&a…

Skyeye 云智能制造 v3.14.6 发布,ERP 商城

Skyeye 云智能制造&#xff0c;采用 Springboot winUI 的低代码平台、移动端采用 UNI-APP。包含 30 多个应用模块、50 多种电子流程&#xff0c;CRM、PM、ERP、MES、ADM、EHR、笔记、知识库、项目、门店、商城、财务、多班次考勤、薪资、招聘、云售后、论坛、公告、问卷、报表…

基于STM32的点滴输液报警器-设计说明书

设计摘要&#xff1a; 本文介绍了基于STM32微控制器的点滴输液报警器的设计与实现。点滴输液是医疗领域中常见的治疗方式&#xff0c;但输液速度的控制对患者的安全和治疗效果至关重要。因此&#xff0c;设计一种能够监测输液速度并在异常情况下发出警报的系统显得十分必要。基…

Linux:进程间通信之命名管道

Linux&#xff1a;进程间通信-CSDN博客 我们说匿名管道只能用于父子进程这样的关系通信&#xff0c;那么陌生进程怎么通信&#xff1f; 我们之前说父子进程能通信的最关键的地方就在于子进程复制了一份父进程的files_struct&#xff0c;从而通过文件的inode映射同一份文件来通…

Serilog文档翻译系列(五) - 编写日志事件

日志事件通过 Log 静态类或 ILogger 接口上的方法写入接收器。下面的示例将使用 Log 以便语法简洁&#xff0c;但下面显示的方法同样可用于接口。 Log.Warning("Disk quota {Quota} MB exceeded by {User}", quota, user); 通过此日志方法创建的警告事件将具有两个相…

亿发零售云解析:新零售破局与年轻群体消费趋势变化

近年来&#xff0c;随着数字化、智能化的快速发展&#xff0c;“新零售”概念逐渐成为商业领域的热门话题。相比传统零售&#xff0c;新零售通过线上与线下的深度融合&#xff0c;利用大数据、人工智能等技术&#xff0c;赋能消费者与品牌之间的互动。尤其在年轻消费群体中&…

JS 特殊运算符有哪些?

JavaScript 特殊运算符有哪些&#xff1f; 众多编程语言之中JavaScript &#xff0c;以其强大而全面的功能深受前端开发者喜爱。其丰富的运算符集&#xff0c;不仅包括了广泛应用的算术运算符、比较运算符以及逻辑运算符&#xff0c;还蕴藏着一系列较为冷门但同样功能强大的运算…

LVGL第一篇-了解lvgl显示原理以及使用C++移植

一、引言 在当今嵌入式系统与图形界面开发的广阔领域中&#xff0c;轻量级图形库 LVGL&#xff08;Light and Versatile Graphics Library&#xff09;恰似一颗璀璨耀眼的明星&#xff0c;正日益受到开发者们的热烈推崇与追逐。它以小巧精致之姿、高效卓越之能以及丰富多元之功…

Qt_事件的介绍

目录 1、理解事件 2、处理事件QEvent 3、键盘事件QKeyEvent 4、鼠标事件QMouseEvent 4.1 鼠标点击事件 4.2 鼠标释放事件 4.3 鼠标移动事件 5、滚轮事件QWheelEvent 6、定时器事件QTimerEvent 7、窗口事件QMoveEvent 8、事件分发器event 9、事件过滤器even…

峟思助力堤防工程安全:构建多功能防洪屏障

堤防工程&#xff0c;作为水利建设中至关重要的防护体系&#xff0c;不仅守护着江河、湖泊及滨海区域的安全&#xff0c;更是确保人民生命财产安全的坚固防线。在现代社会&#xff0c;随着技术的进步与安全意识的提升&#xff0c;堤防工程不仅限于传统的防洪功能&#xff0c;更…

CVPR最牛图像评价算法!

本文所涉及所有资源均在 传知代码平台可获取。 目录 概述 一、论文思路 1.多任务学习框架&#xff1a; 2.视觉-语言对应关系&#xff1a; 3.动态损失权重&#xff1a; 4.模型优化和评估&#xff1a; 二、模型介绍 三、详细实现方法 1.图像编码器和语言编码器&#xff08;Image…

Solidity语言:重点学习Solidity编程语言,这是EVM上最常用的智能合约语言。

Solidity是一种面向合约的编程语言,用于在以太坊虚拟机(EVM)上编写智能合约。它是Solidity开发者在以太坊平台上创建智能合约的主要选择之一。 学习Solidity的重点包括以下几方面: 语法和数据类型:学习Solidity的基本语法、数据类型、变量声明和函数定义等。 智能合约:了…

刷完这个笔记,17K不能再少了....

大家好&#xff0c;最近有不少小伙伴在后台留言&#xff0c;得准备面试了&#xff0c;又不知道从何下手&#xff01;为了帮大家节约时间&#xff0c;特意准备了一份面试相关的资料&#xff0c;内容非常的全面&#xff0c;真的可以好好补一补&#xff0c;希望大家在都能拿到理想…

cobaltstrike之execute-assembly内存加载—后渗透利用

通过execute-assembly内存加载来执行文件&#xff0c;从而避免后渗透中被杀毒软件静态报毒&#xff0c;使更多的工具能够继续利用&#xff0c;常见的方式有权限维持&#xff0c;代理上线等操作 远程bin文件加载 首先尝试远程加载bin文件 使用项目https://github.com/shanekha…