当前位置: 首页 > news >正文

Pytorch 反向传播

1. 理论基础:动态计算图与自动微分

动态计算图(DAG)

PyTorch采用动态计算图(Dynamic Computation Graph),在运行时即时构建计算流程。每个操作(如加法、乘法)会在执行时生成一个节点,并记录输入输出依赖关系。这种“定义即运行”机制允许灵活控制流(如条件分支、循环),但牺牲了静态图的编译期优化机会。

自动微分(AutoGrad)

基于反向模式自动微分(Reverse-mode AD),利用链式法则计算梯度:

  • 前向传播:记录操作序列,构建DAG。
  • 反向传播:从输出(Loss)出发,按拓扑逆序逐层计算梯度。
    数学形式:若 y = f ( x ) y = f(x) y=f(x) ,则梯度 d L d x = ∑ i d L d y i ⋅ ∂ y i ∂ x \frac{dL}{dx} = \sum_{i} \frac{dL}{dy_i} \cdot \frac{\partial y_i}{\partial x} dxdL=idyidLxyi

典型应用场景

  • 调试时动态调整网络结构(如注意力机制)。
  • 需要复杂控制流的模型(如强化学习策略网络)。

2. 实现流程:反向传播步骤分解

分阶段流程:
  1. 前向传播
    import torch
    # 定义参数
    w = torch.randn(1, requires_grad=True)
    b = torch.randn(1, requires_grad=True)# 输入数据
    x = torch.tensor([2.0])
    y_true = torch.tensor([5.0])# 前向计算
    y_pred = w * x + b  # 构建计算图
    
  2. 计算Loss
    loss = (y_pred - y_true) ** 2  # Loss = (wx + b - y_true)^2
    
  3. 反向传播
    loss.backward()  # 启动反向传播,计算梯度
    
  4. 梯度更新
    with torch.no_grad():w -= 0.1 * w.grad  # 手动更新参数b -= 0.1 * b.grad
    
  5. 梯度清零
    w.grad.zero_()  # 避免梯度累积
    b.grad.zero_()
    

关键点

  • requires_grad标记需跟踪梯度的张量。
  • backward()触发Autograd引擎递归计算梯度。
  • torch.no_grad()临时禁用梯度计算。

3. 核心组件:grad_fn与Autograd引擎

grad_fn属性
a = torch.tensor([2.0], requires_grad=True)
b = a ** 2
c = b + 3
print(c.grad_fn)  # <AddBackward0 object>,记录生成c的操作
  • grad_fn指向创建该Tensor的操作函数(如AddBackward0MulBackward0)。
  • 叶子节点(Leaf Nodes):由用户直接创建的Tensor(如a),其grad_fn=None
Autograd引擎工作原理
  1. 依赖追踪:在前向传播时,Autograd记录每个操作的前向函数和反向传播函数。
  2. 拓扑排序:反向传播时,按DAG逆序处理节点(确保父节点梯度已计算)。
  3. 梯度累加:中间节点梯度会累积到tensor.grad中。

源码级原理

  • C++实现的Engine类管理任务队列,Python端通过torch.autograd.backward()交互。
  • 每个grad_fn实现forward()backward()方法。

4. 代码演示:反向传播与高阶导数

基本反向传播示例
import torch# 定义模型参数
w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)# 数据
x = torch.tensor([2.0])
y_true = torch.tensor([5.0])# 前向传播
y_pred = w * x + b
loss = (y_pred - y_true) ** 2# 反向传播
loss.backward()# 查看梯度
print(f"w.grad: {w.grad}, b.grad: {b.grad}")
高阶导数计算
x = torch.tensor([3.0], requires_grad=True)
y = x ** 3  # y = x^3# 一阶导数 dy/dx = 3x²
first_derivative = torch.autograd.grad(y, x, create_graph=True)[0]  # create_graph=True保留计算图# 二阶导数 d²y/dx² = 6x
second_derivative = torch.autograd.grad(first_derivative, x)[0]print(f"First derivative: {first_derivative}")  # 输出 3*3²=27
print(f"Second derivative: {second_derivative}")  # 输出 6*3=18

关键点

  • create_graph=True允许计算高阶导数。
  • 返回值是元组,需索引取第一个元素。

5. 内存管理:计算图保留与优化

