Jiahui Yu† Zirui Wang†{jiahuiyu, ziruiw}@google.comVijay Vasudevan Legg Yeung Mojtaba Seyedhosseini Yonghui WuGoogle Research
参考代码链接:https://github.com/lucidrains/CoCa-pytorch
模型效果对比网址:CoCa: Contrastive Captioners are Image-Text Foundation Models | Papers With Code
论文摘要:
本文介绍了对比描述生成器(CoCa),这是一种极简的设计,可以与(contrastive loss and captioning loss)对比损失和图像描述损失一起预训练图像-文本编码器-解码器基础模型,从而从 CLIP 等对比方法和 SimVLM 等生成方法中假设模型能力。与所有解码器层都关注编码器输出的标准编码器-解码器转换器相比,CoCa 在解码器层的前半部分省略了交叉注意来编码单模态文本表示,并级联剩余的解码器层,这些解码器层交叉关注图像编码器用于多模态图像-文本表示。除了自回归预测文本标记的多模态解码器输出上的字幕损失之外,我们还在单模态图像和文本嵌入之间应用了对比损失。通过共享相同的计算图,以最小的开销有效地计算两个训练目标。CoCa 在网络规模的 Alt-text 数据和带注释的图像上都是端到端和从头开始预训练的,通过将所有标签简单地视为文本,无缝统一自然语言监督以进行表示学习。
模型结构思路&后续下游任务可实现功能:
模型训练伪代码:
算法1 对比式描述器架构的伪代码。
# image, text.ids, text.labels, text.mask: 配对的图像,文本数据
# con_query: 1个用于对比嵌入的查询令牌
# cap_query: N个用于描述嵌入的查询令牌
# cls_token_id: 词汇表中的一个特殊cls_token_id
def 注意力池化(features, query):
out = 多头注意力(features, query)
return 层归一化(out)
img_feature = vit_encoder(image) # [批次, 序列长度, 维度]
cap_feature = 注意力池化(img_feature, cap_query) # [批次, N, 维度]
ids = 连接(text.ids, cls_token_id)
mask = 连接(text.mask, zeros_like(cls_token_id)) # 对cls_token_id进行非填充
txt_embs = 嵌入查找(ids)
unimodal_out = lm_transformers(txt_embs, mask, 交叉注意力=None)
multimodal_out = lm_transformers(
unimodal_out[:, :-1, :], mask, 交叉注意力=cap_feature)
cls_token_feature = 层归一化(unimodal_out)[:, -1:, :] # [批次,1, 维度]
con_loss = 对比损失(con_feature, cls_token_feature)
cap_loss = softmax交叉熵损失(
multimodal_out, 标签=text.labels, 掩码=text.mask)
vit_encoder:基于视觉变换器的编码器;lm_transformer:语言模型变换器