李宏毅机器学习2023HW12—Reinforcement Learning强化学习

文章目录

  • Task
  • Baseline
    • Simple
    • Medium Baseline—Policy Gradient
    • Strong Baseline——Actor-Critic
    • Boss Baseline—Mask

Task

实现深度强化学习方法:

  • Policy Gradient
  • Actor-Critic

环境:月球着陆器

Baseline

Simple

定义优势函数(Advantage function)为执行完action之后直到结束每一步的reward累加,即:
A 1 = R 1 = r 1 + r 2 + . . . . + r T , A 2 = R 2 = r 2 + r 3 + . . . + r T , . . . A T = R T = r T A_1=R_1=r_1+r_2+....+r_T,\\ A_2=R_2=r_2+r_3+...+r_T,\\ ...\\ A_T=R_T=r_T A1=R1=r1+r2+....+rT,A2=R2=r2+r3+...+rT,...AT=RT=rT
其中, R R R为动作状态值函数, r i r_i ri为执行完动作 a i a_i ai得到的reward

for episode in range(EPISODE_PER_BATCH):state = env.reset()total_reward, total_step = 0, 0seq_rewards = []while True:action, log_prob = agent.sample(state) # at, log(at|st)next_state, reward, done, _ = env.step(action)log_probs.append(log_prob) # [log(a1|s1), log(a2|s2), ...., log(at|st)]# seq_rewards.append(reward)state = next_statetotal_reward += rewardtotal_step += 1rewards.append(reward) # change hereif done:final_rewards.append(reward)total_rewards.append(total_reward)break

Medium Baseline—Policy Gradient

第二个版本的cumulated reward,把离a1比较近的的reward给比较大的权重,比较远的给比较小的权重,如下:
A 1 = R 1 = r 1 + γ r 2 + . . . . + γ T − 1 r T , A 2 = R 2 = r 2 + γ r 3 + . . . + γ T − 2 r T , . . . A T = R T = r T A_1=R_1=r_1+\gamma r_2+....+\gamma ^{T-1} r_T,\\ A_2=R_2=r_2+\gamma r_3+...+\gamma ^{T-2} r_T,\\ ...\\ A_T=R_T=r_T A1=R1=r1+γr2+....+γT1rT,A2=R2=r2+γr3+...+γT2rT,...AT=RT=rT
在这里插入图片描述

# Take a state input and generate a probability distribution of an action through a series of Fully Connected Layers
class PolicyGradientNetwork(nn.Module):def __init__(self):super().__init__()self.fc1 = nn.Linear(8, 16)self.fc2 = nn.Linear(16, 16)self.fc3 = nn.Linear(16, 4)def forward(self, state):hid = torch.tanh(self.fc1(state))hid = torch.tanh(hid)return F.softmax(self.fc3(hid), dim=-1)
      while True:action, log_prob = agent.sample(state) # at, log(at|st)next_state, reward, done, _ = env.step(action)log_probs.append(log_prob) # [log(a1|s1), log(a2|s2), ...., log(at|st)]seq_rewards.append(reward) # r1, r2, ...., rtstate = next_statetotal_reward += rewardtotal_step += 1 # total_step in each episode is different# rewards.append(reward) # change here# ! IMPORTANT !# Current reward implementation: immediate reward,  given action_list : a1, a2, a3 ......#                                                         rewards :     r1, r2 ,r3 ......# medium:change "rewards" to accumulative decaying reward, given action_list : a1,                           a2,                           a3, ......#                                                           rewards :           r1+0.99*r2+0.99^2*r3+......, r2+0.99*r3+0.99^2*r4+...... ,  r3+0.99*r4+0.99^2*r5+ ......if done: # done is return by environment, true means current episode is donefinal_rewards.append(reward) # final step rewardtotal_rewards.append(total_reward) # total reward of this episode# calculate accumulative decaying rewarddiscounted_rewards = []R = 0for r in reversed(seq_rewards):R = r + rate * Rdiscounted_rewards.insert(0, R)rewards.extend(discounted_rewards)break

Strong Baseline——Actor-Critic

在这里插入图片描述Actor-Critic 算法中,ActorCritic 的损失函数主要基于策略梯度方法(用于更新 Actor 网络)以及价值函数(用于更新 Critic 网络)。这两部分的损失分别由策略的 Advantage 估计和状态价值的误差构成。

