强化学习笔记之【DDPG算法】

强化学习笔记之【DDPG算法】


文章目录

前言:

本文为强化学习笔记第二篇,第一篇讲的是Q-learning和DQN

就是因为DDPG引入了Actor-Critic模型,所以比DQN多了两个网络,网络名字功能变了一下,其它的就是软更新之类的小改动而已

本文初编辑于2024.10.6

CSDN主页:https://blog.csdn.net/rvdgdsva

博客园主页:https://www.cnblogs.com/hassle

博客园本文链接:

真 · 图文无关

原论文伪代码

  • 上述代码为DDPG原论文中的伪代码

DDPG算法

需要先看:

Deep Reinforcement Learning (DRL) 算法在 PyTorch 中的实现与应用【DDPG部分】【没有在选择一个新的动作的时候,给policy函数返回的动作值增加一个噪音】【critic网络与下面不同】

深度强化学习笔记——DDPG原理及实现(pytorch)【DDPG伪代码部分】【这个跟上面的一样没有加噪音】【critic网络与上面不同】

【深度强化学习】(4) Actor-Critic 模型解析,附Pytorch完整代码【选看】【Actor-Critic理论部分】


如果需要给policy函数返回的动作值增加一个噪音,实现如下

def select_action(self, state, noise_std=0.1):state = torch.FloatTensor(state.reshape(1, -1))action = self.actor(state).cpu().data.numpy().flatten()# 添加噪音,上面两个文档的代码都没有这个步骤noise = np.random.normal(0, noise_std, size=action.shape)action = action + noisereturn action

DDPG 中的四个网络

注意!!!这个图只展示了Critic网络的更新,没有展示Actor网络的更新

  • Actor 网络(策略网络)
    • 作用:决定给定状态 ss 时,应该采取的动作 a=π(s)a=π(s),目标是找到最大化未来回报的策略。
    • 更新:基于 Critic 网络提供的 Q 值更新,以最大化 Critic 估计的 Q 值。
  • Target Actor 网络(目标策略网络)
    • 作用:为 Critic 网络提供更新目标,目的是让目标 Q 值的更新更为稳定。
    • 更新:使用软更新,缓慢向 Actor 网络靠近。
  • Critic 网络(Q 网络)
    • 作用:估计当前状态 ss 和动作 aa 的 Q 值,即 Q(s,a)Q(s,a),为 Actor 提供优化目标。
    • 更新:通过最小化与目标 Q 值的均方误差进行更新。
  • Target Critic 网络(目标 Q 网络)
    • 作用:生成 Q 值更新的目标,使得 Q 值更新更为稳定,减少振荡。
    • 更新:使用软更新,缓慢向 Critic 网络靠近。

大白话解释:

​ 1、DDPG实例化为actor,输入state输出action
​ 2、DDPG实例化为actor_target
​ 3、DDPG实例化为critic_target,输入next_state和actor_target(next_state)经DQN计算输出target_Q
​ 4、DDPG实例化为critic,输入state和action输出current_Q,输入state和actor(state)【这个参数需要注意,不是action】经负均值计算输出actor_loss

​ 5、current_Q 和target_Q进行critic的参数更新
​ 6、actor_loss进行actor的参数更新

action实际上是batch_action,state实际上是batch_state,而batch_action != actor(batch_state)

因为actor是频繁更新的,而采样是随机采样,不是所有batch_action都能随着actor的更新而同步更新

Critic网络的更新是一发而动全身的,相比于Actor网络的更新要复杂要重要许多


代码核心更新公式

t a r g e t ‾ Q = c r i t i c ‾ t a r g e t ( n e x t ‾ s t a t e , a c t o r ‾ t a r g e t ( n e x t ‾ s t a t e ) ) t a r g e t ‾ Q = r e w a r d + ( 1 − d o n e ) × g a m m a × t a r g e t ‾ Q . d e t a c h ( ) target\underline{~}Q = critic\underline{~}target(next\underline{~}state, actor\underline{~}target(next\underline{~}state)) \\target\underline{~}Q = reward + (1 - done) \times gamma \times target\underline{~}Q.detach() target Q=critic target(next state,actor target(next state))target Q=reward+(1done)×gamma×target Q.detach()

  • 上述代码与伪代码对应,意为计算预测Q值

