多模态大语言模型(MLLM)-Blip2深度解读

前言

在这里插入图片描述
Blip2是一个多模态大语言模型,因其提出时间较早(2023年),且效果较好,很快成为一个标杆性工作。Blip2中提出的Q-former也成为衔接多模态和文本的重要桥梁。

在这里插入图片描述Blip2发表时间是2023年,现在引用已经3288了,表明大家对Blip2背后的多模态语言模型是多么的追捧。

作者前期的工作还有Albef(先对齐再融合)、Blip1,都是十分硬核的工作。

创新点

  • Blip2利用冻结的预训练图像模型和语言模型,来有效减小纯多模态模型的训练成本。提出Q-Former框架,通过两个训练阶段(图文对齐+图文指令微调)来弥补模态GAP。在visual question answering, image captioning, and image-text retrieval等典型视觉-语言任务上表现出色。
  • 由LLM驱动,BLIP-2 可以zero-shot得执行图像到文本生成的。因为LLM具备涌现效应,Blip2能够实现视觉知识推理、视觉问答等较难的任务。
  • 由于使用了冻结的预训练图像模型和语言模型,Blip2的训练成本更低,例如,BLIP-2 在零样本 VQAv2 上比 Flamingo高出 8.7%,同时可训练参数减少了 54 倍。

具体细节

Blip2通过两阶段训练,来学习Q-Former模块。两阶段训练包含:

  • 一阶段视觉-语言表示学习(vision-language representation learning stage)
  • 二阶段视觉-语言生成学习(vision-to-language generative learning stage)

Q-former模块在这里插入图片描述

如上图,Q-Former包括左右两列并行的attention模块。左列为self attention+cross attention+feed forward,右列为self attention+feed forward。
左列用于提取图像特征
右列用于提取文本特征
左列+右列用于提取多模态特征

左列-图像特征

左列做的事情整体可以理解为输入N个learned query,对N个learned query做self attention,对输入图像做cross attention,得到N个总结后的图像输出(类似于目标检测的DETR算法,不同query关注不同区域)。论文中,N等于32。
具体如下:

输入learned queries(随机初始化的embedding)

query_tokens = nn.Parameter(torch.zeros(1, num_query_token, encoder_config.hidden_size))query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)

先过self attention层,再和冻结的图像编码器输出的特征做cross attention,代码如下:

class BertLayer(nn.Module):def __init__(self, config, layer_num):super().__init__()self.config = configself.chunk_size_feed_forward = config.chunk_size_feed_forwardself.seq_len_dim = 1self.attention = BertAttention(config)self.layer_num = layer_numif (self.config.add_cross_attentionand layer_num % self.config.cross_attention_freq == 0):self.crossattention = BertAttention(config, is_cross_attention=self.config.add_cross_attention)self.has_cross_attention = Trueelse:self.has_cross_attention = Falseself.intermediate = BertIntermediate(config)self.output = BertOutput(config)self.intermediate_query = BertIntermediate(config)self.output_query = BertOutput(config)def forward(self,hidden_states,attention_mask=None,head_mask=None,encoder_hidden_states=None,encoder_attention_mask=None,past_key_value=None,output_attentions=False,query_length=0,):# decoder uni-directional self-attention cached key/values tuple is at positions 1,2self_attn_past_key_value = (past_key_value[:2] if past_key_value is not None else None)self_attention_outputs = self.attention(hidden_states,attention_mask,head_mask,output_attentions=output_attentions,past_key_value=self_attn_past_key_value,)attention_output = self_attention_outputs[0]outputs = self_attention_outputs[1:-1]present_key_value = self_attention_outputs[-1]if query_length > 0:query_attention_output = attention_output[:, :query_length, :]if self.has_cross_attention:assert (encoder_hidden_states is not None), "encoder_hidden_states must be given for cross-attention layers"cross_attention_outputs = self.crossattention(query_attention_output,attention_mask,head_mask,encoder_hidden_states,encoder_attention_mask,output_attentions=output_attentions,)query_attention_output = cross_attention_outputs[0]outputs = (outputs + cross_attention_outputs[1:-1])  # add cross attentions if we output attention weightslayer_output = apply_chunking_to_forward(self.feed_forward_chunk_query,self.chunk_size_feed_forward,self.seq_len_dim,query_attention_output,)if attention_output.shape[1] > query_length:layer_output_text = apply_chunking_to_forward(self.feed_forward_chunk,self.chunk_size_feed_forward,self.seq_len_dim,attention_output[:, query_length:, :],)layer_output = torch.cat([layer_output, layer_output_text], dim=1)else:layer_output = apply_chunking_to_forward(self.feed_forward_chunk,self.chunk_size_feed_forward,self.seq_len_dim,attention_output,)outputs = (layer_output,) + outputsoutputs = outputs + (present_key_value,)return outputs

