PyTorch模型转ONNX量化模型

你是否发现模型太大,无法部署在你想要的云服务上?或者你是否发现 TensorFlow 和 PyTorch 等框架对于你的云服务来说太臃肿了?ONNX Runtime 可能是你的救星。

如果你的模型在 PyTorch 中,你可以轻松地在 Python 中将其转换为 ONNX,然后根据需要量化模型(对于 TensorFlow 模型,你可以使用 tf2onnx)。ONNX Runtime 是轻量级的,量化可以减小模型大小。

让我们尝试将 PyTorch 中预训练的 ResNet-18 模型转换为 ONNX,然后量化。我们将使用 ImageNet 数据集的子集比较准确率。

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - AI模型在线查看 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割 

0、先决条件

首先下载 ImageNet-100 验证数据集并将其解压到一个目录,从现在开始我们将该目录称为  {VAL}。 {VAL} 应如下所示。

{VAL}/|--n01440764/|--ILSVRC2012_val_00000293.JPEG|--...|--...

换句话说, {VAL}/{synset}/{image_name}.JPEG

接下来下载ImageNet的同义词集(synset)。

如果你想知道“同义词集”是什么,ImageNet 网站是这样描述的:

ImageNet 是根据 WordNet 层次结构组织的图像数据集。WordNet 中的每个有意义的概念(可能由多个单词或词组描述)称为“同义词集”或“同义词集”。

现在,下载此 synset_words.txt 文件,由 J.D. Salinger 的《麦田里的守望者》的狂热粉丝提供。你也应该阅读它。😃

1、软件包

我们需要安装和导入以下软件包。你可以使用 pip 来完成此操作。如果你有受支持的 GPU,可能能够使用为 GPU 构建的软件包版本。(例如 --onnxruntime-gpu

from tqdm import tqdm
from PIL import Image
import glob
import numpy as np
import torch
import torchvision as tv
import onnx
import onnxruntime as ort
from onnxruntime import quantization

TQDM 仅用于美观的进度条。 😄

2、PyTorch环节

对于输入图像,模型输出一个向量,其中包含 1000 个元素,每个元素代表一个同义词集。因此,我们需要使用 synset_words.txt 将数据集中的同义词集与索引中的模型输出向量进行匹配。

synset_to_target = {}
f = open("synset_words.txt", "r")
index = 0
for line in f:parts = line.split(" ")synset_to_target[parts[0]] = indexindex = index + 1
f.close()

2.1 数据加载器

创建一个可在 DataLoader 中使用的 dataset 类:

preprocess = tv.transforms.Compose([tv.transforms.Resize(256),tv.transforms.CenterCrop(224),tv.transforms.ToTensor(),tv.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])def tar_transform(synset):return synset_to_target[synset]class ImageNetValDataset(torch.utils.data.Dataset):def __init__(self, img_dir, transform=None, target_transform=None):self.img_dir = img_dirself.img_paths = sorted(glob.glob(img_dir + "*/*.JPEG"), key=lambda x: int(x.split("_")[-1].split(".")[0]))self.transform = transformself.target_transform = target_transformdef __len__(self):return len(self.img_paths)def __getitem__(self, idx):img_path = self.img_paths[idx]image = Image.open(img_path)synset = img_path.split("/")[-2]label = synsetif self.transform:image = self.transform(image)if self.target_transform:label = self.target_transform(label)return image, labelds = ImageNetValDataset("{VAL}/", transform=preprocess, target_transform=tar_transform)

如果需要,拆分或切片数据集,并保留数据集的未触及部分进行量化。

offset = 500
calib_ds = torch.utils.data.Subset(ds, list(range(offset)))
val_ds = torch.utils.data.Subset(ds, list(range(offset, offset * 2)))

calib_ds 保留用于量化。

创建具有所需批量大小的 DataLoader。

batch_size = 64
dl = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False)

可以关闭 shuffle,因为按图像名称排序会按预定顺序混合图像。

2.2 PyTorch 模型

