sagemaker中使用pytorch框架的DLC训练和部署cifar图像分类任务

参考资料

  • https://github.com/aws/amazon-sagemaker-examples/blob/main/sagemaker-python-sdk/pytorch_cnn_cifar10/pytorch_local_mode_cifar10.ipynb
  • https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html

获取训练数据

# s3://zhaojiew-sagemaker/data/cifar10/cifar-10-python.tar.gz
import torch
import torchvision
import torchvision.transforms as transformsdef _get_transform():return transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 这里加载数据用的路径是/tmp/pytorch-example/cifar-10-data实际下载了tar.gz文件到本地/tmp目录,后续training也要放入tar.gz文件路径
def get_train_data_loader(data_dir='/tmp/pytorch/cifar-10-data'):transform=_get_transform()trainset=torchvision.datasets.CIFAR10(root=data_dir, train=True,download=True, transform=transform)return torch.utils.data.DataLoader(trainset, batch_size=4,shuffle=True, num_workers=2)def get_test_data_loader(data_dir='/tmp/pytorch/cifar-10-data'):transform=_get_transform()testset=torchvision.datasets.CIFAR10(root=data_dir, train=False,download=True, transform=transform)return torch.utils.data.DataLoader(testset, batch_size=4,shuffle=False, num_workers=2)trainloader=get_train_data_loader('/tmp/pytorch-example/cifar-10-data')
testloader=get_test_data_loader('/tmp/pytorch-example/cifar-10-data')

显示加载的数据

import numpy as np
import torchvision, torch
import matplotlib.pyplot as pltdef imshow(img):img = img / 2 + 0.5  # unnormalizenpimg = img.numpy()plt.imshow(np.transpose(npimg, (1, 2, 0)))# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)# show images
imshow(torchvision.utils.make_grid(images))# print labels
classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")
print(" ".join("%9s" % classes[labels[j]] for j in range(4)))

在这里插入图片描述

训练和推理脚本

脚本同时用来进行训练和推理任务,推理部分的实现为model_fn,没有实现input_fn等函数