其中,self.crossattention输入包含query_attention_output(self attention输出),encoder_hidden_states(图像编码器输出)。Q-Former的实现就是对BertLayer进行魔改,通过cross_attention_freq参数来控制cross attention的频率,如果等于3,则第0、3、6层BertLayer会对图像特征做cross attention,其他实现流程和原实现的(huggingface实现)BertLayer类似。

提取图像特征的Q-Former调用方式为:

        image = samples["image"]image_embeds = self.ln_vision(self.visual_encoder(image))image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)query_output = self.Qformer.bert(query_embeds=query_tokens,encoder_hidden_states=image_embeds,encoder_attention_mask=image_atts,use_cache=True,return_dict=True,)image_feats = F.normalize(self.vision_proj(query_output.last_hidden_state), dim=-1)
右列-文本特征

右列的实现流程和左列类似,不同之处在于将BertLayer层中hidden_states改为文本编码,编码方式的实现也类似于原实现的(huggingface实现),如下:


class BertEmbeddings(nn.Module):"""Construct the embeddings from word and position embeddings."""def __init__(self, config):super().__init__()self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load# any TensorFlow checkpoint fileself.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)self.dropout = nn.Dropout(config.hidden_dropout_prob)# position_ids (1, len position emb) is contiguous in memory and exported when serializedself.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")self.config = configdef forward(self,input_ids=None,position_ids=None,query_embeds=None,past_key_values_length=0,):if input_ids is not None:seq_length = input_ids.size()[1]else:seq_length = 0if position_ids is None:position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length].clone()if input_ids is not None:embeddings = self.word_embeddings(input_ids)if self.position_embedding_type == "absolute":position_embeddings = self.position_embeddings(position_ids)embeddings = embeddings + position_embeddingsif query_embeds is not None:embeddings = torch.cat((query_embeds, embeddings), dim=1)else:embeddings = query_embedsembeddings = self.LayerNorm(embeddings)embeddings = self.dropout(embeddings)return embeddings

即将输入文本转为token,再将token转为embedding,embedding作为BertLayer层中hidden_states的输入。
其中,不涉及到cross attention。

提取文本特征的Q-Former调用方式为:

        text = samples["text_input"]text_tokens = self.tokenizer(text,padding="max_length",truncation=True,max_length=self.max_txt_len,return_tensors="pt",).to(image.device)text_output = self.Qformer.bert(text_tokens.input_ids,attention_mask=text_tokens.attention_mask,return_dict=True,)text_feat = F.normalize(self.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1)
左列+右列

即输入图像特征,又输入文本,能够提取图文多模态表征,一般用作ITM的loss计算。

一阶段视觉-语言表示学习

该阶段的目的类似于clip,将视觉表征和文本表征拉到统一空间,具备三个损失函数,分别是Image-Text Contrastive Learning (ITC)、Image-grounded Text Generation (ITG)和Image-Text Matching (ITM)

Image-Text Contrastive Learning (ITC)