c r i t i c ‾ l o s s = M S E L o s s ( c r i t i c ( s t a t e , a c t i o n ) , t a r g e t ‾ Q ) c r i t i c ‾ o p t i m i z e r . z e r o ‾ g r a d ( ) c r i t i c ‾ l o s s . b a c k w a r d ( ) c r i t i c ‾ o p t i m i z e r . s t e p ( ) critic\underline{~}loss = MSELoss(critic(state, action), target\underline{~}Q) \\critic\underline{~}optimizer.zero\underline{~}grad() \\critic\underline{~}loss.backward() \\critic\underline{~}optimizer.step() critic loss=MSELoss(critic(state,action),target Q)critic optimizer.zero grad()critic loss.backward()critic optimizer.step()

  • 上述代码与伪代码对应,意为使用均方误差损失函数更新Critic

a c t o r ‾ l o s s = − c r i t i c ( s t a t e , a c t o r ( s t a t e ) ) . m e a n ( ) a c t o r ‾ o p t i m i z e r . z e r o ‾ g r a d ( ) a c t o r ‾ l o s s . b a c k w a r d ( ) a c t o r ‾ o p t i m i z e r . s t e p ( ) actor\underline{~}loss = -critic(state,actor(state)).mean() \\actor\underline{~}optimizer.zero\underline{~}grad() \\ actor\underline{~}loss.backward() \\ actor\underline{~}optimizer.step() actor loss=critic(state,actor(state)).mean()actor optimizer.zero grad()actor loss.backward()actor optimizer.step()

  • 上述代码与伪代码对应,意为使用确定性策略梯度更新Actor

c r i t i c ‾ t a r g e t . p a r a m e t e r s ( ) . d a t a = ( t a u × c r i t i c . p a r a m e t e r s ( ) . d a t a + ( 1 − t a u ) × c r i t i c ‾ t a r g e t . p a r a m e t e r s ( ) . d a t a ) a c t o r ‾ t a r g e t . p a r a m e t e r s ( ) . d a t a = ( t a u × a c t o r . p a r a m e t e r s ( ) . d a t a + ( 1 − t a u ) × a c t o r ‾ t a r g e t . p a r a m e t e r s ( ) . d a t a ) critic\underline{~}target.parameters().data=(tau \times critic.parameters().data + (1 - tau) \times critic\underline{~}target.parameters().data) \\ actor\underline{~}target.parameters().data=(tau \times actor.parameters().data + (1 - tau) \times actor\underline{~}target.parameters().data) critic target.parameters().data=(tau×critic.parameters().data+(1tau)×critic target.parameters().data)actor target.parameters().data=(tau×actor.parameters().data+(1tau)×actor target.parameters().data)

  • 上述代码与伪代码对应,意为使用策略梯度更新目标网络

Actor和Critic的角色

  • Actor:负责选择动作。它根据当前的状态输出一个确定性动作。
  • Critic:评估Actor的动作。它通过计算状态-动作值函数(Q值)来评估给定状态和动作的价值。

更新逻辑

  • Critic的更新
    1. 使用经验回放缓冲区(Experience Replay)从中采样一批经验(状态、动作、奖励、下一个状态)。
    2. 计算目标Q值:使用目标网络(critic_target)来估计下一个状态的Q值(target_Q),并结合当前的奖励。
    3. 使用均方误差损失函数(MSELoss)来更新Critic的参数,使得预测的Q值(target_Q)与当前Q值(current_Q)尽量接近。
  • Actor的更新
    1. 根据当前的状态(state)从Critic得到Q值的梯度(即对Q值相对于动作的偏导数)。
    2. 使用确定性策略梯度(DPG)的方法来更新Actor的参数,目标是最大化Critic评估的Q值。

