增加设置迭代步数和提前停止训练逻辑
- 修改“tools/train_net.py”脚本。
增加args.early_stop_iteration参数,将参数传入训练主函数。
修改前:
def train(cfg, local_rank, distributed, use_tensorboard=False,): ……do_train(cfg,model,data_loader,optimizer,scheduler,checkpointer,device,checkpoint_period,arguments,data_loaders_val,meters)return modeldef main(): parser = argparse.ArgumentParser(description="PyTorch Object Detection Training") …… parser.add_argument("--override_output_dir", default=None)args = parser.parse_args() ……model = train(cfg=cfg,local_rank=args.local_rank,distributed=args.distributed,use_tensorboard=args.use_tensorboard)
修改后:
def train(cfg, local_rank, distributed, use_tensorboard=False, early_stop_iteration=-1): ……do_train(cfg,model,data_loader,optimizer,scheduler,checkpointer,device,checkpoint_period,arguments,data_loaders_val,meters,early_stop_iteration=early_stop_iteration,)return modeldef main(): parser = argparse.ArgumentParser(description="PyTorch Object Detection Training") …… parser.add_argument("--override_output_dir", default=None) parser.add_argument("--early_stop_iteration", type=int, default=-1)args = parser.parse_args() …… model = train(cfg=cfg,local_rank=args.local_rank,distributed=args.distributed,use_tensorboard=args.use_tensorboard,early_stop_iteration=args.early_stop_iteration)
- 修改“maskrcnn_benchmark/engine/trainer.py”脚本。
将early_stop_iteration参数传入训练函数。
修改前:
def do_train(cfg,model,data_loader,optimizer,scheduler,checkpointer,device,checkpoint_period,arguments,val_data_loader=None,meters=None,zero_shot=False ): …… for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter):……arguments["iteration"] = iterationimages = images.to(device) ……if val_data_loader and (iteration % checkpoint_period == 0 or iteration == max_iter):if is_main_process():print("Evaluating")……
修改后:
def do_train(cfg,model,data_loader,optimizer,scheduler,checkpointer,device,checkpoint_period,arguments,val_data_loader=None,meters=None,zero_shot=False,early_stop_iteration=-1, ): …… for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter):……arguments["iteration"] = iterationif early_stop_iteration > 0:if iteration == early_stop_iteration + 1:breakimages = images.to(device) ……if val_data_loader and (iteration % checkpoint_period == 0 or iteration == max_iter oriteration == early_stop_iteration): if is_main_process():print("Evaluating")……
增加训练时计算实时fps的逻辑
修改“maskrcnn_benchmark/engine/trainer.py”脚本。
增加训练时计算实时fps的逻辑。
修改前:
def do_train(cfg,model,data_loader,optimizer,scheduler,checkpointer,device,checkpoint_period,arguments,val_data_loader=None,meters=None,zero_shot=False,early_stop_iteration=-1, ): …… for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter):……meters.update(time=batch_time, data=data_time)
修改后:
def do_train(cfg,model,data_loader,optimizer,scheduler,checkpointer,device,checkpoint_period,arguments,val_data_loader=None,meters=None,zero_shot=False,early_stop_iteration=-1, ): …… for iteration, (images, targets, idxs, positive_map, positive_map_eval, greenlight_map) in enumerate(data_loader, start_iter):…… train_fps = cfg.SOLVER.IMS_PER_BATCH / batch_time meters.update(time=batch_time, data=data_time, fps=train_fps)
原文链接:概述-模型开发-Ascend Extension for PyTorch6.0.RC2开发文档-昇腾社区 (hiascend.com)