由于需要推理图像小模型,然后返回相关参数,目前商店中没有满足需要的插件,所以开发了一个。
在开发之前,得明白一点:
1. coze发送的图片不是二进制流,而是url链接
以下是后端代码:
import requests
from flask import Flask, request, jsonify
import os
import preprocess_data_module as preproc
from cnn_model import RD_net
import torchapp = Flask(__name__)
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")model = RD_net().to(device)
model.load_state_dict(torch.load("img_model3/xxxxx", map_location=device))
model.eval()@app.route('/')
def index():return jsonify({"message": "Welcome to the RD_net API"})# 健康检查路由
@app.route('/health', methods=['GET'])
def health_check():return jsonify({"status": "OK"})@app.route('/predict', methods=['POST'])
def get_structure_params():content_type = request.headers.get('Content-Type')if content_type == 'application/json':data = request.get_json()file_url = data.get('file')if not file_url:return jsonify({"error": "No file URL provided"}), 400# 下载图片try:response = requests.get(file_url)response.raise_for_status() # 检查请求是否成功temp_image_path = 'temp_image.png' # 临时保存的文件名with open(temp_image_path, 'wb') as f:f.write(response.content) # 保存图片内容except Exception as e:return jsonify({"error": str(e)}), 500# 处理下载的图片try:processed_data = preproc.preprocess_image(temp_image_path)processed_data = processed_data.transpose((2, 0, 1))input_tensor = torch.tensor(processed_data, dtype=torch.float).unsqueeze(0).to(device)# 使用模型进行预测with torch.no_grad():output = model(input_tensor)structure_params = output.cpu().numpy()if structure_params.ndim > 1:structure_params = structure_params.flatten()# 将数组转换为字符串structure_params_str = ', '.join(map(str, structure_params))return jsonify({'structure_params': structure_params_str})except Exception as e:return jsonify({"error": str(e)}), 500finally:# 删除临时文件if os.path.exists(temp_image_path):os.remove(temp_image_path)else:return jsonify({"error": "Invalid Content-Type, expected application/json"}), 415if __name__ == '__main__':app.run(debug=False, host='0.0.0.0', port=8714)