pytorch量化训练

训练时量化(Quantization-aware Training, QAT)是一种在模型训练过程中,通过模拟低精度量化效应来增强模型对量化操作的鲁棒性的技术。与后训练量化不同,QAT 允许模型在训练过程中考虑到量化引入的误差,从而在实际部署时使用低精度进行推理时能够维持更高的性能。

1. 假量化节点插入(Fake Quantization Nodes)

在训练过程中,通过在网络中插入假量化节点来模拟量化和反量化的过程。这些节点在前向传播过程中将权重和激活值量化到指定的数值范围和精度(如INT8),然后再反量化回浮点数,以进行后续的计算。通过这种方式,模型可以适应量化带来的信息损失。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStubclass QuantAwareNet(nn.Module):def __init__(self):super(QuantAwareNet, self).__init__()self.quant = QuantStub() # 新插入内容self.dequant = DeQuantStub() # 新插入内容self.fc1 = nn.Linear(784, 256)self.relu = nn.ReLU()self.fc2 = nn.Linear(256, 10)def forward(self, x):x = self.quant(x) # 新插入内容x = self.fc1(x)x = self.relu(x)x = self.fc2(x)x = self.dequant(x) # 新插入内容return x

2. 量化配置

在PyTorch中,量化配置(QConfig)指定了模型量化过程中将使用的量化方案和算子。通过使用不同的QConfig,你可以控制如何量化模型中的权重和激活,这对于模型的性能和精度具有重要影响。

2.1 量化配置函数 get_default_qat_qconfig

get_default_qat_qconfig 是PyTorch提供的一个函数,用于获取用于量化感知训练(QAT)的默认量化配置。这个函数的一个重要参数是后端,通常是 ‘fbgemm’ 或 ‘qnnpack’:

  • ‘fbgemm’: 主要用于服务器和桌面平台上的x86架构,支持INT8量化。
  • ‘qnnpack’: 适用于移动设备,也支持INT8量化,优化了ARM架构。
from torch.quantization import get_default_qconfig
qconfig = get_default_qconfig('fbgemm')

这个函数会设置一个QConfig,其中包括针对权重和激活的量化方案。对于QAT,权重通常在前向过程中进行伪量化,而激活则在训练时进行动态量化。

2.2 可以设置的其他配置选项

PyTorch允许用户自定义QConfig,以适应特定的需求或实验不同的量化方案。自定义QConfig通常涉及以下部分:

2.2.1 量化方案:

  • torch.quantization.default_observer:
    默认的观察者,用于激活,基于移动平均和最小最大值自动调整量化参数。
  • torch.quantization.default_per_channel_weight_observer:
    用于权重的通道级观察者,每个输出通道有独立的量化参数。

2.2.2 量化和反量化函数:

  • torch.quantization.FakeQuantize: 实现伪量化和反量化,模拟量化的效果而不改变底层数据类型。

创建自定义的QConfig:

from torch.quantization import QConfig, default_observer, default_per_channel_weight_observercustom_qconfig = QConfig(activation=default_observer.with_args(dtype=torch.qint8),weight=default_per_channel_weight_observer.with_args(dtype=torch.qint8)
)

2.3 使用自定义QConfig

可以应用到模型的特定部分或整个模型上

model.fc1.qconfig = custom_qconfig  # 应用到模型的一个特定层 

# 应用到整个模型
from torch.quantization import prepare_qat  
model.qconfig = custom_qconfig 
model = prepare_qat(model, inplace=True) 

3. 量化感知训练

import torch
import torch.nn as nn
import torch.optim as optim
from torch.quantization import QuantStub, DeQuantStub, prepare_qat, convert# 实例化模型
model = MyQuantizedModel() # 指定量化配置
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')# 准备量化感知训练,
model = prepare_qat(model)# 训练配置
optimizer = optim.SGD(model.parameters(), lr=0.01)# 训练循环
for epoch in range(num_epochs):for data, target in dataloader:optimizer.zero_grad()output = model(data)loss = loss_fn(output, target)loss.backward()optimizer.step()# 转换模型为完全量化if epoch == num_epochs - 1:model = convert(model.eval(), inplace=True)

4. 量化推理测试

