第一篇:【以图搜图代码实现】–犬类以图搜图示例 使用保存成h5文件,使用向量积来度量相似性,实现了以图搜图,说明了可以优化的点。
第二篇:【使用resnet18训练自己的数据集】 准对模型问题进行了优化,取得了显著性的效果。
本篇继续第一篇中所说的优化方向,使用faiss实现以图搜图。
1.faiss使用介绍
Faiss的全称是Facebook AI Similarity Search,是FaceBook针对大规模相似度检索问题开发的一个工具,底层是使用C++代码实现的,提供了python的接口,号称对10亿量级的索引可以做到毫秒级检索。
使用faiss的基本步骤
1、数据转换:把原始数据转换为"float32"数据类型的向量。
2、index构建:用 faiss 构建index
3、数据添加:将向量add到创建的index中
4、通过创建的index进行检索
1.创建索引
import faissdef create_index(datas_embedding):# 构建索引,L2代表构建的index采用的相似度度量方法为L2范数# 必须传入一个向量的维度,创建一个空的索引index = faiss.IndexFlatL2(datas_embedding.shape[1]) # 把向量数据加入索引index.add(datas_embedding) return index
2.保存索引
def faiss_index_save(faiss_index, save_file_location):faiss.write_index(faiss_index, save_file_location)
3.加载索引
def faiss_index_load(faiss_index_save_file_location):index = faiss.read_index(faiss_index_save_file_location)return index
4.向索引中添加向量
def index_data_add(faiss_index, img_path):# 获得索引向量的数量print(faiss_index.ntotal)img_embedding = extract_image_features(img_path)faiss_index.add(img_embedding)print(faiss_index.ntotal)
5.删除索引中的向量
def index_data_delete(faiss_index):print(faiss_index.ntotal)# remove, 指定要删除的向量id,是一个np的arrayfaiss_index.remove_ids(np.array([0]))print(faiss_index.ntotal)
可以看出使用Faiss工具更加的灵活,可以向索引中添加和删除向量。
2.faiss实现以图搜图
本篇代码有部分是在前两篇的基础之上的,这里使用11类犬类数据集微调之后的resnet18进行特征提取。
第一篇:【以图搜图代码实现】–犬类以图搜图示例
第二篇:【使用resnet18训练自己的数据集】
数据集准备和下载可以去看第二篇文章。
1.模型加载
为了更好的适配,对第一篇中的resnet18的初始化方法进行了修改,如下:
@Project :ImageRec
@File :resnet18.py
@IDE :PyCharm
@Author :菜菜2024
@Date :2024/9/30
'''
from PIL import Image
from torchvision import transforms
import torch
import torch.nn as nn
from torchvision import modelsclass ResNet18:def __init__(self,out_feature = 11,model_path='E:\\xxx\\ImageRec\\weights\\resnet18.pth'):self.trans = transforms.Compose([transforms.Resize(size=(256, 256)),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])print("-----------loading resnet18------------")self.model = models.resnet18()num_feats = self.model.fc.in_featuresself.model.fc = nn.Linear(num_feats, out_feature)self.model.load_state_dict(torch.load(model_path))self.model.eval()def extract_image_features(self, img_path):image = Image.open(img_path).convert('RGB')image_tensor = self.trans(image).unsqueeze(0)with torch.no_grad():features = self.model(image_tensor)return features
其中out_feature 根据自己的数据集的类别个数进行更改,我这里的犬类是11种。model_path是训练好的保存的权重文件【训练过程可以去看第二篇】
2.文件名映射
在第一篇:【以图搜图代码实现】–犬类以图搜图示例 中使用的是保存成h5文件,索引是没有要求是整数的,这里faiss要求是整数,搞了一个映射方法,同时也是为了在后面可视化的时候,能根据索引再解码得到对应的文件路径。
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec
@File :Imgmap.py
@IDE :PyCharm
@Author :菜菜2024
@Date :2024/9/29 18:02
'''
import os
import uuid
import numpy as npdef getImgMap(img_path):# 为类别生成一个映射文件subnames = [f.split('\\')[-1] for f in os.listdir(img_path)]element_mapping = {}for i in range(len(subnames)):unique_id = str(i+2024)element_mapping[unique_id] = subnames[i]return element_mappingdef valueGetKey(mapping, target_value):for key, value in mapping.items():if value == target_value:# print(f"值 '{target_value}' 对应的键是: {key}")breakreturn keydef nameMap(imgnames, img_path='E:\\xxx\\datas\\pet_dog\\train'):'''getImagVector函数得到的image_ids在保存为h5文件时进行了编码现在faiss工具中index需要是int类型的,这里进行映射转化:param img_path: 数据集目录,来得到类别映射:param imgnames: 需要映射的图片名称,解码之后是“中华田园犬_0”格式这里传参是列表:return:'''element_mapping = getImgMap(img_path)decode_names = [imgname.decode('utf-8') for imgname in imgnames]name_ids=[]for decode_name in decode_names:cla_name = decode_name.split("_")[0]img_name = decode_name.split("_")[-1]key = valueGetKey(element_mapping, cla_name)name_id = key+img_namename_ids.append(name_id)name_ids=np.array(name_ids).astype('int32')return name_idsif __name__ == "__main__":database = 'E:\\xxx\\datas\\pet_dog\\train'element_mapping = getImgMap(database)print(element_mapping)print(element_mapping.get("2024"))
映射文件:
{‘2024’: ‘中华田园犬’, ‘2025’: ‘吉娃娃’, ‘2026’: ‘哈士奇’, ‘2027’: ‘德牧’, ‘2028’: ‘拉布拉多’, ‘2029’: ‘杜宾’, ‘2030’: ‘柴犬’, ‘2031’: ‘法国斗牛’, ‘2032’: ‘萨摩耶’, ‘2033’: ‘藏獒’, ‘2034’: ‘金毛’}
nameMap函数是将之前编码的图像名称进行解码,然后重新编码,编码成20240,20301,分别表示的中华田园犬文件夹下的0.jpg, 柴犬下面的1.jpg。这都是为了可视化的时候进行追溯,得到文件路径。
3.以图搜图实现
定义了一个类ImageRetrival,使用faiss实现创建索引,保存索引,加载索引和图像检索功能
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
'''
@Project :ImageRec
@File :faiss_index.py
@IDE :PyCharm
@Author :菜菜2024
@Date :2024/9/30 15:04
'''import os
import faiss
from utils.split_data import array_norm
from utils.Imgmap import nameMap, getImgMap
from model import ResNet18
from save_feature import getImagVectors
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rc
# 设置全局字体为支持中文的字体
rc('font', family='SimHei') # 黑体class ImageRetrival:def __init__(self, model_path,index_dim=None):self.index_dim = index_dimself.index = faiss.IndexFlatL2(self.index_dim)self.model_path = model_pathdef build_index(self, image_files):# image_vectors图片特征,image_ids对应的标签image_vectors, image_ids = getImagVectors(image_files)# image_ids 在之前保存为h5文件时进行了编码,这里进行映射name_ids = nameMap(image_ids)index = faiss.IndexIDMap(self.index)index.add_with_ids(image_vectors, name_ids)return indexdef save_index(self, index, index_path):faiss.write_index(index, index_path)def load_index(self, index_path):return faiss.read_index(index_path)def image_topK_search(self, index, input_image, topK=None):resnet18 = ResNet18(out_feature=11,model_path=self.model_path)queryVec = resnet18.extract_image_features(input_image)dist, ind = index.search(queryVec, topK)dist, ind = dist.flatten(), ind.flatten()res = array_norm(dist, ind)return res
4.运行调用
if __name__=="__main__":model_path='E:\\xxx\\Pycharm_files\\ImageRec\\weights\\resnet18.pth'# 1.创建索引imageRetrival = ImageRetrival(model_path=model_path,index_dim=11)image_files = 'E:\\xxx\\datas\\pet_dog\\train'save_index = "./weights/dog.index"index = imageRetrival.build_index(image_files)# # 2.保存索引imageRetrival.save_index(index, save_index)# 3.加载索引index_load = imageRetrival.load_index(save_index)## # 4.相似度匹配input_image = './data/pic/德牧.jpg'out = imageRetrival.image_topK_search(index_load, input_image, topK=3)print(out)showFaissRes(image_files, input_image, out)
运行时选择性注销其中的某一步骤。
最后是可视化实现showFaissRes
5.可视化实现
def showFaissRes(image_files, input_image, faissRes):'''对faiss得到的结果进行可视化:param image_files: 图片数据库:param input_image: 查询图片路径:param faissRes: 返回的topk跟距离最近的结果[(ind, score), (ind, score)]:return:'''scores = []imgs = []info = []# 1.得到图片名称的映射element_mapping = getImgMap(image_files)imgs.append(mpimg.imread(input_image))info.append(input_image.split("/")[-1])for i in range(len(faissRes)):score = faissRes[i][1]ind = str(faissRes[i][0])scores.append(score)# 根据索引构建原本的图像路径ind格式:20276,前四个是类别表示claName = element_mapping.get(ind[:4])imgName = ind[4:]+".jpg"imgpath = image_files +"\\"+ claName+ "\\"+imgNameimgs.append(mpimg.imread(imgpath))info.append(claName+"_"+ imgName+"_"+ str(score))print("图片名称是: " + claName+ imgName + " 对应得分是: %f" %score)num = int((len(faissRes) + 1) // 2)+1fig, axs = plt.subplots(nrows=num, ncols=num, figsize=(10, 10))# 确保即使只有一个子图,也可以进行索引if not isinstance(axs, np.ndarray):axs = np.array([[axs]])# 显示图像flat_index = 0for i in range(num):for j in range(num):if flat_index < len(imgs):img = imgs[flat_index]axs[i, j].imshow(img, cmap='gray')axs[i, j].axis('off')axs[i, j].set_title(info[flat_index])flat_index += 1else:axs[i, j].set_visible(False)plt.tight_layout()plt.show()
3.效果对比
第一篇:【以图搜图代码实现】–犬类以图搜图示例 预训练的resnet18
第二篇:【使用resnet18训练自己的数据集】 微调的resnet18
本章 Faiss实现: 分数不重要,本篇对分数进行了归一化。
准确性更高了。