丹摩征文活动 | 深度学习实战:UNet模型的训练与测试详解

🍑个人主页:Jupiter.
🚀 所属专栏:Linux从入门到进阶
欢迎大家点赞收藏评论😊

在这里插入图片描述

在这里插入图片描述

目录

    • 1、云实例:配置选型与启动
      • 1.1 登录注册
      • 1.2 配置 SSH 密钥对
      • 1.3 创建实例
      • 1.4 登录云实例
    • 2、云存储:数据集上传与下载
    • 3、云开发:眼底血管分割案例
      • 3.1 案例背景
      • 3.2 网络搭建
      • 3.3 网络训练
      • 3.4 模型测试


1、云实例:配置选型与启动

1.1 登录注册

首先进入登录界面注册并登录账号

在这里插入图片描述

1.2 配置 SSH 密钥对

配置 SSH 密钥对的作用是后续远程登录服务器不需要密码验证,更加方便。

首先创建本地公钥,进入本地.ssh目录输入ssh-keygen -o命令,这里文件名可以设置为id_dsa,也可以是其他任意名字
在这里插入图片描述
之后我们可以在.ssh目录看到刚刚创建的两个文件

id_dsa id_dsa.pub
其中id_dsa.pub就是需要的公钥文件

进入密钥对配置,创建密钥对,将id_dsa.pub的内容复制到这里就可以

1.3 创建实例

进入GPU 云实例,点击创建实例。如下图所示,按需选择需要的 GPU 型号和镜像
在这里插入图片描述
在这里插入图片描述

1.4 登录云实例

等待实例创建完成后,点击复制“访问链接”。

在这里插入图片描述
接着来到任意一个 SSH 连接终端进行云实例登录,我这里选择的是 VSCode,如下所示
在这里插入图片描述
成功后,输入:

nvidia-smi
torch.cuda.is_available()

简单验证一下功能即可,如下所示即为成功
在这里插入图片描述

2、云存储:数据集上传与下载

文件存储为网络共享存储,可挂载至的不同实例中。相比本地数据盘,其优势是实例间共享,可以多点读写,不受实例释放的影响;此外存储后端有多冗余副本,数据可靠性非常高;但缺陷是 IO 性能一般

考虑到以上优劣,推荐使用方式:将重要数据或代码存放于文件存储中,所有实例共享,便利的同时数据可靠性也有保障;在训练时,需要高 IO 性能的数据(如训练数据),先拷贝到实例本地数据盘,从本地盘读数据获得更好的 IO 性能。如此兼顾便利、安全和性能。

接下来,我们将训练数据上传到云实例数据盘中。使用scp工具如下

scp -rP 35740 ./DRIVE-SEG-DATA root@cn-north-b.ssh.damodel.com:/root/workspace

具体地:

35740与cn-north-b.ssh.damodel.com分别为端口号和远程地址,请参考 1.4 节替换为自己的参数
./DRIVE-SEG-DATA是本地数据集路径
/root/workspace是远程实例数据集路径

在这里插入图片描述
数据的下载也是类似的命令

scp -rP 35740 root@cn-north-b.ssh.damodel.com:/root/workspace ./DRIVE-SEG-DATA

本文提到的数据集可以在DRIVE 数据集中下载:链接:https://drive.grand-challenge.org/Download/

3、云开发:眼底血管分割案例

3.1 案例背景

眼底,作为眼球的内膜,其结构复杂且关键,包括黄斑、视网膜以及视网膜中央动静脉等重要组成部分。在眼科医学领域,眼底图像是医生诊断眼疾病不可或缺的重要依据。近年来,随着深度学习技术的迅猛发展,医学影像分割领域迎来了革命性的变化,眼底图像分割技术也随之取得了显著进步。

深度学习模型,如AlexNet、VGGNet、GoogLeNet和ResNet等,通过训练大规模数据集,能够学习到更加抽象和高级的特征表示。这些特征表示对于实现精确的眼底图像分割至关重要。相较于传统方法,深度学习模型在分割精度和泛化能力上均表现出色。它们能够对未见过的眼底图像数据做出相对准确的预测,从而提高了分割结果的可靠性和稳定性。

此外,深度学习技术还支持端到端的学习方式。这意味着从原始眼底图像输入到最终的分割结果输出,整个过程无需手工设计复杂的特征提取和预处理流程。这种端到端的学习方式不仅简化了分割算法的开发流程,还提高了分割效率和准确性。

值得注意的是,医学影像数据往往包含多种模态,如CT、MRI以及眼底图像等。深度学习技术凭借其强大的数据处理能力,能够更好地处理这些多模态数据。通过实现不同模态之间的信息融合,深度学习模型能够捕捉到更加全面和丰富的医学影像信息,从而进一步提高医学影像分割的准确性和全面性。

