train.py
在main函数中找到这个构建模型的地方,ctrl+左键点进这个函数中去
来到了这里
又来到了这里,这里就是构建模型的地方:
又来到了这里,还是在VLTVG.py这个文件中:
Method
The Overall Network
Visual-Linguistic Verification Module
输入图像首先经过卷积网络,然后再经过transformer encoders进行编码,得到视觉特征硬上映射Fv,Fv中包括图像中对象实例地特征,但是没有先验的语言文本信息,
# Image feature encoder (CNN + Transformer encoder)self.backbone = build_backbone(args)self.trans_encoder = build_visual_encoder(args)self.input_proj = nn.Conv2d(self.backbone.num_channels, self.trans_encoder.d_model, kernel_size=1)
self.backbone = build_backbone(args)构造的backbone如下
Joiner((0): Backbone((body): IntermediateLayerGetter((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): FrozenBatchNorm2d()(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): Bottleneck((conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): FrozenBatchNorm2d()))(1): Bottleneck((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True)))(layer2): Sequential((0): Bottleneck((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): FrozenBatchNorm2d()))(1): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True))(3): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True)))(layer3): Sequential((0): Bottleneck((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): FrozenBatchNorm2d()))(1): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True))(3): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True))(4): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True))(5): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True)))(layer4): Sequential((0): Bottleneck((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): FrozenBatchNorm2d()))(1): Bottleneck((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): FrozenBatchNorm2d()(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): FrozenBatchNorm2d()(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): FrozenBatchNorm2d()(relu): ReLU(inplace=True)))))(1): PositionEmbeddingSine()
)
接下里看一下build_visual_encoder这个函数:
def build_visual_encoder(args):return VisualEncoder(d_model=args.hidden_dim,dropout=args.dropout,nhead=args.nheads,dim_feedforward=args.dim_feedforward,num_encoder_layers=args.enc_layers,normalize_before=args.pre_norm)
hidden_dim:#输入的单词(或其他元素)会通过一个嵌入层转换为一个固定维度的向量比如512,如果多头注意的话,每个头处理的就是hidden_dim/n_heads
dim_feedforward:encoder中还有前馈神经网络,通常是由两个先行层和一个激活层组成,第一个linear通常是将hidden_dim(256较低)转成dim_feedforward(2048较高)
第二个linear层就是将dim_feedforward再重新变成hidden_dim
self.trans_encoder = build_visual_encoder(args)构造的视觉编码器如下:
VisualEncoder((encoder): TransformerEncoder((layers): ModuleList((0): TransformerEncoderLayer((self_attn): MultiheadAttention((out_proj): Linear(in_features=256, out_features=256, bias=True))(linear1): Linear(in_features=256, out_features=2048, bias=True)(dropout): Dropout(p=0.1, inplace=False)(linear2): Linear(in_features=2048, out_features=256, bias=True)(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(dropout1): Dropout(p=0.1, inplace=False)(dropout2): Dropout(p=0.1, inplace=False)(activation): ReLU(inplace=True))(1): TransformerEncoderLayer((self_attn): MultiheadAttention((out_proj): Linear(in_features=256, out_features=256, bias=True))(linear1): Linear(in_features=256, out_features=2048, bias=True)(dropout): Dropout(p=0.1, inplace=False)(linear2): Linear(in_features=2048, out_features=256, bias=True)(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(dropout1): Dropout(p=0.1, inplace=False)(dropout2): Dropout(p=0.1, inplace=False)(activation): ReLU(inplace=True))(2): TransformerEncoderLayer((self_attn): MultiheadAttention((out_proj): Linear(in_features=256, out_features=256, bias=True))(linear1): Linear(in_features=256, out_features=2048, bias=True)(dropout): Dropout(p=0.1, inplace=False)(linear2): Linear(in_features=2048, out_features=256, bias=True)(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(dropout1): Dropout(p=0.1, inplace=False)(dropout2): Dropout(p=0.1, inplace=False)(activation): ReLU(inplace=True))(3): TransformerEncoderLayer((self_attn): MultiheadAttention((out_proj): Linear(in_features=256, out_features=256, bias=True))(linear1): Linear(in_features=256, out_features=2048, bias=True)(dropout): Dropout(p=0.1, inplace=False)(linear2): Linear(in_features=2048, out_features=256, bias=True)(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(dropout1): Dropout(p=0.1, inplace=False)(dropout2): Dropout(p=0.1, inplace=False)(activation): ReLU(inplace=True))(4): TransformerEncoderLayer((self_attn): MultiheadAttention((out_proj): Linear(in_features=256, out_features=256, bias=True))(linear1): Linear(in_features=256, out_features=2048, bias=True)(dropout): Dropout(p=0.1, inplace=False)(linear2): Linear(in_features=2048, out_features=256, bias=True)(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(dropout1): Dropout(p=0.1, inplace=False)(dropout2): Dropout(p=0.1, inplace=False)(activation): ReLU(inplace=True))(5): TransformerEncoderLayer((self_attn): MultiheadAttention((out_proj): Linear(in_features=256, out_features=256, bias=True))(linear1): Linear(in_features=256, out_features=2048, bias=True)(dropout): Dropout(p=0.1, inplace=False)(linear2): Linear(in_features=2048, out_features=256, bias=True)(norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(dropout1): Dropout(p=0.1, inplace=False)(dropout2): Dropout(p=0.1, inplace=False)(activation): ReLU(inplace=True))))
)
self.input_proj = nn.Conv2d(self.backbone.num_channels, self.trans_encoder.d_model, kernel_size=1)
Conv2d(2048, 256, kernel_size=(1, 1), stride=(1, 1))
接下来就是构造一个Text feature encoder,使用的是Bert模型,用了12层transformer encoders
self.bert = BertModel.from_pretrained(args.bert_model)
self.bert_proj = nn.Linear(args.bert_output_dim, args.hidden_dim)
self.bert_output_layers = args.bert_output_layers
接下来就是vg_decoder的构造
# visual grounding
self.trans_decoder = build_vg_decoder(args)
首先看论文中提出的Visual-linguistic verification这个模块:
上面框出的代码会构造一个DiscriminativeFeatEncLayer框架,如下所示
DiscriminativeFeatEncLayer((img2text_attn): MultiheadAttention((out_proj): Linear(in_features=256, out_features=256, bias=True))(text_proj): MLP((layers): ModuleList((0): Linear(in_features=256, out_features=256, bias=True)))(img_proj): MLP((layers): ModuleList((0): Linear(in_features=256, out_features=256, bias=True)))(img2textcond_attn): MultiheadAttention((out_proj): Linear(in_features=256, out_features=256, bias=True))(img2img_attn): MHAttentionRPE((out_proj): Linear(in_features=256, out_features=256, bias=True))(norm_text_cond_img): LayerNorm((256,), eps=1e-05, elementwise_affine=True)(norm_img): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
)
下面这一段代码就是Language-guided Context Encoder的执行过程,就是这张流程图的过程
text_cond_info = self.img2textcond_attn(query=img_feat, key=self.with_pos_embed(word_feat, word_pos),value=word_feat, key_padding_mask=word_key_padding_mask)[0]q = k = img_feat + text_cond_info
text_cond_img_ctx = self.img2img_attn(query=q, key=k, value=img_feat, key_padding_mask=img_key_padding_mask)[0]# discriminative feature
fuse_img_feat = (self.norm_img(img_feat) +self.norm_text_cond_img(text_cond_img_ctx)) * verify_scorereturn torch.cat([orig_img_feat, fuse_img_feat], dim=-1)
multi-stage cross-modal decoder:迭代地查询和考虑视觉和语言信息,减少推理过程中地起义