一、实现
使用目录结构:
templates
---upload.html
faiss_app.py
前端代码:upload.html
<!DOCTYPE html>
<html lang="en">
<head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-scale=1.0"><title>Search and Show Multiple Images</title><style>#image-container {display: flex;flex-wrap: wrap;}#image-container img {max-width: 150px;margin: 10px;}</style>
</head>
<body><h1>Search Images</h1><!-- 搜索框 --><form id="search-form"><input type="text" id="search-input" name="query" placeholder="Enter search term" required><input type="submit" value="Search"></form><h2>Search Results</h2><!-- 显示搜索返回的多张图片 --><div id="image-container"></div><!-- 使用JS处理表单提交 --><script>document.getElementById('search-form').addEventListener('submit', async function(event) {event.preventDefault(); // 阻止表单默认提交行为const query = document.getElementById('search-input').value; // 获取搜索框中的输入内容try {// 发送GET请求,将搜索关键词发送到后端const response = await fetch(`/search?query=${encodeURIComponent(query)}`, {method: 'GET',});// 确保服务器返回JSON数据const data = await response.json();// 清空图片容器const imageContainer = document.getElementById('image-container');imageContainer.innerHTML = '';// 遍历后端返回的图片URL数组,动态创建<img>标签并渲染data.image_urls.forEach(url => {const imgElement = document.createElement('img');imgElement.src = url; // 设置图片的src属性为返回的URLimageContainer.appendChild(imgElement); // 将图片添加到容器中});} catch (error) {console.error('Error searching for images:', error);}});</script>
</body>
</html>
后端代码 faiss_app.py:
from sentence_transformers import SentenceTransformer, util
from PIL import Image
from flask import Flask, request, jsonify, current_app, render_template, send_from_directory, url_for
from werkzeug.utils import secure_filename
import faiss
import os, glob
import numpy as np
from markupsafe import escape
import shutil#Load CLIP model
model = SentenceTransformer('clip-ViT-B-32')
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.gif', '.bmp'}UPLOAD_FOLDER = 'uploads/'
IMAGES_PATH = "C:\\Users\\xxxx\\Pictures\\"def generate_clip_embeddings(images_path, model):image_paths = []# 使用 os.walk 遍历所有子目录和文件for root, dirs, files in os.walk(images_path):for file in files:# 获取文件的扩展名并转换为小写ext = os.path.splitext(file)[1].lower()# 判断是否是图片文件if ext in IMAGE_EXTENSIONS:image_paths.append(os.path.join(root, file)) embeddings = []for img_path in image_paths:image = Image.open(img_path)embedding = model.encode(image)embeddings.append(embedding)return embeddings, image_pathsdef create_faiss_index(embeddings, image_paths, output_path):dimension = len(embeddings[0])# 分情况创建Faiss索引对象if len(image_paths) < 39 * 256:# 如果条目很少,直接用最普通的L2索引faiss_index = faiss.IndexFlatL2(dimension)elif len(image_paths) < 39 * 4096:# 如果条目少于39 × 4096,就只用PQ量化,不使用IVFfaiss_index = faiss.index_factory(dimension, 'OPQ64_256,PQ64x8')else:# 否则就加上IVFfaiss_index = faiss.index_factory(dimension, 'OPQ64_256,IVF4096,PQ64x8')res = faiss.StandardGpuResources()co = faiss.GpuClonerOptions()co.useFloat16 = Truefaiss_index = faiss.index_cpu_to_gpu(res, 0, faiss_index, co)#index = faiss.IndexFlatIP(dimension)faiss_index = faiss.IndexIDMap(faiss_index)vectors = np.array(embeddings).astype(np.float32)# Add vectors to the index with IDsfaiss_index.add_with_ids(vectors, np.array(range(len(embeddings))))# Save the indexfaiss_index = faiss.index_gpu_to_cpu(faiss_index)faiss.write_index(faiss_index, output_path)print(f"Index created and saved to {output_path}")# Save image pathswith open(output_path + '.paths', 'w') as f:for img_path in image_paths:f.write(img_path + '\n')return faiss_indexdef load_faiss_index(index_path):faiss_index = faiss.read_index(index_path)with open(index_path + '.paths', 'r') as f:image_paths = [line.strip() for line in f]print(f"Index loaded from {index_path}")if not faiss_index.is_trained:raise RuntimeError(f'从[{index_path}]加载的Faiss索引未训练')res = faiss.StandardGpuResources()co = faiss.GpuClonerOptions()co.useFloat16 = Truefaiss_index = faiss.index_cpu_to_gpu(res, 0, faiss_index, co)return faiss_index, image_pathsdef retrieve_similar_images(query, model, index, image_paths, top_k=3):# query preprocess:if query.endswith(('.png', '.jpg', '.jpeg', '.tiff', '.bmp', '.gif')):query = Image.open(query)query_features = model.encode(query)query_features = query_features.astype(np.float32).reshape(1, -1)distances, indices = index.search(query_features, top_k)retrieved_images = [image_paths[int(idx)] for idx in indices[0]]return query, retrieved_images# 检查文件扩展名是否允许
def allowed_file(filename):return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONSdef search():query = request.args.get('query') # 获取搜索关键词safe_query = escape(query)if not query:return jsonify({"error": "No search query provided"}), 400index, image_paths = None, []OUTPUT_INDEX_PATH = f"{app.config['UPLOAD_FOLDER']}/vector.index"if os.path.exists(OUTPUT_INDEX_PATH):index, image_paths = load_faiss_index(OUTPUT_INDEX_PATH)else:embeddings, image_paths = generate_clip_embeddings(IMAGES_PATH, model)index = create_faiss_index(embeddings, image_paths, OUTPUT_INDEX_PATH)query, retrieved_images = retrieve_similar_images(query, model, index, image_paths, top_k=5)image_urls = []for path in retrieved_images:base_name = os.path.basename(path)shutil.copy(path, os.path.join(app.config['UPLOAD_FOLDER'], base_name))image_urls.append(url_for('uploaded_file_path', filename=base_name))return jsonify({"image_urls": image_urls})def index():return render_template('upload.html')# 提供静态文件的访问路径
def uploaded_file_path(filename):return send_from_directory(app.config['UPLOAD_FOLDER'], filename)if __name__ == "__main__":app = Flask(__name__)app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDERif not os.path.exists(UPLOAD_FOLDER):os.makedirs(UPLOAD_FOLDER)# 主页显示上传表单app.route('/')(index)app.route('/search', methods=['GET'])(search)app.route('/uploads/images/<filename>')(uploaded_file_path)app.run(host='0.0.0.0', port=8080, debug=True)