背景介绍
l2p有两个代码实现,官方的jax实现,和个人开源的pytorch实现。两种实现有若干区别,而在jax实现中能看到replay和review机制。
训练机制
先跳过繁琐的代码实现,介绍一下jax版实现的训练机制。以数据集cifar100为例,会分为10个task,每个任务有10个类别。训练时(默认没开启gaussian schedule时),先训任务0的数据,再训任务1、任务2等剩余的数据。
关于replay和review机制
l2p的jax实现与pytorch实现有若干区别,其中一个区别在于jax代码实现有replay机制,核心代码在train_continual.py的train_and_evaluate_per_task
,如下:
train_and_evaluate_per_task是个非常冗长的函数,接近400行,因此我们只展示代码片段,认识到核心机制即可。
从下面片段可知,每一轮epoch都有train和review阶段:
in_train_session = (step < initial_step + num_train_steps)
in_review_session = (step >= initial_step + num_train_steps)
然后以下代码有关replay和review机制:
if in_train_session:batch = jax.tree_map(np.asarray, next(train_iter))# replay starts# if replay, we should save it into the bufferif replay_buffer and (relative_step < num_savable_steps):replay_buffer.add_example(task_id, relative_step, batch)# if in 2nd or later task, we also sample from the bufferif replay_buffer and (task_id > 0) and (not review_trick):replay_batch = replay_buffer.get_random_batch(config.per_device_batch_size,config.continual.replay.include_new_task)# concatenate them through the batch_size axisimage_concat = np.concatenate([batch["image"], replay_batch["image"]], axis=1)label_concat = np.concatenate([batch["label"], replay_batch["label"]], axis=1)label_concat = label_concat.astype(np.int32)batch = {"image": image_concat, "label": label_concat}
else:batch = replay_buffer.get_random_batch(config.per_device_batch_size,True)
Q:replay机制的做法
A:从if replay_buffer and (task_id > 0) and (not review_trick):
可知,当到了第二个及以后的任务时,只要有replay buffer,且没有激活review机制,就可以使用replay。具体做法是取重放buffer里的任意一个batch,拼接在当前任务的batch后。
Q:如何收集重放数据?
A:从relative_step < num_savable_steps)
可知,前num_savable_steps
个步的batch数据会存到replay buffer。
Q:review阶段的做法?
A:从else分支可知,review阶段只回顾历史数据。而replay是指训练新数据的同时回顾历史数据。
gaussian_schedule
根据注释可见,gaussian_schedule的作用是可以平滑地过度到新的任务。
def gaussian_schedule(rng,num_classes=100,num_tasks=200,step_per_task=5,random_label=False):"""Returns a schedule where one task blends smoothly into the next."""schedule_length = num_tasks * step_per_task # schedule length in batchesepisode_length = step_per_task # episode length in batches# Each class label appears according to a Gaussian probability distribution# with peaks spread evenly over the schedulepeak_every = schedule_length // num_classeswidth = 50 # width of Gaussianpeaks = range(peak_every // 2, schedule_length, peak_every)schedule = []labels = jnp.array(list(range(num_classes)))if random_label:labels = jax.random.permutation(rng, labels) # labels in random orderfor ep_no in range(0, schedule_length // episode_length):lbls = []while not lbls: # make sure lbls isn't emptyfor j in range(len(peaks)):peak = peaks[j]# Sample from a Gaussian with peak in the right placep = gaussian(peak, width, ep_no * episode_length)(rng2, rng) = jax.random.split(rng)add = jax.random.bernoulli(rng2, p=p)if add:lbls.append(int(labels[j]))episode = {"label_set": np.array(lbls), "n_batches": episode_length}# episode = {'label_set': lbls}schedule.append(episode)return schedule
首先,每个episode包含step_per_task
个batch。每个batch涉及的label数,取决于采样label时,有多少label在采样时被激活了。
从peaks = range(peak_every // 2, schedule_length, peak_every)
可知,每个label都有一个正太分布,当peak_every=2时,其峰值分别为1,3,5,...
。
然后以下代码:
p = gaussian(peak, width, ep_no * episode_length)
(rng2, rng) = jax.random.split(rng)
add = jax.random.bernoulli(rng2, p=p)
可以看出,每次采样时,设置当前坐标为ep_no * episode_length
,在不同峰值的正太分布中有不同的概率值p,然后在p的伯努利分布中采样,如果激活了,则添加这个label,否则不添加。
这里把所有正太分布画成图更好理解,但懒得画了。
acc_matrix和forgetting指标
在这段代码中,acc_matrix是一个全局矩阵,用于存储每个任务在训练结束时的准确率。这个矩阵用于计算遗忘率(forgetting)和学习准确率(learning accuracy)。
具体来说,acc_matrix是一个二维数组,其中行表示任务的编号,列表示当前任务的编号。矩阵中的每个元素acc_matrix[i, task_id]存储了在任务task_id结束时,任务i的准确率。
从下方代码可知,acc_matrix甚至是下三角矩阵,对角线之上(不含对角线)的元素都是0。
acc_matrix = np.zeros((config.continual.num_tasks, config.continual.num_tasks))
从以下代码可知:
- forgetting的计算方式:当前训练到的任务号为
task_id
,遍历每个之前训的的任务,看它至今为止的峰值,减去在该轮的表现,得到差值。将差值取平均即是遗忘程度。 - backward的计算方式类似:当前训练到的任务号为
task_id
,遍历每个之前的任务,看它在训练期的表现,减去在该轮的表现,得到差值。将差值取平均即是遗忘程度。 - learning_acc:在范围[0,
task_id
]内的任务在训练期的精度的平均值。
这里如果把矩阵画成图会很好理解,但懒得画了。
if task_id > 0:forgetting = np.mean((np.max(acc_matrix, axis=1) -acc_matrix[:, task_id])[:task_id])backward = np.mean((acc_matrix[:, task_id] - diagonal)[:task_id])writer.write_scalars(summary_step, {"forgetting": forgetting,"backward": backward})
learning_acc = np.mean(diagonal[:(task_id + 1)])
writer.write_scalars(summary_step, {"learning_acc": learning_acc})
评估函数
评估函数为evaluate_tasks_till_now
,从主流程代码可知,该函数只测评序号范围[0, cur_task_id]的任务,换句话说,不会测那些还没训过的任务。
返回值为eval_metrics_list和prompt_idx_list,即序号范围内任务的测评指标,和那些任务用过的prompt_id(可用于统计)。
# logging.info(f"Starting evaluation for task 0 to {cur_task_id}.")
eval_metrics_list = []
prompt_idx_list = []for task_id in range(cur_task_id + 1):
调试时有两个困难:
- 其核心调用
eval_func
似乎不方便深入调试。只有当step=0的时候,函数内的断点才会生效。当step>0的时候,函数无法停在断点。 - 这段函数用到了一些不认识的第三方库,比如
from libml.eval_metrics import EvalMetrics_list
,clu.metrics.Accuracy
。
但总之这是个简单的分类任务,计算的无非就是准确率,不需要查看细节。调试变量eval_metrics
可以看到,其accuracy_0变量的正确率为 182 / 192 182/192 182/192,另外还保存了eval loss。只要知道返回值eval_metrics_list
保存了多个任务的准确率就够了。
深入eval_func
内部实现在eval_step
函数内,可以看出,数据的label,模型输出的logits和loss都会被用于gather_from_model_output
计算指标。
variables.update(state.model_state)
res= model(train=False).apply(variables, batch["image"], cls_features=cls_features, mutable=False)
logits = res["logits"]...loss = jnp.mean(losses.cross_entropy_loss(logits=logits, labels=batch["label"]))metrics_update = EvalMetrics_list[task_id].gather_from_model_output(logits=logits,labels=batch["label"],loss=loss,mask=batch.get("mask"),)
gather_from_model_output
的返回值类型为EvalMetrics_0:
分类网络的输出机制
参考vit.py的VisionTransformer类,对于模型的输出有如下处理:
elif self.classifier == 'prompt':x = x[:, 0:total_prompt_len]if self.reweight_prompt:reweight = self.param('reweight', nn.initializers.uniform(),(total_prompt_len,))reweight = nn.softmax(reweight)x = jnp.average(x, axis=1, weights=reweight)else:x = jnp.mean(x, axis=1)
当没有设置reweight_prompt时,模型的输出pre_logits就是前total_prompt_len个输出向量的均值,这个向量随后会送给分类头,作分类预测,如下所示。在cifar100中,一般分类头输出维度是100。
x = nn.Dense(features=self.num_classes,name='head',kernel_init=nn.initializers.zeros)(x)
x = x / self.temperature
res_vit['logits'] = x
return res_vit