在眼底图像分割领域,深度学习技术的应用不仅提高了分割精度和效率,还为眼科医生提供了更加可靠和全面的诊断依据。随着技术的不断进步和应用的深入拓展,深度学习技术有望在眼科医学领域发挥更加重要的作用,为患者带来更加精准和有效的治疗方案。

在这里插入图片描述
本次实践,我们采用 UNet 进行眼底血管医学图像分割任务。UNet 是一种被广泛应用于语义分割任务的网络结构,其编码器-解码器结构以及跳跃连接的设计,使其能够有效地捕获图像中不同尺度的特征信息,从而在眼底血管分割任务中取得较好的效果。同时,在推理阶段,UNet 采用全卷积网络结构,能够快速对新的眼底图像进行血管分割,为临床应用提供了实时性支持。

3.2 网络搭建

选用 U-Net 网络结构作为基础分割模型的原因在于其通过编解码器架构,有效地结合局部信息和全局信息,提高分割准确性;同时,U-Net 的跳跃连接结构有助于保留和恢复图像中的细节和边缘信息,且在小样本情况下表现优异,能够充分利用有限数据进行有效训练,广泛应用于医学图像分割任务中。网络架构如下

class UNet(nn.Module):def __init__(self, n_channels, n_classes, bilinear=True):super(UNet, self).__init__()self.n_channels = n_channelsself.n_classes = n_classesself.bilinear = bilinearself.inc = DoubleConv(n_channels, 64)self.down1 = Down(64, 128)self.down2 = Down(128, 256)self.down3 = Down(256, 512)self.down4 = Down(512, 512)self.up1 = Up(1024, 256, bilinear)self.up2 = Up(512, 128, bilinear)self.up3 = Up(256, 64, bilinear)self.up4 = Up(128, 64, bilinear)self.outc = OutConv(64, n_classes)def forward(self, x):x1 = self.inc(x)x2 = self.down1(x1)x3 = self.down2(x2)x4 = self.down3(x3)x5 = self.down4(x4)x = self.up1(x5, x4)x = self.up2(x, x3)x = self.up3(x, x2)x = self.up4(x, x1)logits = self.outc(x)return logits

3.3 网络训练

基于 PyTorch 的神经网络训练流程可以分为以下步骤(不考虑前期数据准备和模型结构):

定义损失函数 根据任务类型选择合适的损失函数(loss function),如分类任务常用的交叉熵损失(Cross-Entropy Loss)或回归任务中的均方误差(Mean Square Error)。

选择优化器 选择合适的优化器(optimizer),如随机梯度下降(SGD)、Adam 或 RMSprop,并设置初始学习率及其它优化参数。

训练模型 在训练过程中,通过迭代训练数据集来调整模型参数。每个迭代周期称为一个 epoch。对于每个 epoch,数据会被分成多个 batch,每个 batch 被输入到模型中进行前向传播、计算损失、反向传播更新梯度,并最终优化模型参数。

保存模型 当满足需求时,可以将训练好的模型保存下来,以便后续部署和使用。

根据这个步骤编写以下代码