通过阶梯式的对比学习,缩小pair内图文距离,扩大pair间图文距离。
图像特征:Q-Former左列输出的图像特征,有N个
文本特征:Q-Former右列输出的CLS文本特征,有一个

因为图像特征有N个,每一次仅选择N个中和文本特征最近的那个来算ITC。在计算ITC时,负样本数量非常重要,通过多卡in-batch采样来获取多卡的负样本。代码详见:

		image_feats_all = concat_all_gather(image_feats)  # [batch_size*num_gpu, num_query_tokens, embed_dim]text_feat_all = concat_all_gather(text_feat)  # [batch_size*num_gpu, embed_dim]sim_q2t = torch.matmul(image_feats.unsqueeze(1), text_feat_all.unsqueeze(-1)).squeeze()# [batch_size, batch_size*num_gpu, num_query_tokens]# image-text similarity: aggregate across all query tokenssim_i2t, _ = sim_q2t.max(-1)sim_i2t = sim_i2t / self.temp# text-query similarity: [batch_size, batch_size*num_gpu, num_query_tokens]sim_t2q = torch.matmul(text_feat.unsqueeze(1).unsqueeze(1), image_feats_all.permute(0, 2, 1)).squeeze()# text-image similarity: aggregate across all query tokenssim_t2i, _ = sim_t2q.max(-1)sim_t2i = sim_t2i / self.temp  # [batch_size, batch_size*num_gpu]rank = dist.get_rank()bs = image.size(0)targets = torch.linspace(rank * bs, rank * bs + bs - 1, bs, dtype=int).to(image.device)loss_itc = (F.cross_entropy(sim_i2t, targets, label_smoothing=0.1)+ F.cross_entropy(sim_t2i, targets, label_smoothing=0.1)) / 2
Image-grounded Text Generation (ITG)

通过Q-Former架构,实现image caption任务。首先利用Q-Former提取图像特征,将图像特征作为输入,迭代式得利用LM loss约束文本输出。

       	text_tokens = self.tokenizer(text,padding="max_length",truncation=True,max_length=self.max_txt_len,return_tensors="pt",).to(image.device)image_embeds = self.ln_vision(self.visual_encoder(image))image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(image.device)query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1)query_output = self.Qformer.bert(query_embeds=query_tokens,encoder_hidden_states=image_embeds,encoder_attention_mask=image_atts,use_cache=True,return_dict=True,)decoder_input_ids = text_tokens.input_ids.clone()decoder_input_ids[:, 0] = self.tokenizer.bos_token_idlabels = decoder_input_ids.masked_fill(decoder_input_ids == self.tokenizer.pad_token_id, -100)query_atts = torch.ones(query_tokens.size()[:-1], dtype=torch.long).to(image.device)attention_mask = torch.cat([query_atts, text_tokens.attention_mask], dim=1)lm_output = self.Qformer(decoder_input_ids,attention_mask=attention_mask,past_key_values=query_output.past_key_values,return_dict=True,labels=labels,)loss_lm = lm_output.loss
Image-Text Matching (ITM)

