当前位置: 首页 > news >正文

Tensorflow实现用接口调用模型训练和停止训练功能

语言:Python
框架:Flask、Tensorflow
功能描述:存在两个接口,一个接口实现开始训练模型的功能,一个接口实现停止训练的功能。
实现:用一个全局变量存储在训练中的模型。

# 存储所有训练任务
training_tasks = {}
# 训练模型的接口
@train_model.route("/train", methods=["POST"])
def train():try:data = request.get_data()data = json.loads(data)print(data)modelId = data["modelId"]if modelId in training_tasks:return {"success": False, "message": f"{modelId} 已经在训练中"}stop_event = threading.Event()# 在新线程中启动训练train_thread = threading.Thread(target=start_train,args=(data, stop_event))training_tasks[modelId] = {'thread': train_thread,'stop_event': stop_event}train_thread.start()return {"success": "success", "message": "开始训练"}except Exception as e:return  {"success": False, "message": str(e)}
def start_train(data, stop_event):try:# 获取任务参数modelId = data["modelId"]except Exception as e:response_data = {"success": False, "message": str(e)}return response_data
class StopTrainingCallback(keras.callbacks.Callback):def __init__(self, model, modelId, stop_event):super().__init__()self.model = modelself.modelId = modelIdself.stop_event = stop_eventdef on_train_begin(self, logs=None):if self.stop_event.is_set():self.model.stop_training = True # 设置此标志会使model.fit提前终止print(f"训练在开始前被停止")def on_batch_begin(self, batch, logs=None):if self.stop_event.is_set():self.model.stop_training = True # 设置此标志会使model.fit提前终止print(f"训练在批次被停止")# 强制抛出一个异常以确保立即停止raise KeyboardInterrupt("训练被用户停止")
# 模型真正训练的函数
def start_train(data, stop_event):# 定义模型及训练数据model = "xxx"modelId = "xxx"train_dataset = "xxx"test_dataset = "xxx"train_steps = len(list(train_dataset))test_steps = len(list(test_dataset))epochs = "xxx"stoptrainingcallback = StopTrainingCallback(model, modelId, stop_event)try:# 在开始训练前立即检查停止事件if stop_event.is_set():log.info(f"训练 {modelId} 在开始前被停止")callback_log.info("模型训练在开始前被停止")raise KeyboardInterrupt("Training stopped before start")model.fit(train_dataset,steps_per_epoch=train_steps,epochs=epochs,verbose=2,shuffle=True,validation_data=test_dataset,validation_steps=test_steps,callbacks=[stoptrainingcallback])response_data = {"success": True, "message": "Success"}except KeyboardInterrupt:response_data = {"success": False, "message": "模型训练被用户停止."}except tf.errors.ResourceExhaustedError as e:# 显存不足错误response_data = {"success": False, "message": "GPU内存不足,请调整训练参数."}except Exception as e:print("模型训练失败")response_data = {"success": False, "message": str(e)}finally:if data["modelId"] in training_tasks:del training_tasks[data["modelId"]]return response_data
# 停止训练的接口
@stop_train.route('/stop', methods=['POST'])
def stop():data = request.get_data()try:data = json.loads(data)modelId = data.get("modelId",'') # 每个模型有一个唯一的UUIDif modelId == '':return jsonify({"success": False, "message": "modelId为空,无法停止训练.", "data": ''})except Exception as e:print("停止模型训练接口请求数据出错:", str(e))return jsonify({"success": False, "message": "参数错误.", "data": ''})# 调用服务层停止训练result = stop_train_service(modelId)print(result["message"])# 返回响应return jsonify(result)
# 调用服务层停止训练
def stop_train_service(modelId):# 检查模型是否存在if modelId not in training_tasks:return {"success": "error", "message": f"没有找到模型 {modelId} 的训练任务"}# 获取停止事件并设置stop_event = training_tasks[modelId].get('stop_event')if stop_event:stop_event.set()# 清理任务记录del training_tasks[modelId]return {"success": "success", "message": f"停止 {modelId} 模型训练的请求已发送"}else:return {"success": "error", "message": f"模型 {modelId} 的停止训练事件不存在"}
http://www.xdnf.cn/news/7165.html

相关文章:

  • Mac mini 安装mysql数据库以及出现的一些问题的解决方案
  • 【前端HTML生成二维码——MQ】
  • uni-app 安卓10以上上传原图解决方案
  • 基于FPGA的AES加解密系统verilog实现,包含testbench和开发板硬件测试
  • 4.Rust+Axum Tower 中间件实战:从集成到自定义
  • 【Leetcode 每日一题】2364. 统计坏数对的数目
  • 再读bert(Bidirectional Encoder Representations from Transformers)
  • 学习设计模式《二》——外观模式
  • 京东物流基于Flink StarRocks的湖仓建设实践
  • UI 在教育产品涉及的领域
  • 如何评价2025 mathorcup妈妈杯数学建模竞赛?完整建模过程+完整代码论文全解全析来了
  • 2025年MathorCup数学应用挑战赛D题问题一求解与整体思路分析
  • Android 12.0 framework实现对系统语言切换的功能实现
  • 硬盘变废为宝!西部数据携微软等启动稀土回收 效率可达90%
  • SQL预编译——预编译真的能完美防御SQL注入吗
  • 关于hadoop和yarn的问题
  • 基于Flask的AI工具聚合平台技术解析
  • TypeScript 从入门到精通:完整教程与实战应用(二)
  • stl 容器 – map
  • 校平机:精密制造的“材料雕刻家“
  • MQTTClient.c中的协议解析与报文处理机制
  • SpringBoot运维问题
  • FreeRTOS任务通知
  • 51单片机实验五:A/D和D/A转换
  • 前端:uniapp框架中<scroll-view>r如何控制元素进行局部滚动
  • ASP.NET MVC 实现增删改查(CRUD)操作的完整示例
  • 从代码学习深度学习 - 小批量随机梯度下降 PyTorch 版
  • Spring Boot启动流程深度解析:从main()到应用就绪的完整旅程
  • Starrocks 数据均衡DiskAndTabletLoadReBalancer的实现
  • 使用Lean 4和C#进行数学定理证明与逻辑推理