import torch
from torch.quantization import convertdef test_quantized_model(model, dataloader, device='cpu'):model = convert(model.eval(), inplace=True)model.to(device)  # 确保模型在正确的设备上correct = 0total = 0with torch.no_grad():  # 关闭梯度计算,因为我们只做推理for data, targets in dataloader:data, targets = data.to(device), targets.to(device)  # 移动数据到相应设备outputs = model(data)  # 前向推理_, predicted = torch.max(outputs.data, 1)  # 获取预测结果total += targets.size(0)correct += (predicted == targets).sum().item()accuracy = 100 * correct / totalprint(f'Accuracy of the quantized model on the test data: {accuracy:.2f}%')# 'test_loader' 是用于测试的 DataLoader
# 测试模型
# test_quantized_model(quantized_model, test_loader, device='cuda' if torch.cuda.is_available() else 'cpu'

5.完整参考代码

import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoaderimport torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStubimport torch.optim as optim
from torch.quantization import get_default_qconfig, prepare_qat, convert# 定义数据预处理
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')class QuantizedCNN(nn.Module):def __init__(self):super(QuantizedCNN, self).__init__()self.quant = QuantStub()self.conv1 = nn.Conv2d(3, 16, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(16, 32, 5)self.fc1 = nn.Linear(32 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)self.dequant = DeQuantStub()def forward(self, x):# x = self.quant(x)x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = torch.flatten(x, 1)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)x = self.dequant(x)return xdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = QuantizedCNN().to(device)
model.qconfig = get_default_qconfig('qnnpack')# # 准备模型进行量化感知训练
model = prepare_qat(model, inplace=True)optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()# 训练循环
num_epochs = 10
for epoch in range(num_epochs):model.train()running_loss = 0.0for i, data in enumerate(trainloader, 0):inputs, labels = data[0].to(device), data[1].to(device)optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()if i % 2000 == 1999:print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}')running_loss = 0.0# 切换到评估模式进行测试model.eval()correct = 0total = 0with torch.no_grad():for data in testloader:images, labels = data[0].to(device), data[1].to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))# 在最后一个epoch后完成量化if epoch == num_epochs - 1:model = convert(model.eval(), inplace=True)print("Model quantization completed.")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/12212.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

datastage在升级版本到11.7之后,部分在11.3上正常执行的SP报错SQLSTATE = 22007: 本机错误代码 = -180