ITM的本质在于输入一个图文pair,用以判断图文pair是否匹配。借助二分类来实现,输出1为匹配;输出0为不匹配。
Blip2在构造图文pair时,采用了多种采样策略

  • 匹配的图文pair,label均为1,表示匹配
  • 图固定,图-文距离为权值,利用torch.multinomial采样文,构成图文pair,label均为0
  • 文固定,图-文距离为权值,利用torch.multinomial采样图,构成图文pair,label均为0
    代码详见:
		text_input_ids_world = concat_all_gather(text_tokens.input_ids)text_attention_mask_world = concat_all_gather(text_tokens.attention_mask)image_embeds_world = all_gather_with_grad(image_embeds)with torch.no_grad():if "image_id" in samples.keys():mask = torch.eq(image_ids, image_ids_all.t())sim_t2i.masked_fill_(mask, -10000)sim_i2t.masked_fill_(mask, -10000)else:    sim_t2i[:, rank * bs : rank * bs + bs].fill_diagonal_(-10000)sim_i2t[:, rank * bs : rank * bs + bs].fill_diagonal_(-10000)            weights_t2i = F.softmax(sim_t2i, dim=1)weights_i2t = F.softmax(sim_i2t, dim=1)# select a negative image for each textimage_embeds_neg = []for b in range(bs):neg_idx = torch.multinomial(weights_t2i[b], 1).item()image_embeds_neg.append(image_embeds_world[neg_idx])image_embeds_neg = torch.stack(image_embeds_neg, dim=0)# select a negative text for each imagetext_ids_neg = []text_atts_neg = []for b in range(bs):neg_idx = torch.multinomial(weights_i2t[b], 1).item()text_ids_neg.append(text_input_ids_world[neg_idx])text_atts_neg.append(text_attention_mask_world[neg_idx])text_ids_neg = torch.stack(text_ids_neg, dim=0)text_atts_neg = torch.stack(text_atts_neg, dim=0)text_ids_all = torch.cat([text_tokens.input_ids, text_tokens.input_ids, text_ids_neg], dim=0)  # pos, pos, negtext_atts_all = torch.cat([text_tokens.attention_mask, text_tokens.attention_mask, text_atts_neg],dim=0,)query_tokens_itm = self.query_tokens.expand(text_ids_all.shape[0], -1, -1)query_atts_itm = torch.ones(query_tokens_itm.size()[:-1], dtype=torch.long).to(image.device)attention_mask_all = torch.cat([query_atts_itm, text_atts_all], dim=1)image_embeds_all = torch.cat([image_embeds, image_embeds_neg, image_embeds], dim=0)  # pos, neg, posimage_atts_all = torch.ones(image_embeds_all.size()[:-1], dtype=torch.long).to(image.device)output_itm = self.Qformer.bert(text_ids_all,query_embeds=query_tokens_itm,attention_mask=attention_mask_all,encoder_hidden_states=image_embeds_all,encoder_attention_mask=image_atts_all,return_dict=True,)vl_embeddings = output_itm.last_hidden_state[:, : query_tokens_itm.size(1), :]vl_output = self.itm_head(vl_embeddings)logits = vl_output.mean(dim=1)itm_labels = torch.cat([torch.ones(bs, dtype=torch.long), torch.zeros(2 * bs, dtype=torch.long)],dim=0,).to(image.device)loss_itm = F.cross_entropy(logits, itm_labels)

输入图像特征、文本,提取图文pair的多模态特征,再送到二元分类器,实现ITM的loss计算。

总体loss

总体loss等于上述三个loss的加和

        return BlipOutput(loss=loss_itc + loss_itm + loss_lm,loss_itc=loss_itc,loss_itm=loss_itm,loss_lm=loss_lm,)

二阶段视觉-语言生成学习

在这里插入图片描述
在生成阶段,需要做

  • 输入图像,借助冻结的图像编码器,提取图像特征。图像特征送入Q-Former,提取图文对齐的图像特征(32个embedding)
  • 采用一个FC层,将图像特征维度与LLM维度对齐
  • 将图像特征作为soft visual prompts,输入到LLM中,借助图文指令微调数据,训练Q-Former、FC层。

实验数据

预训练

Blip2采用129M图片,包括COCO、Visual Genome、CC3M、CC12M、SBU,以及LAION400M。其中115M来自于LAION400M,使用CapFilt对网图进行生成caption,具体步骤如下:
1、使用Blip模型 生成10个caption;
2、10个caption+原始web caption通过CLIP模型计算图像-caption排序;
3、选取top2作为该图的caption,以此作为训练数据;

预训练图像编码器与LLM

两个SOTA视觉transformer预训练模型:
ViT-L/14 from CLIP、ViT-G/14 from EVA-CLIP
移除ViT最后一层,使用倒数第二层特征。
LLM模型:
无监督训练的OPT作为decoder-based LLM
基于指令训练的FlanT5作为encoder-decoder-based LLM