def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):dataset = Dateset_Loader(data_path)per_epoch_num = len(dataset) / batch_sizetrain_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=batch_size,shuffle=True)optimizer = optim.Adam(net.parameters(),lr=lr,betas=(0.9, 0.999),eps=1e-08, weight_decay=1e-08,amsgrad=False)criterion = nn.BCEWithLogitsLoss()best_loss = float('inf')loss_record = []with tqdm(total=epochs*per_epoch_num) as pbar:for epoch in range(epochs):net.train()for image, label in train_loader:optimizer.zero_grad()image = image.to(device=device, dtype=torch.float32)label = label.to(device=device, dtype=torch.float32)pred = net(image)loss = criterion(pred, label)pbar.set_description("Processing Epoch: {} Loss: {}".format(epoch+1, loss))if loss < best_loss:best_loss = losstorch.save(net.state_dict(), 'best_model.pth')loss.backward()optimizer.step()pbar.update(1)loss_record.append(loss.item())plt.figure()plt.plot([i+1 for i in range(0, len(loss_record))], loss_record)plt.title('Training Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.savefig('/root/shared-storage/results/training_loss.png')

运行这个脚本,可以在终端看到进度

在这里插入图片描述
训练损失函数如下,可以看到已经收敛

在这里插入图片描述

3.4 模型测试

测试逻辑如下所示,主要是计算 IoU 指标

def cal_miou(test_dir="/root/workspace/DRIVE-SEG-DATA/Test_Images",pred_dir="/root/workspace/DRIVE-SEG-DATA/results", gt_dir="/root/workspace/DRIVE-SEG-DATA/Test_Labels",model_path='best_model_drive.pth'):name_classes = ["background", "vein"]num_classes = len(name_classes)if not os.path.exists(pred_dir):os.makedirs(pred_dir)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')net = UNet(n_channels=1, n_classes=1)net.to(device=device)net.load_state_dict(torch.load(model_path, map_location=device))net.eval()img_names = os.listdir(test_dir)image_ids = [image_name.split(".")[0] for image_name in img_names]time.sleep(1)for image_id in tqdm(image_ids):image_path = os.path.join(test_dir, image_id + ".png")img = cv2.imread(image_path)origin_shape = img.shapeimg = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)img = cv2.resize(img, (512, 512))img = img.reshape(1, 1, img.shape[0], img.shape[1])img_tensor = torch.from_numpy(img)img_tensor = img_tensor.to(device=device, dtype=torch.float32)pred = net(img_tensor)pred = np.array(pred.data.cpu()[0])[0]pred[pred >= 0.5] = 255pred[pred < 0.5] = 0pred = cv2.resize(pred, (origin_shape[1], origin_shape[0]), interpolation=cv2.INTER_NEAREST)cv2.imwrite(os.path.join(pred_dir, image_id + ".png"), pred)hist, IoUs, PA_Recall, Precision = compute_mIoU_gray(gt_dir, pred_dir, image_ids, num_classes, name_classes)miou_out_path = "/root/shared-storage/results/"show_results(miou_out_path, hist, IoUs, PA_Recall, Precision, name_classes)

模型保存的时候保存到共享存储路径/root/shared-storage,其他实例可以直接从共享存储中获取训练后的模型
在这里插入图片描述
在这里插入图片描述


本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/17816.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

# 10_ Python基础到实战一飞冲天(一)--linux基础(十)

10_ Python基础到实战一飞冲天&#xff08;一&#xff09;–linux基础&#xff08;十&#xff09;–软链接硬链接-tar-gzip-bzip2-apt-软件源 一、其他命令-04-文件软链接的演练实现 1、ubuntu 桌面文件如下图&#xff1a; 2、需求&#xff1a;文件软链接的演练&#xff08;演…

Python学习27天

字典 dict{one:1,two:2,three:3} # 遍历1&#xff1a; # 先取出Key for key in dict:# 取出Key对应的valueprint(f"key:{key}---value:{dict[key]}")#遍历2&#xff0c;依次取出value for value in dict.values():print(value)# 遍历3&#xff1a;依次取出key,value …

【Linux】进程的优先级

进程的优先级 一.概念二.修改优先级的方法三.进程切换的大致原理&#xff1a;四.上下文数据的保存位置&#xff1a; 一.概念 cpu资源分配的先后顺序&#xff0c;就是指进程的优先权&#xff08;priority&#xff09;。 优先权高的进程有优先执行权利。配置进程优先权对多任务环…

ubuntu无密码用SCP复制文件到windows

默认情况下,ubuntu使用scp复制文件到windows需要输入密码: scp *.bin dev001@172.16.251.147:~/Desktop/. 为了解决每次复制文件都要输入密码这个问题,需要按如下操作: 1.创建ssh密钥 ssh-keygen -t ed25519 -C "xxx_xxx_xxx@hotmail.com" 2.使用scp复制公钥到w…

单片机GPIO中断+定时器 软件串口通信

单片机GPIO中断定时器 软件串口通信 解决思路代码示例 解决思路 串口波特率9600bps,每个bit约为1000000us/9600104.16us&#xff1b; 定时器第一次定时时间设为52us即半个bit的时间&#xff0c;其目的是偏移半个bit时间&#xff0c;之后的每104us采样并读取1bit数据。使得采样…

使用Web Components构建模块化Web应用

&#x1f493; 博客主页&#xff1a;瑕疵的CSDN主页 &#x1f4dd; Gitee主页&#xff1a;瑕疵的gitee主页 ⏩ 文章专栏&#xff1a;《热点资讯》 使用Web Components构建模块化Web应用 使用Web Components构建模块化Web应用 使用Web Components构建模块化Web应用 引言 Web Co…

每行数据个数在变的二维数组的输出

#include<stdio.h> int main() {//定义四个一维数组int arr1[1] { 1 };int arr2[3] { 1,2,3 };int arr3[5] { 1,2,3,4,5 };int arr4[7] { 1,2,3,4,5,6,7 };//把四个一维数组放进一个二维数组int* arr[4] { arr1,arr2,arr3,arr4};//预先计算好每一个数组真实的长度in…

【SSL证书】腾讯云SSL续签备忘录

适用于证书过期了&#xff0c;需要替换证书的场景。本备忘录为nginx使用证书场景 步骤&#xff1a;一共7步。 登录腾讯云控制台->申请免费证书->腾讯云审核->下载->登录服务器->替换证书->重启nginx 1.登录控制台 https://console.cloud.tencent.com/ssl…

AVL树

一.AVL树的概念 AVL树是一颗特殊的二叉搜索树。二叉搜索树在有些极端情况下可能会出现单支的情况&#xff0c;这会影响其插入查找的效率。而AVL树是一个高度平衡的二叉搜索树&#xff0c;它要求任何的左右子树的高低差都小于等于1。它可以通过去控制左右子树的高度差来控制二叉…

鸿蒙开发-网络数据访问、应用本地数据保存

HTTP概述 HTTP&#xff0c;全称Hyper Text Transfer Protocol 超文本传输协议。 HTTP请求为短连接。客户端发起请求&#xff0c;服务器返回响应。本次连接即结束。 添加网络权限 在访问网络之前&#xff0c;需要在module.json5中给APP添加网络权限 "module": {&…

画 五边形 思路

1. 计算圆心 view 中心点 2.规定半径 R < view宽度 / 2 3.计算五边形五个顶点&#xff08;角度A 2π / 5&#xff09; 4. 五点相连 转载&#xff1a; Android自定义控件 芝麻信用分雷达图 - 简书

网络工程实验三:DHCP的配置

#实验仅供参考&#xff0c;勿直接粘贴复制&#xff0c;用以学习交流# #对于软件的使用&#xff0c;请移步到实验一观看# 1、实验目的&#xff1a; &#xff08;1&#xff09;掌握DHCP工作原理。 &#xff08;2&#xff09;配置路由器作为DHCP服务器。 &#xff08;3&#x…

手写体识别Tensorflow实现

简介&#xff1a;本文先讲解了手写体识别中涉及到的知识&#xff0c;然后分步讲解了代码的详细思路&#xff0c;完成了手写体识别案例的讲解&#xff0c;希望能给大家带来帮助&#xff0c;也希望大家多多关注我。本文是基于TensorFlow1.14.0的环境下运行的 手写体识别Tensorflo…

【SpringBoot】公共字段自动填充

问题引入 JavaEE开发的时候&#xff0c;新增字段&#xff0c;修改字段大都会涉及到创建时间(createTime)&#xff0c;更改时间(updateTime)&#xff0c;创建人(craeteUser)&#xff0c;更改人(updateUser)&#xff0c;如果每次都要自己去setter()&#xff0c;会比较麻烦&#…

【项目开发】为什么文件名要小写?

未经许可,不得转载。 文章目录 一、可移植性二、易读性三、易用性四、便捷性一、可移植性 Linux 系统对文件名大小写敏感,而 Windows 和 Mac 系统则不敏感。这种差异可能导致跨平台的问题。 例如,以下四个文件名: computerComPutercomPuterCOMPOTer在 Linux 系统上,它们…

ssm127基于SSM的乡镇篮球队管理系统+jsp(论文+源码)_kaic

毕 业 设 计&#xff08;论 文&#xff09; 题目&#xff1a;乡镇篮球队管理系统设计与实现 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本乡镇篮球队管理…

C#获取视频第一帧_腾讯云媒体处理获取视频第一帧

一、 使用步骤&#xff1a; 第一步、腾讯云开启万象 第二步、安装Tencent.QCloud.Cos.Sdk 包 第三步、修改 腾讯云配置 图片存储目录配置 第四步、执行获取图片并保存 二、封装代码 using System.Text; using System.Threading.Tasks;using COSXML.Model.CI; using COSXML.A…

【数据分享】2003-2022年各省土地利用面积统计数据

数据介绍 2003-2022年各省土地利用面积统计数据数据时间2003-2008、2013、2015-2017、2019、2022数据类型excel数据指标土地调查面积/万公顷农用地面积/万公顷园林面积/万公顷牧草地面积/万公顷建设用地面积/万公顷居民点及工矿用地/万公顷交通用地/万公顷水利设施用地/万公顷…

任务调度工具Spring Test

Spring Task 是Spring框架提供的任务调度工具&#xff0c;可以按照约定的时间自动执行某个代码逻辑。 作用&#xff1a;定时自动执行某段Java代码 应用场景&#xff1a; 信用卡每月还款提醒 银行贷款每月还款提醒 火车票售票系统处理未支付订单 入职纪念日为用户发送通知 一.…

20 轮转数组

20 轮转数组 20.1 轮转数组解决方案 class Solution { public:void rotate(vector<int>& nums, int k) {int n nums.size();k k % n; // 如果 k 大于数组长度&#xff0c;取模减少不必要的旋转// 第一步&#xff1a;反转整个数组reverse(nums.begin(), nums.end(…