在升级版本到11.7之后,部分在11.3上正常执行的SP开始报错,报的SQL错误是时间参数问题,但是一样的SP可以直接call sp执行,也可以手动调用作业执行,只有设置定时调度时作业会报错, CALLXXX.XXX(1,CURRENT TIM…

Windows VSCode .NET CORE WebAPI Debug配置

1.安装C#插件 全名C# for Visual Studio Code,选择微软的 2. 安装C# Dev Kit插件 全名C# Dev Kit for Visual Studio Code,同样是选择微软的 3.安装Debugger for Unity 4.配置launch.json 文件 {"version": "0.2.0","config…

Docker使用docker-compose一键部署nacos、Mysql、redis

下面是一个简单的例子,展示如何通过Docker Compose文件部署Nacos、MySQL和Redis。请确保您的机器上已经安装了Docker和Docker Compose。 1,准备好mysql、redis、nacos镜像 sudo docker pull mysql:8 && sudo docker pull redis:7.2 &&…

【模块一】kubernetes容器编排进阶实战之k8s基础概念

kubernetes 基本介绍 kubernetes 组件简介 - master: 主人,并不部署服务,而是管理salve节点。 后期更名为: controll plane,控制面板。 etcd: 2379(客户端通信)、2…

【MPC-Simulink】EX04 信号归一化简化权重调节过程与提高数值计算质量

【MPC-Simulink】EX04 信号归一化简化权重调节过程与提高数值计算质量 参考 Matlab 官网提供的 Model Predictive Control Toolbox - Getting Started Guide,在 MPC 控制器中指定缩放因子,可以简化权重调节过程,提高数值计算质量。 当被控对…

Dubbo分布式日志跟踪实现

前言 随着越来越多的应用逐渐微服务化后,分布式服务之间的RPC调用使得异常排查的难度骤增,最明显的一个问题,就是整个调用链路的日志不在一台机器上,往往定位问题就要花费大量时间。如何在一个分布式网络中把单次请求的整个调用日…

企业网络转型:优势与挑战

◎ 网络研究观 事实上,现代企业网络是一个由相互连接的数据、应用程序和基础设施组成的复杂网络。然而,企业不应让这种复杂性成为服务不可靠、安全漏洞或网络停机的借口。 由于组织和公司面临着从并购到云扩展的诸多挑战,以及网络技术日益复…

【算法一周目】双指针(1)

目录 1.双指针介绍 2.移动零 解题思路 C代码实现 3.复写零 解题思路 C代码实现 4.快乐数 解题思路 C代码实现 5.盛水最多的容器 解题思路 C代码实现 1.双指针介绍 常见的双指针有两种形式,一种是对撞指针,一种是快慢指针。 对撞指针&#x…

6547网:青少年软件编程Python等级考试(六级)真题试卷

2024年9月青少年软件编程Python等级考试(六级)真题试卷 题目总数:38 总分数:100 选择题 第 1 题 单选题 下面Python代码运行后出现的图像是?( ) import matplotlib.pyplot as plt im…

【5种灵活有效方式】如何从死机手机中恢复内部数据?

本文介绍了5种方法来从死机的Android设备中恢复数据,包括使用U1tData安卓数据恢复软件、SD卡、OTG、Google云端硬盘和SamsungCloud。这些方法覆盖了不同情况下的数据恢复需求。 摘要由CSDN通过智能技术生成 我的手机掉在地上,现在无法开机。我丢失了所…

【安全测试】sqlmap工具(sql注入)学习

前言:sqimap是一个开源的渗透测试工具,它可以自动化检测和利用SQL注入缺陷以及接管数据库服务器的过程。它有一个强大的检测引擎,许多适合于终极渗透测试的小众特性和广泛的开关,从数据库指纹、从数据库获 取数据到访问底层文件系…

行业类别-智慧城市-子类别智能交通-细分类别自动驾驶技术-应用场景城市公共交通优化

1.大纲分析 针对题目“8.0 行业类别-智慧城市-子类别智能交通-细分类别自动驾驶技术-应用场景城市公共交通优化”的大纲分析,可以从以下几个方面进行展开: 一、引言 简述智慧城市的概念及其重要性。强调智能交通在智慧城市中的核心地位。引出自动驾驶…

24.11.11 JavaScript1

JavaScript(简称js)是⼀种描述语⾔,基于对象和事件驱动的脚本语⾔ JavaScript特点:脚本语⾔(⼀种轻量级的编程语⾔) ⼀种解释性语⾔(⽆需预编译) 被设计为向HTML⻚⾯添加交互⾏为 运⾏于客户端&…

PDF24:多功能 PDF 工具使用指南

PDF24:多功能 PDF 工具使用指南 在日常工作和学习中,PDF 是一种常见且重要的文档格式。无论是查看、编辑、合并,还是转换 PDF 文件,能够快速高效地处理 PDF 文档对于提高工作效率至关重要。PDF24 是一款免费、功能全面的 PDF 工具…

计算机的错误计算(一百五十一)

摘要 探讨 MATLAB 中反正弦 asin 与反余弦 acos 函数的计算精度问题。 例1. 已知 计算 及 直接贴图吧: 另外,16位的正确值分别为 0.1570785896071048e1、0.1043072384837152e-4、-0.1570785896071048e1 与 0.3141582222865945e1(I…

Lua进阶用法之Lua和C的接口设计

一:lua/c的接口编程 首先skynet、openresty 都是深度使用 lua 语言的典范;学习 lua 不仅仅要学习基本用法,还要学会使用 c 与 lua 交互,这样才学会了 lua 作为胶水语言的精髓,下面看一下他们两个的调用过程。 虚拟栈&a…

macOS 下的 ARM 裸机嵌入式开发入门- 第二部分:实现第一个裸机应用并且调试

1、准备二进制运行程序镜像 利用 QEMU 仿真一个完整的系统,并创建最简单的“Hello world!”示例。 QEMU 模拟器支持 VersatilePB 平台,该平台包含一个 ARM926EJ-S 核心,以及其他外设,四个 UART 串行端口;特别是第一个…

【网络面试篇】其他面试题——Cookie、Session、DNS、CDN、SSL/TLS、加密概念

目录 一、HTTP 相关问题 1. Cookie 和 Session 是什么? (1)Cookie (2)Session 2. Cookie 的工作原理? 3. Session 的工作原理? 4. Cookie 和 Session 有什么区别? 二、其他问…

【数值分析】复习1---牛顿迭代法

首先,我们先来回顾一下牛顿迭代法的概念。 这里注意的是,牛顿迭代法是一种线性方法,它在点 x k x_k xk​处进行线性展开,而且展开成一阶泰勒公式!注意是一阶,不是二阶,不是更高阶,所…

文本语义分块、RAG 系统的分块难题:小型语言模型如何找到最佳断点

文本语义分块、RAG 系统的分块难题:小型语言模型如何找到最佳断点? 转自jina最新的关于文本语义分块的分享和模型 之前我们聊过RAG 里文档分块 (Chunking) 的挑战,也介绍了 迟分 (Late Chunking) 的概念,它可以在向量化的时候减…