import ast
import argparse
import loggingimport osimport torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision
import torchvision.models
import torchvision.transforms as transforms
import torch.nn.functional as Flogger=logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)classes=('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')# https://github.com/pytorch/tutorials/blob/master/beginner_source/blitz/cifar10_tutorial.py#L118
class Net(nn.Module):def __init__(self):super(Net, self).__init__()self.conv1=nn.Conv2d(3, 6, 5)self.pool=nn.MaxPool2d(2, 2)self.conv2=nn.Conv2d(6, 16, 5)self.fc1=nn.Linear(16 * 5 * 5, 120)self.fc2=nn.Linear(120, 84)self.fc3=nn.Linear(84, 10)def forward(self, x):x=self.pool(F.relu(self.conv1(x)))x=self.pool(F.relu(self.conv2(x)))x=x.view(-1, 16 * 5 * 5)x=F.relu(self.fc1(x))x=F.relu(self.fc2(x))x=self.fc3(x)return xdef _train(args):is_distributed=len(args.hosts) > 1 and args.dist_backend is not Nonelogger.debug("Distributed training - {}".format(is_distributed))if is_distributed:# Initialize the distributed environment.world_size=len(args.hosts)os.environ['WORLD_SIZE']=str(world_size)host_rank=args.hosts.index(args.current_host)dist.init_process_group(backend=args.dist_backend, rank=host_rank, world_size=world_size)logger.info('Initialized the distributed environment: \'{}\' backend on {} nodes. '.format(args.dist_backend,dist.get_world_size()) + 'Current host rank is {}. Using cuda: {}. Number of gpus: {}'.format(dist.get_rank(), torch.cuda.is_available(), args.num_gpus))device='cuda' if torch.cuda.is_available() else 'cpu'logger.info("Device Type: {}".format(device))logger.info("Loading Cifar10 dataset")transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])trainset=torchvision.datasets.CIFAR10(root=args.data_dir, train=True,download=False, transform=transform)train_loader=torch.utils.data.DataLoader(trainset, batch_size=args.batch_size,shuffle=True, num_workers=args.workers)testset=torchvision.datasets.CIFAR10(root=args.data_dir, train=False,download=False, transform=transform)test_loader=torch.utils.data.DataLoader(testset, batch_size=args.batch_size,shuffle=False, num_workers=args.workers)logger.info("Model loaded")model=Net()if torch.cuda.device_count() > 1:logger.info("Gpu count: {}".format(torch.cuda.device_count()))model=nn.DataParallel(model)model=model.to(device)criterion=nn.CrossEntropyLoss().to(device)optimizer=torch.optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)for epoch in range(0, args.epochs):running_loss=0.0for i, data in enumerate(train_loader):# get the inputsinputs, labels=datainputs, labels=inputs.to(device), labels.to(device)# zero the parameter gradientsoptimizer.zero_grad()# forward + backward + optimizeoutputs=model(inputs)loss=criterion(outputs, labels)loss.backward()optimizer.step()# print statisticsrunning_loss += loss.item()if i % 2000 == 1999:  # print every 2000 mini-batchesprint('[%d, %5d] loss: %.3f' %(epoch + 1, i + 1, running_loss / 2000))running_loss=0.0print('Finished Training')return _save_model(model, args.model_dir)def _save_model(model, model_dir):logger.info("Saving the model.")path=os.path.join(model_dir, 'model.pth')# recommended way from http://pytorch.org/docs/master/notes/serialization.htmltorch.save(model.cpu().state_dict(), path)def model_fn(model_dir):logger.info('model_fn triggered, starting to load model...')device="cuda" if torch.cuda.is_available() else "cpu"model=Net()if torch.cuda.device_count() > 1:logger.info("Gpu count: {}".format(torch.cuda.device_count()))model=nn.DataParallel(model)with open(os.path.join(model_dir, 'model.pth'), 'rb') as f:model.load_state_dict(torch.load(f))return model.to(device)if __name__ == '__main__':parser=argparse.ArgumentParser()parser.add_argument('--workers', type=int, default=2, metavar='W',help='number of data loading workers (default: 2)')parser.add_argument('--epochs', type=int, default=2, metavar='E',help='number of total epochs to run (default: 2)')parser.add_argument('--batch-size', type=int, default=4, metavar='BS',help='batch size (default: 4)')parser.add_argument('--lr', type=float, default=0.001, metavar='LR',help='initial learning rate (default: 0.001)')parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='momentum (default: 0.9)')parser.add_argument('--dist-backend', type=str, default='gloo', help='distributed backend (default: gloo)')# The parameters below retrieve their default values from SageMaker environment variables, which are# instantiated by the SageMaker containers framework.# https://github.com/aws/sagemaker-containers#how-a-script-is-executed-inside-the-containerparser.add_argument('--hosts', type=str, default=ast.literal_eval(os.environ['SM_HOSTS']))parser.add_argument('--current-host', type=str, default=os.environ['SM_CURRENT_HOST'])parser.add_argument('--model-dir', type=str, default=os.environ['SM_MODEL_DIR'])parser.add_argument('--data-dir', type=str, default=os.environ['SM_CHANNEL_TRAINING'])parser.add_argument('--num-gpus', type=int, default=os.environ['SM_NUM_GPUS'])_train(parser.parse_args())

模型训练

提前获取pytorch镜像

  • 托管的DLC中内置了training toolkit和inference toolkit,因此只需要按照规范提供训练和推理脚本即可
from sagemaker import get_execution_rolerole=get_execution_role()from sagemaker import image_uris
image_uri_inference = image_uris.retrieve(framework='pytorch',region='cn-north-1',version='1.8.0',py_version='py3',image_scope='inference', instance_type='ml.c5.4xlarge')
image_uri_train = image_uris.retrieve(framework='pytorch',region='cn-north-1',version='1.8.0',py_version='py3',image_scope='training', instance_type='ml.c5.4xlarge')
print(image_uri_inference)
print(image_uri_train)