个人理解:

DQN算法是将q_network中的参数每n轮一次复制到target_network里面

DDPG使用系数 τ \tau τ来更新参数,将学习到的参数更加soft地拷贝给目标网络

DDPG采用了actor-critic网络,所以比DQN多了两个网络

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1555763.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

Ubuntu22.04 Docker 国内安装最靠谱教程

目前docker在国内安装常存在众所周知的网络问题,如果安装过程如果从官网地址安装以及安装之后从官网要拉取镜像都存在问题。这篇文章主要针对这两个问题总结最靠谱的docker安装教程。 1. docker安装 1.1 系统环境概述 Ubuntu 22.04linux内核版本 6.8(…

重学SpringBoot3-集成Redis(四)之Redisson

更多SpringBoot3内容请关注我的专栏:《SpringBoot3》 期待您的点赞👍收藏⭐评论✍ 重学SpringBoot3-集成Redis(四)之Redisson 1. 添加 Redisson 依赖2. 配置 Redisson 客户端3. 使用 Redisson 实现分布式锁4. 调用分布式锁5. 为什…

二进制的神奇操作——拆位法和贡献思想

拆位的引入 我们来思考这么一个问题,如果给你一个数组,让你去求一个数组里面所有连续子串的异或和的和,问你该怎么求? 我们该如何去处理,首先肯定是会想到暴力的思路,第一层循环遍历左端点,第…

SpringBoot在线教育平台:设计与实现的深度解析

2相关技术 2.1 MYSQL数据库 MySQL是一个真正的多用户、多线程SQL数据库服务器。 是基于SQL的客户/服务器模式的关系数据库管理系统,它的有点有有功能强大、使用简单、管理方便、安全可靠性高、运行速度快、多线程、跨平台性、完全网络化、稳定性等,非常…

654、最大二叉树

1、题目描述 . - 力扣(LeetCode) 其实就是给定了一个所谓"最大二叉树"的规则,让我们去构建二叉树。 以 nums [3,2,1,6,0,5] 为例,规则如下: (1)找出其中的最大值6将其作为根节点,6前面的是左子…

程序传入单片机的过程,以Avrdude为例分析

在市场上有各式各样的单片机,例如Arduino,51单片机,STM等。通常,我们都用其对应的IDE软件进行单片机的编程。这些软件既负责将程序代码转写成二进制代码,即机器语言,也负责将该二进制代码导入单片机。与此同…

C++ 算法学习——7.4.1 优化算法——双指针

双指针法(Two Pointers)是一种常用的算法技巧,通常用于解决数组或链表中的问题。这种技巧通过维护两个指针,通常分别指向数组或链表的不同位置,来协同解决问题。双指针法一般有两种类型:快慢指针和左右指针…

什么是transformer大模型,答案就在这里

Transformer大模型是一种在自然语言处理(NLP)领域中广泛使用的模型,其详细数据与分析可以从以下几个方面进行阐述: 1. 模型架构 Transformer模型本质上是一个Encoder-Decoder架构。编码组件由多层编码器(Encoder&…

(笔记)第三期书生·浦语大模型实战营(十一卷王场)–书生基础岛第3关---浦语提示词工程实践

学员闯关手册:https://aicarrier.feishu.cn/wiki/ZcgkwqteZi9s4ZkYr0Gcayg1n1g?open_in_browsertrue 课程视频:https://www.bilibili.com/video/BV1cU411S7iV/ 课程文档: https://github.com/InternLM/Tutorial/tree/camp3/docs/L1/Prompt 关…

还在“卷”长度?长文本模型真的基于上下文进行回复吗?

近年来,随着长文本模型(Long-context Model, LCM)技术的突飞猛进,处理长上下文的能力已成为各大语言模型(Large Language Model, LLM)的核心竞争力,也是各大技术厂商争夺的焦点。截至2023年12月…

RAG再总结之如何使大模型更好使用外部数据:四个不同层级及查询-文档对齐策略

我们来看看RAG进展。《Retrieval Augmented Generation (RAG) and Beyond: A Comprehensive Survey on How to Make your LLMs use External Data More Wisely》(https://arxiv.org/abs/2409.14924),主要讨论了如何使大型语言模型(LLMs)更明智…

Redis中BitMap实现签到与统计连续签到功能

服务层代码 //签到Overridepublic Result sign() {//1.获取当前登录的用户Long userId UserHolder.getUser().getId();//获取日期LocalDateTime now LocalDateTime.now();//拼接keyString keySuffix now.format(DateTimeFormatter.ofPattern(":yyyyMM"));String …

实例分割、语义分割和 SAM(Segment Anything Model)

实例分割、语义分割和 SAM(Segment Anything Model) 都是图像处理中的重要技术,它们的目标是通过分割图像中的不同对象或区域来帮助识别和分析图像,但它们的工作方式和适用场景各有不同。 1. 语义分割(Semantic Segme…

一款基于 Java 的可视化 HTTP API 接口快速开发框架,干掉 CRUD,效率爆炸(带私活源码)

平常我们经常需要编写 API,但其实常常只是一些简单的增删改查,写这些代码非常枯燥无趣。 今天给大家带来的是一款基于 Java 的可视化 HTTP API 接口快速开发框架,通过 UI 界面编写接口,无需定义 Controller、Service、Dao 等 Jav…

Bolt.new:终极自动化编程工具

兄弟们,终极写代码工具来了—— Bolt.new!全方位的编程支持: StackBlitz 推出了 Bolt․new,这是一款结合了 AI 与 WebContainers 技术的强大开发平台,允许用户快速搭建并开发各种类型的全栈应用。 它的主要特点是无需…

内网靶场 | 渗透攻击红队内网域渗透靶场-1(Metasploit)零基础入门到精通,收藏这一篇就够了

“ 和昨天的文章同一套靶场,这次主要使用的是Kali Linux以及Metasploit来打靶场,熟悉一下MSF在内网渗透中的使用,仅供学习参考,大佬勿喷。本期文章靶场来自公众号:渗透攻击红队。” 靶场下载地址:https://…

展锐平台WIFI国家码信道总结

展锐平台WIFI国家码信道总结 1.下载wireless-regdb wireless-regdb是一个开源的工程,编译它会生成regulatory.bin文件,这实际上是一个加密后的数据库,它记录各个国家可用的无线频段。 可从下面的网站上下载最新的regdb库: https://git.kernel.org/pub/scm/linux/kernel…

【无人水面艇路径跟随控制2】(C++)USV代码阅读: SetOfLos 类的从路径点和里程计信息中计算期望航向

【无人水面艇路径跟随控制2】(C)USV代码阅读: SetOfLos 类的从路径点和里程计信息中计算期望航向 写在最前面set_of_los.cpp小结详细解释头文件包含命名空间构造函数和析构函数设置参数函数获取航向函数 🌈你好呀!我是…

热门:AI变现,看看谁在默默赚大钱?

在这个愈发依赖AI的时代,找到属于自己的盈利方式愈发重要。 更多实操教程和AI绘画工具,可以扫描下方,免费获取 总的来说,利用AI进行盈利的方式主要有三种:技术型、流量型和内容型。 每种方式都根植于AI的特性,但同时也需要特定…

重庆数字孪生工业互联网可视化技术,赋能新型工业化智能制造工厂

重庆作为西南地区的重要工业基地,正积极探索和实践数字孪生、工业互联网及可视化技术在智能制造领域的深度融合,致力于打造新型工业化智能制造工厂,为制造业的高质量发展注入强劲动力。 在重庆的智能制造工厂中,数字孪生技术被广…