预训练设置

第一阶段训练250k step,第二阶段训练80k step;ViT和LLM 转为FP16,FlanT5转为BFloat16,作者发现相对于32-bit,性能无下降;由于使用frozen模型,作者预训练比现在大规模VLP方法计算量都小,在16个A100(40G)上,对于ViT-G和FlanT5-XXL第一阶段训练耗时6天,第二阶段少于3天。
因为绝大部分参数(图像编码器、LLM)都冻结,所以训练成本较低,这也是Blip2较流行的一个原因


插一句嘴,目前的多模态LLM范式较Blip2更加简单。Blip2采用Q-Former实现图文对齐,现在的大部分工作直接采用FC层实现图文对齐,效果和Q-Former类似,但训练成本更低。相关工作有Llava1.5等

后续会持续更新多模态LLM的相关论文

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1557265.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

计算机毕业设计 自习室座位预约系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍:✌从事软件开发10年之余,专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ 🍅文末获取源码联系🍅 👇🏻 精…

【2024最新】基于springboot+vue的家具销售电商平台lw+ppt

作者:计算机搬砖家 开发技术:SpringBoot、php、Python、小程序、SSM、Vue、MySQL、JSP、ElementUI等,“文末源码”。 专栏推荐:SpringBoot项目源码、Vue项目源码、SSM项目源码、微信小程序源码 精品专栏:Java精选实战项…

Android OpenGLES2.0开发(四):矩阵变换和相机投影

事物的本质是事物本身所固有的、深藏于‌现象背后并决定或支配现象的方面‌。 还记得我们上一篇绘制的三角形吗,我们确实能够顺利用OpenGL ES绘制出图形了,这是一个好的开始,但这还远远不够。我们定义的坐标是正三角形,但是绘制出…

Python网络爬虫从入门到实战

目录 引言 一、网络爬虫的概念 二、 网络爬虫的基本工作流程 (一)过程: (二)安装requests模块和beautifulsoup4模块 (三)requests库的使用 1、requests库的基本介绍 2、导入requests库的…

使用tcpkill断开异常tcp连接

在linux系统中,遇到TCP链接迟迟不能释放的情况,类似FIN_WAIT1、FIN_WAIT2的状态,释放时间不确定,而且对应的程序已经关闭,相应的端口也不再监听,无法通过杀进程来解决,这种情况下,为…

大数据-157 Apache Kylin 背景 历程 特点 场景 架构 组件 详解