创建Estimator

from sagemaker.estimator import Estimator# 超参数实际上会作为训练脚本的参数传入,可以通过argparse进行解析
hyperparameters = {'epochs': 1,
}# 使用通用的Estimator,
estimator=Estimator(image_uri=image_uri_train, # 这里可以使用托管镜像或基于托管的扩展镜像role=role,instance_count=1,instance_type='ml.p3.2xlarge',hyperparameters=hyperparameters,source_dir="src",entry_point="cifar10.py"# model_uri="s3://zhaojiew-sagemaker/model/cifar10-pytorch/" # 如果有pre-trained的模型可以使用此参数导入)
# 在本地测试训练任务,实际上是通过docker-compose运行
#estimator.fit('file:///tmp/pytorch-example/cifar-10-data')
# 提交train任务
estimator.fit('s3://zhaojiew-tmp/cifar-10-data/',)

也可以使用PyTorch的Estimator

from sagemaker.pytorch.estimator import PyTorch
# 也可以使用PyTorch
pytorch_estimator = PyTorch(entry_point='cifar10.py',instance_type='ml.p3.2xlarge',instance_count=1,role=role,framework_version='1.8.0',py_version='py3',hyperparameters=hyperparameters
)
pytorch_estimator.fit('s3://zhaojiew-tmp/cifar-10-data/')

最终存储的模型位置为

model_location = 's3://sagemaker-cn-north-1-xxxxxxx/pytorch-training-2024-11-19-09-56-55-508/output/model.tar.gz'

模型部署

实际上可以直接基于estimator进行部署,但是这里导入模型将两个阶段分开

from sagemaker.pytorch.model import PyTorchModelpytorch_model = PyTorchModel(# 指定模型所在位置model_data=model_location,role=role,image_uri=image_uri_inference,entry_point='cifar10.py', # 如果指定了推理脚本会打包为source.tar.gz并和model.tar.gz合并成一个tar文件source_dir="src" # 指定代码所在目录
)
pytorch_predictor = pytorch_model.deploy(instance_type='ml.m5.xlarge', initial_instance_count=1)

也可以使用更通用的Model

from sagemaker.model import Modelmodel = Model(# # 指定模型所在位置model_data=model_location,image_uri=image_uri_inference,role=role,entry_point="cifar10.py",source_dir="src"
)model_predictor=model.deploy(1, "ml.m5.xlarge")

模型调用

如果predictor丢失,可以通过如下方法重建

from sagemaker.predictor import Predictor
from sagemaker.serializers import NumpySerializer
from sagemaker.deserializers import NumpyDeserializermodel_predictor = Predictor(endpoint_name="pytorch-inference-2024-11-19-14-19-49-678"
)
model_predictor.serializer = NumpySerializer()
model_predictor.deserializer = NumpyDeserializer()

使用测试集测试

# get some test images
dataiter = iter(testloader)
images, labels = next(dataiter)# print images
imshow(torchvision.utils.make_grid(images))
print("GroundTruth: ", " ".join("%4s" % classes[labels[j]] for j in range(4)))outputs = model_predictor.predict(images.numpy())
_, predicted = torch.max(torch.from_numpy(np.array(outputs)), 1)print("Predicted: ", " ".join("%4s" % classes[predicted[j]] for j in range(4)))

在这里插入图片描述

由于模型部署后仅仅是在机器学习实例上启动容器,因此也可以在本地测试,例如以下docker-compose文件

networks:sagemaker-local:name: sagemaker-local
services:localendpoint:command: serve # 也可以忽略,默认为servecontainer_name: localendpointenvironment:- AWS_REGION=cn-north-1- SAGEMAKER_PROGRAM=cifar10.py- S3_ENDPOINT_URL=https://s3.cn-north-1.amazonaws.com.cn- SAGEMAKER_SUBMIT_DIRECTORY=/opt/ml/model/codeimage: 727897471807.dkr.ecr.cn-north-1.amazonaws.com.cn/pytorch-inference:1.8.0-cpu-py3ports:- 8080:8080networks:sagemaker-local:volumes:- ./src/cifar10.py:/opt/ml/model/code/cifar10.py- ./model/model.pth:/opt/ml/model/model.pth
version: '2.3'

但是这只能测试推理服务器能够正常启动,实际调用由于无法使用boto3和sagemaker sdk,可能需要手动封装http请求

import numpy as np
import torch
import requests
from io import BytesIObuffer = BytesIO()
np.save(buffer, images.numpy(), allow_pickle=False)
payload = buffer.getvalue()local_url = "http://localhost:8080/invocations"
try:response = requests.post(local_url,data=payload,headers={'Content-Type': 'application/x-npy'})response.raise_for_status()result = np.frombuffer(response.content, dtype=np.float32)print(result)
except Exception as e:print(f"发生错误: {e}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/20420.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

golang笔记8-函数

1. 基本函数 package mainimport "fmt"/*什么是函数:完成某一功能的程序指令的集合语法:func 函数名称(形参列表)(返回值类型列表){执行语句。。。返回值列表}注意事项:函数名:函数名首字母大写:可以被本包…

vite+vue3+ts编译vue组件后,编译产物中d.ts文件为空

一、前言 使用vue3vitets实现一个UI组件库,为了生成类型文件便于其他项目引用该组件库。根据推荐使用了vite-plugin-dts插件进行ts文件的生成 二、版本 组件版本vue ^3.5.12 vite ^5.4.10 vite-plugin-dts ^4.3.0 typescript ~5.6.2 三、问题描述 使用vitevi…

向量数据库FAISS之二:基础进阶版

基础 1.评价类型和距离 1.METRIC_L2 Faiss 使用了欧几里得 (L2) 距离的平方,避免了平方根。 这仍然与欧几里德距离一样单调,但如果需要精确距离,则需要结果的额外平方根。 2.METRIC_INNER_PRODUCT 这通常用于推荐系统中的最大内积搜索。…

家庭网络常识:猫与路由器

这张图大家应该不陌生——以前家庭网络的连接方式。 图1 家庭网络连接示意图 来说说猫/光猫: 先看看两者的图片。 图2 猫 图3 光猫 这个东西因为英文叫“modem”,类似中文的“猫”,所以简称“猫”。 猫和光猫的区别就是,一…

华三预赛学习笔记(每日编辑,复习完为止)

知识点分布 路由交换技术基础 计算机网络基本概念 计算机网络基本概念: 很多电脑和设备通过电线或无线信号连在一起,可以互相“说话”和“分享东西” 网络的主要形式和发展历程: 诞生阶段-最早的计算机网络是以单个计算机为中心的联机系统-终…

技术速递|Microsoft.Extensions.VectorData 预览版简介

作者:Luis Quintanilla - 项目经理 排版:Alan Wang 我们很高兴推出 Microsoft.Extensions.VectorData.Abstractions 库,该库现已提供预览版。 正如 Microsoft.Extensions.AI 库为使用 AI 服务提供了一个统一层一样,此包为 .NET 生…

第5章-总体设计 5.3 硬件架构设计

5.3 硬件架构设计 1.哪些类型的产品需要架构设计?2.硬件架构师到底做什么?(1)理解需求和业务模型的情况。(2)背板设计,既需要考虑业务数据交换能力,也需要考虑子模块的管理监控能力。…

图像/文字差异类型验证码识别 无需训练

某像差异在个别全家桶验证方便有使用,对于这种验证码类型如下: 首先还是目标检测,直接用 dddd 自带的detection 就足够了。 特征提取 其次经过观察,差异答案与其他三个无非就是颜色,字体,方向&#xff0c…

新华三H3CNE网络工程师认证—生成树协议

新华三H3CNE网络工程师认证本节讲解生成树协议,关于生成树协议,提到生成树协议,这个时候不得不提到另外一个概念叫二层环路。二层环路导致的原因是交换机的转发机制导致的,本博客将分析这个机制导致这个问题的原因。 文章目录 一…

使用ai工具探究论文的工作流(阅读一个EEG的cnn-lstm文献(2021))

文章目录 李沐老师的方法论第一遍:做海选第二遍:对相关论文进行精选第三遍:重点研读 AI是怎么分析一个文章的标题(Title)和关键词摘要(Abstract)分析引言(Introduction)梳…

Scala的Array习题

答案:CBBBB import scala.collection.mutable.ArrayBuffer //1 case class DreamItem(content:String,var isDone:Boolean,deadline:String,var order:Int) object p5 {def main(args: Array[String]): Unit {//2val dreamListArrayBuffer[DreamItem]()//梦想清单/…

深度学习的实践层面

深度学习的实践层面 设计机器学习应用 在训练神经网络时,超参数选择是一个高度迭代的过程。我们通常从一个初步的模型框架开始,进行编码、运行和测试,通过不断调整优化模型。 数据集一般划分为三部分:训练集、验证集和测试集。常…

TPU-MLIR 总览

TPU-MLIR 总览 💡深度学习编译器可以实现一次性代码开发和重用各种计算能力处理器的目标 ## 项目简介: TPU-MLIR 是 AI 芯片的 TPU 编译器工程。该工程提供了一套完整的工具链, 其可以将不同框架下预训练的神经网络, 转化为可以在算能 TPU 上高效运算的…

Vue3 + Vite 项目引入 Typescript

文章目录 一、TypeScript简介二、TypeScript 开发环境搭建三、编译方式1. 自动编译单个文件2. 自动编译整个项目 四、配置文件1. compilerOptions基本选项严格模式相关选项(启用 strict 后自动包含这些)模块与导入相关选项 2. include 和 excludeinclude…

苹果MacOS 调用自编译opencv的Dylib显示一个图片程序的步骤

前言 为了测试自编译的opencv库是否能在苹果MacOS系统下使用,需要写一个简单的测试程序。这个测试程序写起来不难,麻烦的是一些配置。网上的办法很多,里面因为版本的问题有一些坑。特此写了一个建立步骤,供大家参考。 1、新建一个…

AI赋能:高职院校实验实训教学如何拥抱人工智能浪潮?

随着信息技术的迅猛发展,人工智能技术已成为推动社会各行业转型升级的核心力量。它不仅在提升生产效率、优化管理流程、提高服务质量方面发挥着关键作用,也深刻影响着高职教育的专业发展和课程教学内容的改革。作为培养专业技术技能人才的摇篮&#xff0…

消费者行为学领域的顶级期刊

一、期刊 1.Journal of Consumer Research 2.Journal of Consumer Psychology 3.Journal of Research in Interactive Marketing 4.Journal of the Academy of Marketing Science 5.Tourism Management 下面是我整理的一个excel,大家按需丝我获取。 二、期刊&z…

STM32单片机CAN总线汽车线路通断检测-分享

目录 目录 前言 一、本设计主要实现哪些很“开门”功能? 二、电路设计原理图 1.电路图采用Altium Designer进行设计: 2.实物展示图片 三、程序源代码设计 四、获取资料内容 前言 随着汽车电子技术的不断发展,车辆通信接口在汽车电子控…

Zmap+python脚本+burp实现自动化Fuzzing测试

声明 学习视频来自 B 站UP主泷羽sec,如涉及侵权马上删除文章。 笔记的只是方便各位师傅学习知识,以下网站只涉及学习内容,其他的都与本人无关,切莫逾越法律红线,否则后果自负。 ✍🏻作者简介:致…

3.tree of thought 源码 (Thought 和ToTDFSMemory 类)

本教程将介绍 tree of thought 源码 中的Thought 和ToTDFSMemory 类 定义思维有效性 使用Enum模块来定义思维的有效性。 from enum import Enumclass ThoughtValidity(Enum):"""Enum for the validity of a thought."""VALID_INTERMEDIATE 0…