Actor 网络的目标是通过 策略梯度(Policy Gradient) 方法,最大化预期的累计奖励 $ \mathbb{E} [R] $。为此,损失函数通常为负的 log 概率乘以 Advantage(优势函数),该优势函数描述了当前策略执行动作的好坏程度。
L Actor = − E π [ log ⁡ ( π ( a ∣ s ) ) ⋅ A ( s , a ) ] L_{\text{Actor}} = -\mathbb{E}_{\pi} [\log(\pi(a|s)) \cdot A(s, a)] LActor=Eπ[log(π(as))A(s,a)]

其中:

  • log ⁡ ( π ( a ∣ s ) ) \log(\pi(a|s)) log(π(as)) :是状态 ( s ) 下选择动作 ( a ) 的 log 概率。
  • $ A(s, a) :是 ∗ ∗ A d v a n t a g e ∗ ∗ ,代表实际收益与 C r i t i c 估计的差距,定义为: :是 **Advantage**,代表实际收益与 Critic 估计的差距,定义为: :是Advantage,代表实际收益与Critic估计的差距,定义为:A(s, a) = r + \gamma V(s)_{t+1} - V(s)_t$
    其中:
    • r r r :当前动作的即时奖励。
    • $\gamma $:折扣因子,用于考虑未来奖励的权重。
    • V ( s ) t + 1 V(s)_{t+1} V(s)t+1:下一个状态的价值估计。
    • $V(s)_t $:当前状态的价值估计。