点一下关注吧!!!非常感谢!!持续更新!!! 目前已经更新到了: Hadoop(已更完)HDFS(已更完)MapReduce(已更完&am…

HarmonyOS NEXT - 表单录入组件封装(TextInput)

demo 地址: https://github.com/iotjin/JhHarmonyDemo 组件对应代码实现地址 代码不定时更新,请前往github查看最新代码 HarmonyOS NEXT - 表单录入组件封装(TextInput) 序JhFormInputCellJhFormSelectCellJhLoginTextField 序 鸿蒙next中有两…

PMP--冲刺题--解题--81-90

文章目录 12.采购管理--1.规划采购管理--在自制或外购分析中,可以使用回收期、投资回报率(ROI)内部报酬率(IRR)、现金流贴现、净现值(NPV)、收益成本净现值(BCA)或其他分…

如何使用ssm实现公司进销存管理系统设计与开发

TOC ssm792公司进销存管理系统设计与开发jsp 第1章 绪论 1.1选题动因 当前的网络技术,软件技术等都具备成熟的理论基础,市场上也出现各种技术开发的软件,这些软件都被用于各个领域,包括生活和工作的领域。随着电脑和笔记本的广…

python画图|曲线动态输出基础教程

在前述学习过程中,已经掌握基本的曲线图像画法,并尝试探索过3D动画基础教程。 相关文章可以通过下述链接直达: python画三角函数图|小白入门级教程_正余弦函数画图python-CSDN博客 python动画教程|Animations using Matplotlib-官网教程程…

电玩体验馆计时软件可以倒计时的软件试用版下载 佳易王ps5计时器管理系统使用教程

一、前言 【软件试用版下载可以点击本文章最下方官网卡片】 电玩体验馆计时软件可以倒计时的软件试用版下载 佳易王ps5计时器管理系统使用教程 1、软件能够记录玩家开始的时间和节数时间,从而计算出时长。 2、根据预设的收费标准,软件可以自动计算出…

各省份自然灾害损失造成的直接经济损失数据(2009-2022年)

自然灾害是自然演变过程中不可避免的现象,它们对人类社会构成了巨大的威胁。中国作为一个自然灾害频发的国家,面临着种类繁多的灾害挑战,包括气象灾害、地质灾害、海洋灾害、生物灾害和森林草原火灾等。 数据来源:《中国环境统计…

【GESP】C++一级练习BCQM3030,保留12位小数

浮点数数位保留练习,%m.nf知识点,已在BCQM3027中详细介绍。 题解详见:https://www.coderli.com/gesp-1-bcqm3030/ 【GESP】C一级练习BCQM3030,保留12位小数 | OneCoder浮点数数位保留练习,%m.nf知识点,已在…

培训机构客户管理系统的设计+ssm论文源码调试讲解

2 系统开发环境 2.1微信开发者工具 微信开发者工具现在已经被小程序开发团队开发运行,目前微信开发者工具任然在不断的完善中,在开发小程序时经常要不断的更新。可以使用微信扫码登陆开发者工具,开发者工具将使用这个微信帐号的信息进行小程…

Android车载音频系统概览

目录 1. 什么是Android车载音频系统 2. Android 声音和声音流 2.1 Android 声音 2.2 外部声音流 2.3 输出设备 章节说明:本节内容是Android车载音频系统简介。 1. 什么是Android车载音频系统 官方英文名称是:Automotive audio systems 由于汽车上无论是音频设备的数量还…

QT 实现图片查看工具

QT 实现图片查看工具 1、选择图像文件 单文件选择 QFileDialog::getOpenFileName多文件选择 QFileDialog::getOpenFileNamesQList<QString> imageNames = QFileDialog::getOpenFileNames(this,tr("打开图片"),"",tr("图片文件 (*.png *.jpg *.b…

OpenHarmony中OpenSSL从1.1.1 升级到3.0.7 时不支持MD4算法导致wpa_supplicant报错问题解决

OpenHarmony中OpenSSL从1.1.1 升级到3.0.7 时不支持MD4算法导致wpa_supplicant报错问题解决 1 问题现象 我们在测试EAP-PEAP(MSCHAPV2)功能时发现如下打印,导致认证失败 2 初步分析 openssl_digest_vector 中 调用EVP_DigestInit_ex 时如果报错 会打印"OpenSSL: EVP…

人工智能重点知识点总结整理

一、章节知识点目录 绪论知识表示方法确定性推理非经典推理计算智能机器学习完结篇二、每章重点内容 1. 绪论 定义:人工智能(AI)

厨房用品分割系统源码&数据集分享

厨房用品分割系统源码&#xff06;数据集分享 [yolov8-seg-C2f-DCNV3&#xff06;yolov8-seg-AFPN-P345等50全套改进创新点发刊_一键训练教程_Web前端展示] 1.研究背景与意义 项目参考ILSVRC ImageNet Large Scale Visual Recognition Challenge 项目来源AAAI Global Al ln…

论文阅读笔记-XLNet: Generalized Autoregressive Pretraining for Language Understanding

前言 Google发布的XLNet在问答、文本分类、自然语言理解等任务上都大幅超越BERT,XLNet提出一个框架来连接语言建模方法和预训练方法。我们所熟悉的BERT是denoising autoencoding模型,最大的亮点就是能够获取上下文相关的双向特征表示,所以相对于标准语言模型(自回归)的预…