retain_graph参数
loss1 = (w * x + b - y_true) ** 2
loss2 = (w * x**2 + b - y_true) ** 2loss1.backward(retain_graph=True)  # 第一次反向传播保留图
loss2.backward()  # 继续第二次反向传播
  • 默认反向传播后释放计算图,retain_graph=True防止释放。
  • 应用场景:多任务学习中多个Loss顺序反向传播。
内存优化技巧
  • 及时清除无用中间变量del tensor或使用上下文管理器。
  • 合并操作:减少冗余计算(如将多个激活函数合并)。
  • 降低精度:使用torch.float16或混合精度训练。
with torch.no_grad():  # 推理阶段禁用梯度predictions = model(inputs)

6. 注意事项:常见陷阱与解决方案

梯度累积
for batch in data_loader:optimizer.zero_grad()outputs = model(batch)loss = loss_function(outputs, labels)loss.backward()  # 梯度累积(未调用zero_grad)
optimizer.step()  # 累积多个batch后更新
  • 应用场景:小显存设备模拟大batch size训练。
In-place操作限制
x = torch.randn(3, requires_grad=True)
# 错误!In-place操作破坏计算图
x.add_(1)  # 报错:a leaf Variable that requires grad is being used in an in-place operation
  • PyTorch禁止修改requires_grad=True张量的值(除非标记为.volatile)。
非标量输出处理
x = torch.tensor([2.0], requires_grad=True)
y = torch.stack([x**2, x**3])  # 非标量输出
v = torch.tensor([1.0, 0.1])  # 外部梯度权重
y.backward(v)  # 相当于计算 ∂L/∂x = v[0]*dy[0]/dx + v[1]*dy[1]/dx
print(x.grad)  # 输出 1.0*4 + 0.1*12 = 5.2
  • 对非标量输出调用backward()时必须传入gradient参数,用于指定外部梯度。

总结

维度关键技术点典型应用
理论基础DAG、反向模式AD动态模型设计
实现流程forward → backward → update训练自定义模型
核心组件grad_fn、Autograd引擎调试梯度计算流程
内存管理retain_graph、no_grad多任务学习、低显存训练
注意事项梯度累积、in-place限制、非标量处理复杂Loss设计、高阶优化问题
http://www.xdnf.cn/news/217351.html

相关文章:

  • 塔能照明节能服务流程:精准驱动工厂能耗优化
  • leetcode:3005. 最大频率元素计数(python3解法)
  • 第三次作业(密码学)
  • 【android bluetooth 协议分析 06】【l2cap详解 11】【l2cap连接超时处理逻辑介绍】
  • (29)VTK C++开发示例 ---绘制两条彩色线
  • 想做博闻强记的自己
  • IoTDB数据库建模与资源优化指南
  • Python中的defaultdict方法
  • 驱动开发硬核特训 · Day 24(下篇):深入理解 Linux 内核时钟子系统结构
  • 【深度学习的灵魂】图片布局生成模型LayoutPrompt(1)
  • MATLAB函数调用全解析:从入门到精通
  • 【Linux】g++安装教程
  • Linux 命名管道+日志
  • 婴幼儿托育实训室生活照料流程标准化设计
  • Flowable7.x学习笔记(十五)动态指定用户分配参数启动工作流程
  • AutogenStudio使用
  • 快速掌握向量数据库-Milvus探索2_集成Embedding模型
  • AI技术前沿:Function Calling、RAG与MCP的深度解析与应用实践
  • 基于PyTorch的图像分类特征提取与模型训练文档
  • 集群系统的五大核心挑战与困境解析
  • EtherCAT转CANopen方案落地:推动运动控制器与传感器通讯的工程化实践
  • CKESC Breeze 6S 40A_4S 50A FOC BEC电调测评:全新vfast 技术赋能高效精准控制
  • 低代码平台部署方案解析:百特搭四大部署方式
  • 大模型推理:Qwen3 32B vLLM Docker本地部署
  • 强化学习贝尔曼方程推导
  • 流量守门员:接口限流艺术
  • Manus AI多语言手写识别技术全解析:从模型架构到实战部署
  • JavaScript 中深拷贝浅拷贝的区别?如何实现一个深拷贝?
  • 信雅达 AI + 悦数 Graph RAG | 大模型知识管理平台在金融行业的实践
  • C# 类的基本概念(实例成员)