因此,Actor 损失函数的整体公式为:
L Actor = − ∑ i = 1 T log ⁡ ( π ( a ∣ s ) ) ⋅ A ( s , a ) = − ∑ i = 1 T log ⁡ ( π ( a ∣ s ) ) ⋅ ( r + γ V ( s ′ ) − V ( s ) ) L_{\text{Actor}} =-\sum{ ^T _{i=1}} \log(\pi(a|s)) \cdot A(s,a)= -\sum { ^T _{i=1}}\log(\pi(a|s)) \cdot (r + \gamma V(s') - V(s)) LActor=i=1Tlog(π(as))A(s,a)=i=1Tlog(π(as))(r+γV(s)V(s))

2. Critic 损失公式

Critic 网络的目标是尽可能精确地估计状态的价值 ($ V(s) $),所以我们使用 价值误差 作为 Critic 损失。常用的损失函数是 均方误差(Mean Squared Error, MSE) 或者 平滑 L1 损失

Critic 的损失函数可以写为:
L Critic = E [ ( V ( s ) t − ( r + γ V ( s ) t + 1 ) ) 2 ] L_{\text{Critic}} = \mathbb{E} \left[ \left( V(s)_t - \left( r + \gamma V(s)_{t+1} \right) \right)^2 \right] LCritic=E[(V(s)t(r+γV(s)t+1))2]

也就是说,Critic 通过最小化 ( V ( s ) t V(s)_t V(s)t) 和 ( r + γ V ( s ) t + 1 r + \gamma V(s)_{t+1} r+γV(s)t+1 ) 之间的误差,来提高状态价值的估计。

在使用 平滑 L1 损失 的情况下,公式为: L Critic = smooth_l1_loss ( V ( s ) t , r + γ V ( s ) t + 1 ) L_{\text{Critic}} = \text{smooth\_l1\_loss}(V(s)_t, r + \gamma V(s)_{t+1}) LCritic=smooth_l1_loss(V(s)t,r+γV(s)t+1).平滑 L1 损失 比均方误差对异常值更具鲁棒性。

from torch.optim.lr_scheduler import StepLR
class ActorCritic(nn.Module):def __init__(self):super().__init__()self.fc = nn.Sequential(nn.Linear(8, 16),nn.Tanh(),nn.Linear(16, 16),nn.Tanh())self.actor = nn.Linear(16, 4)self.critic = nn.Linear(16, 1)self.values = []self.optimizer = optim.SGD(self.parameters(), lr=0.001)def forward(self, state):hid = self.fc(state)self.values.append(self.critic(hid).squeeze(-1))return F.softmax(self.actor(hid), dim=-1)def learn(self, log_probs, rewards):values = torch.stack(self.values)loss = (-log_probs * (rewards - values.detach())).sum() + F.smooth_l1_loss(values, rewards)self.optimizer.zero_grad()loss.backward()self.optimizer.step()self.values = []def sample(self, state):action_prob = self(torch.FloatTensor(state))action_dist = Categorical(action_prob)action = action_dist.sample()log_prob = action_dist.log_prob(action)return action.item(), log_prob

Boss Baseline—Mask

Mask 蒙版(Mask) 和 Rate(折扣因子) 解释**

  • Mask

mask 是一个 蒙版向量,用于过滤掉无效的或不需要考虑的状态值或奖励。这通常在处理 序列数据 或者 部分状态无效的任务 时很有用。例如,在某些环境中,某些时间步的奖励可能不可用,或这些时间步不需要计入学习。

通过 mask,我们可以有选择性地忽略某些状态或动作:
A ( s , a ) = r + γ ⋅ mask ⋅ V ( s ) t + 1 − V ( s ) + t A(s, a) = r + \gamma \cdot \text{mask} \cdot V(s)_{t+1} - V(s)+t A(s,a)=r+γmaskV(s)t+1V(s)+t

  • Rate (折扣因子 ( γ \gamma γ ))

rate 也就是折扣因子 ( γ \gamma γ ),用于对未来奖励进行折现。它的作用是 平衡即时奖励与长期奖励。折扣因子 ( γ \gamma γ ) 的取值范围通常在 ( $[0, 1] $),当 ( γ = 0 \gamma = 0 γ=0 ) 时,表示完全只关注即时奖励;当 ($ \gamma \to 1 $) 时,表示对长期奖励的重视程度增加。

因此,完整的 Advantage 函数(带有蒙版和折扣因子的形式)是:
A ( s , a ) = r + γ ⋅ mask ⋅ V ( s ) t + 1 − V ( s ) t A(s, a) = r + \gamma \cdot \text{mask} \cdot V(s)_{t+1} - V(s)_t A(s,a)=r+γmaskV(s)t+1V(s)t

完整的损失函数公式()

  1. Actor 损失公式:
    L Actor = − ∑ i = 1 T log ⁡ ( π ( a ∣ s ) ) ⋅ ( r + γ ⋅ mask ⋅ V ( s ) t + 1 − V ( s ) t ) L_{\text{Actor}} = -\sum{ ^T _{i=1}} \log(\pi(a|s)) \cdot \left( r + \gamma \cdot \text{mask} \cdot V(s)_{t+1} - V(s)_t \right) LActor=i=1Tlog(π(as))(r+γmaskV(s)t+1V(s)t)

  2. Critic 损失公式:
    L Critic = smooth_l1_loss ( V ( s ) t , r + γ ⋅ mask ⋅ V ( s ) t + 1 ) L_{\text{Critic}} = \text{smooth\_l1\_loss}\left( V(s)_t, r + \gamma \cdot \text{mask} \cdot V(s)_{t+1} \right) LCritic=smooth_l1_loss(V(s)t,r+γmaskV(s)t+1)

from torch.optim.lr_scheduler import StepLR
class ActorCritic(nn.Module):def __init__(self):super().__init__()self.fc = nn.Sequential(nn.Linear(8, 16),nn.Tanh(),nn.Linear(16, 16),nn.Tanh())self.actor = nn.Linear(16, 4)self.critic = nn.Linear(16, 1)self.values = []self.optimizer = optim.SGD(self.parameters(), lr=0.001)self.scheduler = optim.lr_scheduler.CyclicLR(self.optimizer, base_lr=2e-4, max_lr=2e-3, step_size_up=10, mode='triangular2')def forward(self, state):hid = self.fc(state)self.values.append(self.critic(hid).squeeze(-1))return F.softmax(self.actor(hid), dim=-1)def learn(self, log_probs, rewards, mask, rate):values = torch.stack(self.values)advantage = rewards + rate* mask * torch.cat([values[1:], torch.zeros(1)]) - valuesloss = (-log_probs * (advantage.detach())).sum() + \F.smooth_l1_loss(advantage, torch.zeros(len(advantage)))self.optimizer.zero_grad()loss.backward()self.optimizer.step()self.scheduler.step()self.values = []def sample(self, state):action_prob = self(torch.FloatTensor(state))action_dist = Categorical(action_prob)action = action_dist.sample()log_prob = action_dist.log_prob(action)return action.item(), log_prob
while True:action, log_prob = agent.sample(state) # at, log(at|st)next_state, reward, done, _ = env.step(action)log_probs.append(log_prob) # [log(a1|s1), log(a2|s2), ...., log(at|st)]seq_rewards.append(reward)state = next_statetotal_reward += rewardtotal_step += 1if done:final_rewards.append(reward)total_rewards.append(total_reward)rewards += seq_rewardsmask += [1]*len(seq_rewards)mask[-1] = 0break

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/1541772.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

C++之Person类

首先设置头文件&#xff0c;将题目中的要求完成。 #include <iostream>using namespace std;class Person { public:Person();Person(string name, int id, string address);~Person();void setPerson(string name, int id, string address);void setName(string name);…

Kotlin编程全攻略:从基础到实战项目的系统学习资料

Kotlin作为一种现代、简洁的编程语言&#xff0c;正逐渐成为Android开发的新宠。本文将为您介绍一套全面的Kotlin学习资料&#xff0c;包括学习大纲、PDF文档、源代码以及配套视频教程&#xff0c;帮助您从Kotlin的基础语法到实战项目开发&#xff0c;系统地提升您的编程技能。…

2024年中国研究生数学建模竞赛B题(华为题目)WLAN组网中网络吞吐量建模一

2024年中国研究生数学建模竞赛B题&#xff08;华为题目&#xff09; WLAN组网中网络吞吐量建模 一、背景 无线局域网&#xff08;Wireless Local Area Network&#xff0c;WLAN&#xff09;是一种无线计算机网络&#xff0c;使用无线信道作为传输介质连接两个或多个设备。WL…

什么情况下会导致索引失效?

什么情况下会导致索引失效&#xff1f; 1. 组合索引非最左前缀2. LIKE查询%开头3. 字符串未加引号4. 不等比较5. 索引列运算6. OR连接查询 &#x1f496;The Begin&#x1f496;点点关注&#xff0c;收藏不迷路&#x1f496; 1. 组合索引非最左前缀 描述&#xff1a;在组合索引…

Linux之实战命令02:shred应用实例(三十六)

简介&#xff1a; CSDN博客专家、《Android系统多媒体进阶实战》一书作者 新书发布&#xff1a;《Android系统多媒体进阶实战》&#x1f680; 优质专栏&#xff1a; Audio工程师进阶系列【原创干货持续更新中……】&#x1f680; 优质专栏&#xff1a; 多媒体系统工程师系列【…

python sql中带引号字符串(单双引号)转义处理

描述&#xff1a; 最近在爬取数据保存到数据库时&#xff0c;遇到有引号的字符串插入MySQL报错&#xff1a;1064, "You have an error in your SQL syntax; check the manual that corresponds to your MySQL server version for the right syntax to use near 转义字符串…

【大模型实战篇】关于Bert的一些实操回顾以及clip-as-service的介绍

最近在整理之前的一些实践工作&#xff0c;一方面是为了笔记记录&#xff0c;另一方面也是自己做一些温故知新&#xff0c;或许对于理解一些现在大模型工作也有助益。 1. 基于bert模型实现中文语句的embedding编码 首先是基于bert模型实现中文语句的embedding编码&#xff0c;…

828华为云征文|Flexus X实例GitLab部署构建流水线-私人一体化代码仓库~

目录 前言Gitlab 环境准备 GitLab部署 拉取GitLab镜像 创建映射目录 运行GitLab容器 修改gitlab.rb配置 开放端口 切换语言 创建项目 添加ssh密钥 克隆代码 CICD 什么是CICD Gitlab中使用CICD 什么是Runner 安装GitLab Runner 获取注册令牌 runner注册 交互…

2024华为杯数学建模竞赛E题

2024年中国研究生数学建模竞赛E题 高速公路应急车道紧急启用模型 高速公路拥堵现象的原因众多&#xff0c;除了交通事故外&#xff0c;最典型的就是部分路段出现瓶颈现象&#xff0c;主要原因是车辆汇聚&#xff0c;而拥堵后又容易蔓延。高速公路一些特定的路段容易形成堵点&…

8-Python基础编程之数据类型操作——字典和集合

Python基础编程之数据类型操作——字典和集合 字典概念定义意义操作增删改查遍历计算和判定 集合概念定义可变集合不可变集合 操作单一集合操作增删查 集合之间操作交集并集差值判定 字典 概念 无序的&#xff0c;可变的键值对的集合 定义 方式一直接定义&#xff1a; per…

Springboot使用ThreadPoolTaskScheduler轻量级多线程定时任务框架

简介&#xff1a; Spring注解定时任务使用不是很灵活&#xff0c;如果想要灵活的配置定时任务&#xff0c;可以使用xxl-job 或者 quartz等定时任务框架&#xff0c;但是过于繁琐&#xff0c;可能成本较大。所以可以使用ThreadPoolTaskScheduler来灵活处理定时任务 ThreadPoolT…

2024 ICPC ShaanXi Provincial Contest —— C. Seats(个人理解)拓扑+dfs

2024 ICPC ShaanXi Provincial Contest —— C. Seats&#xff08;个人理解&#xff09;拓扑dfs 先放个传送门 https://codeforces.com/gym/105257/problem/C ———————————————————————————————————— 思路 可以看到&#xff0c;每一个编…

Vision Based Navigation :针对航天领域的基于视觉导航机器学习应用生成训练数据集

2024-09-18 由欧洲空间局主导&#xff0c;由空客防务与空间公司参与创建Vision Based Navigation &#xff0c; 为空间任务中的基于视觉导航&#xff08;VBN&#xff09;机器学习应用生成训练数据集。 目前遇到的困难和挑战 1、数据集的可用性和充分性&#xff1a; 挑战&#x…

BFS 解决多源最短路问题

文章目录 多源BFS542. 01 矩阵题目解析算法原理代码实现 1020. 飞地的数量题目解析算法原理 1765. 地图中的最高点题目解析算法原理代码实现 1162. 地图分析题目解析算法原理代码实现 多源BFS 单源最短路&#xff1a; 一个起点、一个终点 多源最短路&#xff1a; 可以多个起点…

9.安卓逆向-安卓开发基础-安卓四大组件2

免责声明&#xff1a;内容仅供学习参考&#xff0c;请合法利用知识&#xff0c;禁止进行违法犯罪活动&#xff01; 内容参考于&#xff1a;图灵Python学院 本人写的内容纯属胡编乱造&#xff0c;全都是合成造假&#xff0c;仅仅只是为了娱乐&#xff0c;请不要盲目相信。 工…

[云服务器13] 如何正确选择云服务器?

【非广告&#xff0c;仅提供建议&#xff0c;没有强制消费引导】 这期我们不讲搭建教程了&#xff0c;因为我想到前面12篇的教程&#xff0c;有关套餐配置的教程好像都有点敷衍…… 所以这期我们主要来说一说服务器的配置选择和不同配置的应用场景。 网站: 雨云 打开后&…

基于ZigBee的农业大棚信息采集系统设计

过去的农业大棚种植中大多需要依靠经验来实现浇水施肥等工作&#xff0c;无法根据天气的变化做出顺应的改变&#xff0c;也就造成了大棚内种植农作物的产量和质量很难得到保证。伴随着物联网与农业种植的结合&#xff0c;基于ZigBee通信和传感器等技术开发一款能监测大棚内环境…

Linux:路径末尾加/和不加/的区别

相关阅读 Linuxhttps://blog.csdn.net/weixin_45791458/category_12234591.html?spm1001.2014.3001.5482 普通文件操作 首先说明这个问题只会出现在目录和符号链接中&#xff0c;因为如果想要索引普通文件但却在路径末尾加/则会出现错误&#xff0c;如例1所示。 # 例1 zhang…

free源码

文章目录 free源码调试main_arena结构&#xff1a;free_hooktcachetcache的结构free_chunk进入tcache&#xff1a; fastbinunlink 合并top free源码调试 main_arena结构&#xff1a; 整体看一下main_arena的结构&#xff1a; free_hook free_hook&#xff0c;在glibc-2.3…

在 Windows 11 中,可以通过修改注册表来更改系统的自动更新时间设置

regedit 计算机\HKEY_LOCAL_MACHINE\SOFTWARE\Microsoft\WindowsUpdate\UX\Settings FlightSettingsMaxPauseDays 36524