从 Torch Hub 下载 ResNet-18。

model_pt = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', weights=tv.models.ResNet18_Weights.DEFAULT)
model_pt.eval()

eval() 将模型设置为推理模式。

使用虚拟输入执行一次推理:

dummy_in = torch.randn(1, 3, 224, 224, requires_grad=True)dummy_out = model_pt(dummy_in)

2.3 转换为 ONNX

ONNX 模型将保存到给定的路径:

# export fp32 model to onnx
model_fp32_path = 'resnet18_fp32.onnx'torch.onnx.export(model_pt,                                         # modeldummy_in,                                         # model inputmodel_fp32_path,                                  # pathexport_params=True,                               # store the trained parameter weights inside the model fileopset_version=14,                                 # the ONNX version to export the model todo_constant_folding=True,                         # constant folding for optimizationinput_names = ['input'],                          # input namesoutput_names = ['output'],                        # output namesdynamic_axes={'input' : {0 : 'batch_size'},       # variable length axes'output' : {0 : 'batch_size'}})

常量折叠(constant folding)将用预先计算的常量节点替换一些具有所有常量输入的 op。

验证模型的结构并确认模型具有有效的架构。通过检查模型的版本、图形的结构以及节点及其输入和输出来验证 ONNX 图的有效性。

model_onnx = onnx.load(model_fp32_path)
onnx.checker.check_model(model_onnx)

如果测试失败,则会引发异常。

2.4 PyTorch vs. ONNX

定义一个将PyTorch张量转换为NumPy数组的函数:

def to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

此函数将允许我们将相同的 PyTorch DataLoader 与 ONNX 一起使用。

准备模型:

ort_provider = ['CPUExecutionProvider']
if torch.cuda.is_available():model_pt.to('cuda')ort_provider = ['CUDAExecutionProvider']ort_sess = ort.InferenceSession(model_fp32_path, providers=ort_provider)

使用 GPU(如果可用)。

测试模型:

correct_pt = 0
correct_onnx = 0
tot_abs_error = 0for img_batch, label_batch in tqdm(dl, ascii=True, unit="batches"):ort_inputs = {ort_sess.get_inputs()[0].name: to_numpy(img_batch)}ort_outs = ort_sess.run(None, ort_inputs)[0]ort_preds = np.argmax(ort_outs, axis=1)correct_onnx += np.sum(np.equal(ort_preds, to_numpy(label_batch)))if torch.cuda.is_available():img_batch = img_batch.to('cuda')label_batch = label_batch.to('cuda')with torch.no_grad():pt_outs = model_pt(img_batch)pt_preds = torch.argmax(pt_outs, dim=1)correct_pt += torch.sum(pt_preds == label_batch)tot_abs_error += np.sum(np.abs(to_numpy(pt_outs) - ort_outs))print("\n")print(f"pt top-1 acc = {100.0 * correct_pt/len(val_ds)} with {correct_pt} correct samples")
print(f"onnx top-1 acc = {100.0 * correct_onnx/len(val_ds)} with {correct_onnx} correct samples")mae = tot_abs_error/(1000*len(val_ds))
print(f"mean abs error = {mae} with total abs error {tot_abs_error}")

你可能会得到一些这样的结果:

# CPU
# pt top-1 acc = 79.0 with 395 correct samples
# onnx top-1 acc = 79.0 with 395 correct samples
# mean abs error = 1.7788757681846619e-06 with total abs error 0.8894378840923309# GPU
# pt top-1 acc = 79.0 with 395 correct samples
# onnx top-1 acc = 79.0 with 395 correct samples
# mean abs error = 4.85603129863739e-06 with total abs error 2.428015649318695

已知 CPU 和 GPU 产生的结果略有不同,具体取决于操作的实现方式和轻微的位错误。

4、ONNX模型量化

根据 ONNX 运行时文档,建议在量化之前执行此预处理步骤,其中包括优化。

model_prep_path = 'resnet18_prep.onnx'quantization.shape_inference.quant_pre_process(model_fp32_path, model_prep_path, skip_symbolic_shape=False)

预处理后的模型将保存到给定的路径。

4.1 校准数据读取器

根据 ONNX 运行时文档,

通常,建议对 RNN 和基于 transformer 的模型使用动态量化,对 CNN 模型使用静态量化。

由于 ResNet-18 主要是 CNN,我们应该进行静态量化。但是,它需要一个数据集来校准量化的模型参数。(幸好我们把 alib_ds留下了! 😉

class QuntizationDataReader(quantization.CalibrationDataReader):def __init__(self, torch_ds, batch_size, input_name):self.torch_dl = torch.utils.data.DataLoader(torch_ds, batch_size=batch_size, shuffle=False)self.input_name = input_nameself.datasize = len(self.torch_dl)self.enum_data = iter(self.torch_dl)def to_numpy(self, pt_tensor):return pt_tensor.detach().cpu().numpy() if pt_tensor.requires_grad else pt_tensor.cpu().numpy()def get_next(self):batch = next(self.enum_data, None)if batch is not None:return {self.input_name: self.to_numpy(batch[0])}else:return Nonedef rewind(self):self.enum_data = iter(self.torch_dl)qdr = QuntizationDataReader(calib_ds, batch_size=64, input_name=ort_sess.get_inputs()[0].name)

量化模型将保存到给定的路径:

q_static_opts = {"ActivationSymmetric":False,"WeightSymmetric":True}
if torch.cuda.is_available():q_static_opts = {"ActivationSymmetric":True,"WeightSymmetric":True}model_int8_path = 'resnet18_int8.onnx'
quantized_model = quantization.quantize_static(model_input=model_prep_path,model_output=model_int8_path,calibration_data_reader=qdr,extra_options=q_static_opts)

根据 ONNX 运行时存储库,

如果模型以 GPU/TRT 为目标,则需要对称激活和权重。如果模型面向 CPU,建议使用非对称激活和对称权重,以平衡性能和准确性。

你可以从这个ResearchGate 页面了解有关对称/非对称量化的更多信息。

4.2 ONNX FP32 vs. INT8

加载 量化的onnx 模型:

ort_int8_sess = ort.InferenceSession(model_int8_path, providers=ort_provider)

测试模型:

correct_int8 = 0
correct_onnx = 0
tot_abs_error = 0for img_batch, label_batch in tqdm(dl, ascii=True, unit="batches"):ort_inputs = {ort_sess.get_inputs()[0].name: to_numpy(img_batch)}ort_outs = ort_sess.run(None, ort_inputs)[0]ort_preds = np.argmax(ort_outs, axis=1)correct_onnx += np.sum(np.equal(ort_preds, to_numpy(label_batch)))ort_int8_outs = ort_int8_sess.run(None, ort_inputs)[0]ort_int8_preds = np.argmax(ort_int8_outs, axis=1)correct_int8 += np.sum(np.equal(ort_int8_preds, to_numpy(label_batch)))tot_abs_error += np.sum(np.abs(ort_int8_outs - ort_outs))print("\n")print(f"onnx top-1 acc = {100.0 * correct_onnx/len(val_ds)} with {correct_onnx} correct samples")
print(f"onnx int8 top-1 acc = {100.0 * correct_int8/len(val_ds)} with {correct_int8} correct samples")mae = tot_abs_error/(1000*len(val_ds))
print(f"mean abs error = {mae} with total abs error {tot_abs_error}")

可能得到类似如下的结果:

# CPU
# onnx top-1 acc = 79.0 with 395 correct samples
# onnx int8 top-1 acc = 77.8 with 389 correct samples
# mean abs error = 0.265933556640625 with total abs error 132966.7783203125# GPU
# onnx top-1 acc = 79.0 with 395 correct samples
# onnx int8 top-1 acc = 77.4 with 387 correct samples
# mean abs error = 0.44179485546875 with total abs error 220897.427734375

CPU 与 GPU 的结果也可能不同,因为选择了对称与非对称量化方法。


原文链接:PyTorch转ONNX量化模型 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1544754.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

智能感知,主动防御:移动云态势感知为政企安全护航

数字化时代,网络安全已成为企业持续运营和发展的重要基石。随着业务扩展,企业资产的数量急剧增加,且分布日益分散,如何全面、准确地掌握和管理资产成为众多政企单位的难题。同时,传统安全手段又难以有效应对新型、隐蔽…

你的提交信息还在拖后腿?看这里,提升代码质量的绝招!

文章目录 前言一、什么是约定式提交?二、创建新仓库三、将代码推送到远程仓库的步骤1.检查当前远程仓库2.添加代码到暂存区3. 进行约定式提交4. 推送代码到远程仓库5. 完成推送 总结 前言 在当今软件开发领域,Git已经成为最广泛使用的版本控制系统之一。…

二阶滤波算法总结(对RC滤波算法整理的部分修正和完善)

文章目录 1、一阶低通滤波2、一阶高通滤波3、二阶低通滤波器3.1 二阶RC低通滤波器的连续域数学模型3.2 二阶RC低通滤波器的算法推导3.3 matlab仿真 4、二阶高通滤波器4.1 二阶RC高通滤波器的连续域数学模型4.2 二阶RC高通滤波器的算法推导4.3 matlab仿真 5、陷波滤波6、带通滤波…

白杨SEO:从小红书、抖音图文再到小绿书,为什么现在制作图文内容搞SEO搜索精准流量更容易?

前言:为什么想到写这个?上周参加了一个杭州公司游学,发现大家现在做SEO精准流量都在用图文方式来搞了,还有做小绿书也越来越多了,所以分享给大家,看完对大家有一些启发。 文章大纲: 1、图文是什…

2024年AI技术爆发的元年,用对工具,让你副业比主业赚得多!

大家好,我是强哥 文字的力量不容小觑,或许你没有多好的文笔,或许你已经很久没有拿笔写字了,但是没关系,我们有工具! AI时代的到来,不会用工具,那你可就OUT了 如果你觉得文字不能赚…

产业报告丨2024中国AI大模型场景探索及产业应用调研报告(附下载)

前言 AI大模型是指在机器学习和深度学习领域中,采用大规模参数(至少在一亿个参数以上)的神经网络模型,AI大模型在训练过程中需要依赖大量的算力和高质量的数据资源。2024年,AI大模型的行业应用与技术发展正有效提升千…

2024年 AI大模型我该买一张什么卡?

有钱啥也不用说,买张最贵的就是了。对囊中羞涩的我还说,我该买张什么样的显卡呢? 我的旧显卡RTX1060 6G,满负荷消耗功率110多瓦,几乎达到设计最大TDP,周日时拿了朋友的RTX3060Ti 8G,发现是锁算…

Kaggle-狗种类的识别(Pytorch框架)基本图像识别流程

狗类别实现过程 一. 将数据集按标签分类,将标签转换为数字表示,并制作数据集 二. 搭建网络框架,inception,或者ResNet 三. 选择优化函数,训练模型 数据集制作 首先分析数据集,题中已经很明确告诉有120 种…

【2024W32】肖恩技术周刊(第 10 期):太阳神鸟

周刊内容: 对一周内阅读的资讯或技术内容精品(个人向)进行总结,分类大致包含“业界资讯”、“技术博客”、“开源项目”和“工具分享”等。为减少阅读负担提高记忆留存率,每类下内容数一般不超过3条。 更新时间: 星期天 历史收录:…

LeetCode 刷题基础Ⅰ -- 基础语法

c 基础语法,LeetCode 刷题用 学习网站一、顺序结构基本数据类型① 整型 int② 长整型 long③ 浮点型 double④类型转换 输入输出① getchar 吸收回车符② 数学函数③ 最大值的定义 二、选择结构① switch 三、数组① 初始化② 输入③ 方法 四、结构体① 自定义结构体…

UE5地图白屏/过曝/非常亮の解决方法

今天遇到一个问题 , 新建项目 , 打开虚幻第三人称地图的默认关卡 , 发现白屏 , 啥也看不见 猜测可能是虚幻编辑器的bug , 造成白屏的原因应该是场景过曝了 记录一下解决方案 第一种解决方法 找到场景中的 后期处理体积 (PostProcessVolume) 直接删掉 或者找到 细节面板中 -…

【Transformers基础入门篇5】基础组件之Datasets

文章目录 一、简介二、Datasets基本使用2.1 加载在线数据集(load_dataset)2.2 加载数据集某一项任务(load_dataset)2.3 按照数据集划分进行加载(load_dataset)2.4 查看数据集(index and slice&a…

数据库课程 CMU15-445 2023 Fall Project-2 Extendible Hash Index

0 实验结果 tips:完成项目的前提不需要一定看视频 1 数据结构:扩展哈希 解释下这张图: 图中header的最大深度2,directory最大深度2,桶的容量2。 最开始的时候只有一个header。 插入第一个数据,假设这个数据对应的哈希…

洛汗2保姆级辅助教程攻略:VMOS云手机辅助升级打怪!

在《洛汗2》中,玩家将进入一个充满魔幻色彩的西方世界,体验多种族文明的兴衰与冒险。为了更好地享受这款由普雷威(Playwith)开发的角色扮演动作手游,使用VMOS云手机将是一个明智的选择。VMOS云手机专为游戏打造了定制版…

基于SSM的“在线CRM管理系统”的设计与实现(源码+数据库+文档+开题报告)

基于SSM的“在线CRM管理系统”的设计与实现(源码数据库文档开题报告) 开发语言:Java 数据库:MySQL 技术:SSM 工具:IDEA/Ecilpse、Navicat、Maven 系统展示 总体功能模块图 登录页面 后台管理页面 产品信息页面 客…

JSP(Java Server Pages)基础使用二

简单练习在jsp页面上输出出乘法口诀表 既然大家都是来看这种代码的人了&#xff0c;那么这种输出乘法口诀表的这种简单算法肯定是难不住大家了&#xff0c;所以这次主要是来说jsp的使用格式问题。 <%--Created by IntelliJ IDEA.User: ***Date: 2024/7/18Time: 11:26To ch…

consul注册中心与容器自动发现实战

consul简介 Consul 是 HashiCorp 公司推出的开源工具&#xff0c;用于实现分布式系统的服务发现与配置。内置了服务注册与发现框 架、分布一致性协议实现、健康检查、Key/Value 存储、多数据中心方案&#xff0c;不再需要依赖其它工具&#xff08;比如 ZooKeeper 等&#xff0…

拾色器的取色的演示

前言 今天&#xff0c;有一个新新的程序员问我&#xff0c;如何确定图片中我们需要选定的颜色范围。一开始&#xff0c;我感到对这个问题很不屑。后来&#xff0c;想了想&#xff0c;还是对她说&#xff0c;你可以参考一下“拾色器”。 后来&#xff0c;我想关于拾色器&#…

动态规划11,完全背包模板

NC309 完全背包 问题一&#xff1a;求这个背包至多能装多大价值的物品&#xff1f; 状态表示&#xff1a;经验题目要求 dp[i][j] 表示 从前i个物品中挑选&#xff0c;总体积不超过j&#xff0c;所有选法中&#xff0c;能选出来的最大价值。 状态转移方程 根据最后一步的状态&a…

C语言 typedef - C语言零基础入门教程

目录 一.typedef 简介 二.typedef 实战 1.typedef 定义基本数据变量 2.typedef 定义结构体 A.常规定义结构体B.typedef 定义结构体C.结构体使用 typedef 和不使用 typedef 区别 3.typedef 定义函数指针 三.猜你喜欢 零基础 C/C 学习路线推荐 : C/C 学习目录 >> C 语言基础…