DQN强化训练agent玩是男人就下xx层小游戏

游戏代码参考Python是男人就下一百层小游戏源代码_是男人就下一百层完整代码python-CSDN博客

 在游戏作者代码基础上修改了下使该游戏在失败后能自动重新开始,方便后续能不间断训练

    def reset_game(self):self.score = 0self.end = Falseself.last = 6 * SIDEself.dire = 0self.barrier.clear()self.barrier.append(Barrier(self.screen, SOLID))self.body = pygame.Rect(self.barrier[0].rect.center[0], 200, SIDE, SIDE)
  • 增加了 reset_game 方法

    • 此方法用于重置游戏状态,包括分数、结束标志、障碍物列表以及玩家的位置。
  • show_end 方法中添加了重启逻辑

    • 当游戏结束时,先显示结束状态,然后等待 2 秒钟(通过 pygame.time.delay(2000)),然后调用 reset_game 方法重新开始游戏。

游戏代码如下

#!python3
# -*- coding: utf-8 -*-
'''
公众号:Python代码大全
'''
from random import choice, randint
import pygame
from sys import exitSCORE = 0
SOLID = 1
FRAGILE = 2
DEADLY = 3
BELT_LEFT = 4
BELT_RIGHT = 5
BODY = 6GAME_ROW = 40
GAME_COL = 28
OBS_WIDTH = GAME_COL // 4
SIDE = 13
SCREEN_WIDTH = SIDE*GAME_COL
SCREEN_HEIGHT = SIDE*GAME_ROW
COLOR = {SOLID: 0x00ffff, FRAGILE: 0xff5500, DEADLY: 0xff2222, SCORE: 0xcccccc,BELT_LEFT: 0xffff44, BELT_RIGHT: 0xff99ff, BODY: 0x00ff00}
CHOICE = [SOLID, SOLID, SOLID, FRAGILE, FRAGILE, BELT_LEFT, BELT_RIGHT, DEADLY]class Game(object):def __init__(self, title, size, fps=30):self.size = sizepygame.init()self.screen = pygame.display.set_mode(size, 0, 32)pygame.display.set_caption(title)self.keys = {}self.keys_up = {}self.clicks = {}self.timer = pygame.time.Clock()self.fps = fpsself.score = 0self.end = Falseself.fullscreen = Falseself.last_time = pygame.time.get_ticks()self.is_pause = Falseself.is_draw = Trueself.score_font = pygame.font.SysFont("Calibri", 130, True)def bind_key(self, key, action):if isinstance(key, list):for k in key:self.keys[k] = actionelif isinstance(key, int):self.keys[key] = actiondef bind_key_up(self, key, action):if isinstance(key, list):for k in key:self.keys_up[k] = actionelif isinstance(key, int):self.keys_up[key] = actiondef bind_click(self, button, action):self.clicks[button] = actiondef pause(self, key):self.is_pause = not self.is_pausedef set_fps(self, fps):self.fps = fpsdef handle_input(self, event):if event.type == pygame.QUIT:pygame.quit()exit()if event.type == pygame.KEYDOWN:if event.key in self.keys.keys():self.keys[event.key](event.key)if event.key == pygame.K_F11:                           # F11全屏self.fullscreen = not self.fullscreenif self.fullscreen:self.screen = pygame.display.set_mode(self.size, pygame.FULLSCREEN, 32)else:self.screen = pygame.display.set_mode(self.size, 0, 32)if event.type == pygame.KEYUP:if event.key in self.keys_up.keys():self.keys_up[event.key](event.key)if event.type == pygame.MOUSEBUTTONDOWN:if event.button in self.clicks.keys():self.clicks[event.button](*event.pos)def run(self):while True:for event in pygame.event.get():self.handle_input(event)self.timer.tick(self.fps)self.update(pygame.time.get_ticks())self.draw(pygame.time.get_ticks())def draw_score(self, color, rect=None):score = self.score_font.render(str(self.score), True, color)if rect is None:r = self.screen.get_rect()rect = score.get_rect(center=r.center)self.screen.blit(score, rect)def is_end(self):return self.enddef update(self, current_time):passdef draw(self, current_time):passclass Barrier(object):def __init__(self, screen, opt=None):self.screen = screenif opt is None:self.type = choice(CHOICE)else:self.type = optself.frag_touch = Falseself.frag_time = 12self.score = Falseself.belt_dire = 0self.belt_dire = pygame.K_LEFT if self.type == BELT_LEFT else pygame.K_RIGHTleft = randint(0, SCREEN_WIDTH - 7 * SIDE - 1)top = SCREEN_HEIGHT - SIDE - 1self.rect = pygame.Rect(left, top, 7*SIDE, SIDE)def rise(self):if self.frag_touch:self.frag_time -= 1if self.frag_time == 0:return Falseself.rect.top -= 2return self.rect.top >= 0def draw_side(self, x, y):if self.type == SOLID:rect = pygame.Rect(x, y, SIDE, SIDE)self.screen.fill(COLOR[SOLID], rect)elif self.type == FRAGILE:rect = pygame.Rect(x+2, y, SIDE-4, SIDE)self.screen.fill(COLOR[FRAGILE], rect)elif self.type == BELT_LEFT or self.type == BELT_RIGHT:rect = pygame.Rect(x, y, SIDE, SIDE)pygame.draw.circle(self.screen, COLOR[self.type], rect.center, SIDE // 2 + 1)elif self.type == DEADLY:p1 = (x + SIDE//2 + 1, y)p2 = (x, y + SIDE)p3 = (x + SIDE, y + SIDE)points = [p1, p2, p3]pygame.draw.polygon(self.screen, COLOR[DEADLY], points)def draw(self):for i in range(7):self.draw_side(i*SIDE+self.rect.left, self.rect.top)class Hell(Game):def __init__(self, title, size, fps=60):super(Hell, self).__init__(title, size, fps)self.last = 6 * SIDEself.dire = 0self.barrier = [Barrier(self.screen, SOLID)]self.body = pygame.Rect(self.barrier[0].rect.center[0], 200, SIDE, SIDE)self.bind_key([pygame.K_LEFT, pygame.K_RIGHT], self.move)self.bind_key_up([pygame.K_LEFT, pygame.K_RIGHT], self.unmove)self.bind_key(pygame.K_SPACE, self.pause)def move(self, key):self.dire = keydef unmove(self, key):self.dire = 0def reset_game(self):self.score = 0self.end = Falseself.last = 6 * SIDEself.dire = 0self.barrier.clear()self.barrier.append(Barrier(self.screen, SOLID))self.body = pygame.Rect(self.barrier[0].rect.center[0], 200, SIDE, SIDE)def show_end(self):self.draw(0, end=True)self.end = Trueself.reset_game()def move_man(self, dire):if dire == 0:return Truerect = self.body.copy()if dire == pygame.K_LEFT:rect.left -= 1else:rect.left += 1if rect.left < 0 or rect.left + SIDE >= SCREEN_WIDTH:return Falsefor ba in self.barrier:if rect.colliderect(ba.rect):return Falseself.body = rectreturn Truedef get_score(self, ba):if self.body.top > ba.rect.top and not ba.score:self.score += 1ba.score = Truedef to_hell(self):self.body.top += 2for ba in self.barrier:if not self.body.colliderect(ba.rect):self.get_score(ba)continueif ba.type == DEADLY:self.show_end()returnself.body.top = ba.rect.top - SIDE - 2if ba.type == FRAGILE:ba.frag_touch = Trueelif ba.type == BELT_LEFT or ba.type == BELT_RIGHT:# self.body.left += ba.belt_direself.move_man(ba.belt_dire)breaktop = self.body.topif top < 0 or top+SIDE >= SCREEN_HEIGHT:self.show_end()def create_barrier(self):solid = list(filter(lambda ba: ba.type == SOLID, self.barrier))if len(solid) < 1:self.barrier.append(Barrier(self.screen, SOLID))else:self.barrier.append(Barrier(self.screen))self.last = randint(3, 5) * SIDEdef update(self, current_time):if self.end or self.is_pause:returnself.last -= 1if self.last == 0:self.create_barrier()for ba in self.barrier:if not ba.rise():if ba.type == FRAGILE and ba.rect.top > 0:self.score += 1self.barrier.remove(ba)self.move_man(self.dire)self.to_hell()def draw(self, current_time, end=False):if self.end or self.is_pause:returnself.screen.fill(0x000000)self.draw_score((0x3c, 0x3c, 0x3c))for ba in self.barrier:ba.draw()if not end:self.screen.fill(COLOR[BODY], self.body)else:self.screen.fill(COLOR[DEADLY], self.body)pygame.display.update()def hex2rgb(color):b = color % 256color = color >> 8g = color % 256color = color >> 8r = color % 256return (r, g, b)if __name__ == '__main__':hell = Hell("是男人就下一百层", (SCREEN_WIDTH, SCREEN_HEIGHT))hell.run()

下面是选择合适的强化学习算子,考虑使用 深度 Q 学习(DQN)。DQN 是一种结合了深度学习和 Q 学习的算法,适用于具有高维状态空间的环境。

先找个DQN算法的pytorch模板

import numpy as np
import random
import torch
import torch.nn as nn
import torch.optim as optim
from collections import dequeclass DQNAgent(nn.Module):def __init__(self, state_size, action_size):super(DQNAgent, self).__init__()self.state_size = state_sizeself.action_size = action_sizeself.memory = deque(maxlen=2000)self.gamma = 0.95  # discount rateself.epsilon = 1.0  # exploration rateself.epsilon_min = 0.01self.epsilon_decay = 0.995self.model = self._build_model()self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)self.loss_fn = nn.MSELoss()def _build_model(self):model = nn.Sequential(nn.Linear(self.state_size, 24),  # 第一层nn.ReLU(),nn.Linear(24, 24),  # 第二层nn.ReLU(),nn.Linear(24, self.action_size)  # 输出层)return modeldef remember(self, state, action, reward, next_state, done):self.memory.append((state, action, reward, next_state, done))def act(self, state):if np.random.rand() <= self.epsilon:return random.randrange(self.action_size)  # 随机选择动作with torch.no_grad():state_tensor = torch.FloatTensor(state)act_values = self.model(state_tensor)return np.argmax(act_values.numpy())def replay(self, batch_size):minibatch = random.sample(self.memory, batch_size)for state, action, reward, next_state, done in minibatch:target = rewardif not done:next_state_tensor = torch.FloatTensor(next_state)target += self.gamma * torch.max(self.model(next_state_tensor)).item()target_f = self.model(torch.FloatTensor(state))target_f[action] = targetself.optimizer.zero_grad()loss = self.loss_fn(target_f, self.model(torch.FloatTensor(state)))loss.backward()self.optimizer.step()if self.epsilon > self.epsilon_min:self.epsilon *= self.epsilon_decay# 环境类的示例
class Env:def __init__(self):# 初始化游戏环境passdef reset(self):# 重置游戏passdef step(self, action):# 执行动作,返回下一个状态、奖励、是否结束等passdef get_state(self):# 返回当前状态pass# 训练主循环
if __name__ == "__main__":env = Env()  # 实例化游戏环境state_size = 3  # 根据状态特征数量调整action_size = 3  # 根据实际动作数量调整agent = DQNAgent(state_size, action_size)episodes = 1000for e in range(episodes):state = env.reset()state = np.reshape(state, [1, state_size])for time in range(500):action = agent.act(state)next_state, reward, done, _ = env.step(action)reward = reward if not done else -10  # 奖励调整next_state = np.reshape(next_state, [1, state_size])agent.remember(state, action, reward, next_state, done)state = next_stateif done:print(f"Episode: {e}/{episodes}, Score: {time}, Epsilon: {agent.epsilon:.2}")breakif len(agent.memory) > 32:agent.replay(32)

结合游戏代码开始编写训练脚本

import osimport numpy as np
import randomimport pygame
import torch
import torch.nn as nn
import torch.optim as optim
from collections import dequefrom getdown import Hell, SCREEN_WIDTH, SCREEN_HEIGHTclass DQNAgent(nn.Module):def __init__(self, state_size, action_size):super(DQNAgent, self).__init__()self.state_size = state_sizeself.action_size = action_sizeself.memory = deque(maxlen=2000)self.gamma = 0.95  # discount rateself.epsilon = 1.0  # exploration rateself.epsilon_min = 0.01self.epsilon_decay = 0.995self.model = self._build_model()self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)self.loss_fn = nn.MSELoss()def _build_model(self):model = nn.Sequential(nn.Linear(self.state_size, 24),  # 第一层nn.ReLU(),nn.Linear(24, 24),  # 第二层nn.ReLU(),nn.Linear(24, self.action_size)  # 输出层)return modeldef remember(self, state, action, reward, next_state, done):self.memory.append((state, action, reward, next_state, done))def act(self, state):if np.random.rand() <= self.epsilon:return random.randrange(self.action_size)  # 随机选择动作with torch.no_grad():state_tensor = torch.FloatTensor(state)act_values = self.model(state_tensor)return np.argmax(act_values.numpy())def replay(self, batch_size):minibatch = random.sample(self.memory, batch_size)for state, action, reward, next_state, done in minibatch:target = rewardif not done:next_state_tensor = torch.FloatTensor(next_state)target += self.gamma * torch.max(self.model(next_state_tensor)).item()target_f = self.model(torch.FloatTensor(state))target_f = target_f.squeeze()target_f[action] = targetself.optimizer.zero_grad()loss = self.loss_fn(target_f, self.model(torch.FloatTensor(state)))loss.backward()self.optimizer.step()if self.epsilon > self.epsilon_min:self.epsilon *= self.epsilon_decaydef save_model(self, file_name):torch.save(self.model.state_dict(), file_name)  # 保存模型的参数def load_model(self, file_name):if os.path.exists(file_name):self.model.load_state_dict(torch.load(file_name))  # 加载模型的参数print(f"Model loaded from {file_name}")else:print(f"No model found at {file_name}, starting from scratch.")# 环境类的示例
class Env:def __init__(self, hell):self.hell = hell # 创建游戏实例self.state_size = 9  # 根据状态特征数量调整self.action_size = 3  # 左, 右, 不动def reset(self):# 重置游戏self.hell.reset()return self.get_state()def step(self, action):if action == 0:  # Move leftself.hell.move(pygame.K_LEFT)elif action == 1:  # Move rightself.hell.move(pygame.K_RIGHT)else:  # Stay stillself.hell.unmove(None)self.hell.update(pygame.time.get_ticks())  # 更新游戏状态# 获取状态、奖励和结束信息state = self.get_state()reward = self.hell.scoredone = self.hell.endreturn state, reward, done, {}def get_state(self):state = []state.append(self.hell.body.x)  # 玩家 x 坐标state.append(self.hell.body.y)  # 玩家 y 坐标state.append(len(self.hell.barrier))  # 障碍物数量# 记录最多 2 个障碍物的信息max_barriers = 2for i in range(max_barriers):if i < len(self.hell.barrier):ba = self.hell.barrier[i]state.append(ba.rect.x)  # 障碍物 x 坐标state.append(ba.rect.y)  # 障碍物 y 坐标state.append(ba.type)  # 障碍物类型else:# 如果没有障碍物,用零填充state.extend([0, 0, 0])# 确保状态的长度与 state_size 一致return np.array(state)# 训练主循环
if __name__ == "__main__":env = Env(Hell("是男人就下一百层", (SCREEN_WIDTH, SCREEN_HEIGHT)))  # 初始化强化学习环境agent = DQNAgent(env.state_size, env.action_size)  # 创建 DQN 代理model_path = "getdown_hell_model.h5"  # 你可以根据需要更改路径agent.load_model(model_path)total_steps = 0  # 初始化总步数while True:  # 无限训练state = env.reset()state = np.reshape(state, [1, env.state_size])for time in range(500):action = agent.act(state)next_state, reward, done, _ = env.step(action)reward = reward if not done else -10  # 奖励调整next_state = np.reshape(next_state, [1, env.state_size])agent.remember(state, action, reward, next_state, done)state = next_statetotal_steps += 1if done:print(f"Score: {time}, Total Steps: {total_steps}, Epsilon: {agent.epsilon:.2}")break# 在每1000步时保存模型if total_steps % 1000 == 0:agent.save_model("getdown_hell_model.h5")  # 保存模型if len(agent.memory) > 32:agent.replay(32)

增加断点续训功能,上面代码状态空间和奖惩函数设计肯定是不全面,先跑起来训练看看

发现强化训练陷入到一个局部最优即一次性落入底部,这样获取score的奖励,它没有考虑还有可能获取更多得分的可能性,奖惩函数需要修改。不过现在我们先尝试修改游戏代码,把模型接入到控制里看看效果

#!python3
# -*- coding: utf-8 -*-
'''
公众号:Python代码大全
'''
from random import choice, randintimport numpy as np
import pygame
from sys import exitfrom getdown_dqn import DQNAgent, EnvSCORE = 0
SOLID = 1
FRAGILE = 2
DEADLY = 3
BELT_LEFT = 4
BELT_RIGHT = 5
BODY = 6GAME_ROW = 40
GAME_COL = 28
OBS_WIDTH = GAME_COL // 4
SIDE = 13
SCREEN_WIDTH = SIDE*GAME_COL
SCREEN_HEIGHT = SIDE*GAME_ROW
COLOR = {SOLID: 0x00ffff, FRAGILE: 0xff5500, DEADLY: 0xff2222, SCORE: 0xcccccc,BELT_LEFT: 0xffff44, BELT_RIGHT: 0xff99ff, BODY: 0x00ff00}
CHOICE = [SOLID, SOLID, SOLID, FRAGILE, FRAGILE, BELT_LEFT, BELT_RIGHT, DEADLY]class Game(object):def __init__(self, title, size, fps=30):self.size = sizepygame.init()self.screen = pygame.display.set_mode(size, 0, 32)pygame.display.set_caption(title)self.keys = {}self.keys_up = {}self.clicks = {}self.timer = pygame.time.Clock()self.fps = fpsself.score = 0self.end = Falseself.fullscreen = Falseself.last_time = pygame.time.get_ticks()self.is_pause = Falseself.is_draw = Trueself.score_font = pygame.font.SysFont("Calibri", 130, True)def bind_key(self, key, action):if isinstance(key, list):for k in key:self.keys[k] = actionelif isinstance(key, int):self.keys[key] = actiondef bind_key_up(self, key, action):if isinstance(key, list):for k in key:self.keys_up[k] = actionelif isinstance(key, int):self.keys_up[key] = actiondef bind_click(self, button, action):self.clicks[button] = actiondef pause(self, key):self.is_pause = not self.is_pausedef set_fps(self, fps):self.fps = fpsdef handle_input(self, event):if event.type == pygame.QUIT:pygame.quit()exit()if event.type == pygame.KEYDOWN:if event.key in self.keys.keys():self.keys[event.key](event.key)if event.key == pygame.K_F11:                           # F11全屏self.fullscreen = not self.fullscreenif self.fullscreen:self.screen = pygame.display.set_mode(self.size, pygame.FULLSCREEN, 32)else:self.screen = pygame.display.set_mode(self.size, 0, 32)if event.type == pygame.KEYUP:if event.key in self.keys_up.keys():self.keys_up[event.key](event.key)if event.type == pygame.MOUSEBUTTONDOWN:if event.button in self.clicks.keys():self.clicks[event.button](*event.pos)def run(self):while True:state = env.get_state()state = np.reshape(state, [1, env.state_size])action = agent.act(state)if action == 0:  # Move leftself.handle_input(simulate_key_press(pygame.K_LEFT))elif action == 1:  # Move rightself.handle_input(simulate_key_press(pygame.K_RIGHT))self.timer.tick(self.fps)self.update(pygame.time.get_ticks())self.draw(pygame.time.get_ticks())def draw_score(self, color, rect=None):score = self.score_font.render(str(self.score), True, color)if rect is None:r = self.screen.get_rect()rect = score.get_rect(center=r.center)self.screen.blit(score, rect)def is_end(self):return self.enddef get_state(self):return self.enddef update(self, current_time):passdef draw(self, current_time):passclass Barrier(object):def __init__(self, screen, opt=None):self.screen = screenif opt is None:self.type = choice(CHOICE)else:self.type = optself.frag_touch = Falseself.frag_time = 12self.score = Falseself.belt_dire = 0self.belt_dire = pygame.K_LEFT if self.type == BELT_LEFT else pygame.K_RIGHTleft = randint(0, SCREEN_WIDTH - 7 * SIDE - 1)top = SCREEN_HEIGHT - SIDE - 1self.rect = pygame.Rect(left, top, 7*SIDE, SIDE)def rise(self):if self.frag_touch:self.frag_time -= 1if self.frag_time == 0:return Falseself.rect.top -= 2return self.rect.top >= 0def draw_side(self, x, y):if self.type == SOLID:rect = pygame.Rect(x, y, SIDE, SIDE)self.screen.fill(COLOR[SOLID], rect)elif self.type == FRAGILE:rect = pygame.Rect(x+2, y, SIDE-4, SIDE)self.screen.fill(COLOR[FRAGILE], rect)elif self.type == BELT_LEFT or self.type == BELT_RIGHT:rect = pygame.Rect(x, y, SIDE, SIDE)pygame.draw.circle(self.screen, COLOR[self.type], rect.center, SIDE // 2 + 1)elif self.type == DEADLY:p1 = (x + SIDE//2 + 1, y)p2 = (x, y + SIDE)p3 = (x + SIDE, y + SIDE)points = [p1, p2, p3]pygame.draw.polygon(self.screen, COLOR[DEADLY], points)def draw(self):for i in range(7):self.draw_side(i*SIDE+self.rect.left, self.rect.top)class Hell(Game):def __init__(self, title, size, fps=60):super(Hell, self).__init__(title, size, fps)self.last = 6 * SIDEself.dire = 0self.barrier = [Barrier(self.screen, SOLID)]self.body = pygame.Rect(self.barrier[0].rect.center[0], 200, SIDE, SIDE)self.bind_key([pygame.K_LEFT, pygame.K_RIGHT], self.move)self.bind_key_up([pygame.K_LEFT, pygame.K_RIGHT], self.unmove)self.bind_key(pygame.K_SPACE, self.pause)def move(self, key):self.dire = keydef unmove(self, key):self.dire = 0def reset(self):self.score = 0self.end = Falseself.last = 6 * SIDEself.dire = 0self.barrier.clear()self.barrier.append(Barrier(self.screen, SOLID))self.body = pygame.Rect(self.barrier[0].rect.center[0], 200, SIDE, SIDE)def show_end(self):self.draw(0, end=True)self.end = Trueself.reset()def move_man(self, dire):if dire == 0:return Truerect = self.body.copy()if dire == pygame.K_LEFT:rect.left -= 1else:rect.left += 1if rect.left < 0 or rect.left + SIDE >= SCREEN_WIDTH:return Falsefor ba in self.barrier:if rect.colliderect(ba.rect):return Falseself.body = rectreturn Truedef get_score(self, ba):if self.body.top > ba.rect.top and not ba.score:self.score += 1ba.score = Truedef to_hell(self):self.body.top += 2for ba in self.barrier:if not self.body.colliderect(ba.rect):self.get_score(ba)continueif ba.type == DEADLY:self.show_end()returnself.body.top = ba.rect.top - SIDE - 2if ba.type == FRAGILE:ba.frag_touch = Trueelif ba.type == BELT_LEFT or ba.type == BELT_RIGHT:# self.body.left += ba.belt_direself.move_man(ba.belt_dire)breaktop = self.body.topif top < 0 or top+SIDE >= SCREEN_HEIGHT:self.show_end()def create_barrier(self):solid = list(filter(lambda ba: ba.type == SOLID, self.barrier))if len(solid) < 1:self.barrier.append(Barrier(self.screen, SOLID))else:self.barrier.append(Barrier(self.screen))self.last = randint(3, 5) * SIDEdef update(self, current_time):if self.end or self.is_pause:returnself.last -= 1if self.last == 0:self.create_barrier()for ba in self.barrier:if not ba.rise():if ba.type == FRAGILE and ba.rect.top > 0:self.score += 1self.barrier.remove(ba)self.move_man(self.dire)self.to_hell()def draw(self, current_time, end=False):if self.end or self.is_pause:returnself.screen.fill(0x000000)self.draw_score((0x3c, 0x3c, 0x3c))for ba in self.barrier:ba.draw()if not end:self.screen.fill(COLOR[BODY], self.body)else:self.screen.fill(COLOR[DEADLY], self.body)pygame.display.update()def simulate_key_press(key):event = pygame.event.Event(pygame.KEYDOWN, key=key)return eventdef hex2rgb(color):b = color % 256color = color >> 8g = color % 256color = color >> 8r = color % 256return (r, g, b)if __name__ == '__main__':hell = Hell("是男人就下一百层", (SCREEN_WIDTH, SCREEN_HEIGHT))env = Env(hell)agent = DQNAgent(env.state_size, env.action_size)model_path = "getdown_hell_model.h5"  # 你可以根据需要更改路径agent.load_model(model_path)# 开始游戏hell.run()

先试试控制效果,模型接入控制

如果只根据分数(例如 self.hell.score)进行训练,可能会导致强化学习模型无法有效地捕捉到长程反馈。这是因为单一的奖励信号可能不足以鼓励代理在较长的时间跨度内采取合适的行动。为了改善这一点,优化奖励函数是非常重要的。

优化奖励函数的建议

  1. 分段奖励:考虑根据特定事件或状态变化给予额外奖励。例如,当物体成功离开本级台阶或达到下面某级台阶时,可以提供额外奖励。

    if agent_avoided_obstacle:reward += 10  # 例如,成功避开障碍物时给予奖励
    
  2. 负奖励:对于不良行为,给予负奖励。例如,当物体碰到带刺障碍失败时,给予负奖励。这可以帮助代理学习避免这些行为。

    if agent_hit_obstacle:reward -= 10  # 碰到障碍物时给予惩罚
    
  3. 使用时间奖励:为每个时间步骤提供小的正奖励,以鼓励持续进行游戏。

    reward += 0.1  # 每个时间步骤给予小奖励
    
  4. 引入长程奖励:可以通过使用折扣因子(通常用 gamma 表示)来考虑未来奖励的影响。在计算目标值时,考虑未来的奖励。

    target += self.gamma * next_value
    
  5. 状态变化奖励:根据状态的变化给予奖励。例如,当物体到达新的区域或发现新的障碍物时,可以给予奖励。

根据上面列的几点修改奖励函数

    def compute_reward(self, action):reward = self.hell.score  # 基于当前分数的奖励body = self.hell.bodybarrier = self.hell.barriertarget_y = body.y + body.h + 2matching_barriers = [ba for ba in barrierif ba.rect.y == target_y and ba.rect.x < body.x < (ba.rect.x + ba.rect.width)]# 当物体成功离开本级台阶或达到下面某级台阶时,可以提供额外奖励。if matching_barriers:left_distance = body.x - matching_barriers[0].rect.xright_distance = matching_barriers[0].rect.x + matching_barriers[0].rect.width - body.x# 说明在台面上移动if left_distance < right_distance and action == 0:reward += 0.1elif left_distance > right_distance and action == 1:reward += 0.1else:reward -= 0.1thres_hold = 100matching_barriers = [ba for ba in barrierif 0 < (ba.rect.y - body.y) < thres_hold and ba.rect.x < body.x < (ba.rect.x + ba.rect.width)]# 对于不良行为,给予负奖励。例如,下方快碰到带刺的障碍时if matching_barriers and matching_barriers[0].type == 2:reward -= 5else:reward += 3# 当物体到达新的区域或发现新的障碍物时,可以给予奖励。if self.preview_barrier_num < len(self.hell.barrier):self.preview_barrier_num = len(self.hell.barrier)reward += 1else:reward -= 0.5# 增加下落时朝向障碍物的奖励falling_towards_barrier = any(ba.rect.x < body.x < (ba.rect.x + ba.rect.width) and ba.rect.y > body.yfor ba in barrier)if falling_towards_barrier:reward += 2# 为每个时间步骤提供小的正奖励,以鼓励持续进行游戏。reward += 0.1return reward

训练代码增加奖励(reward)线图来监控训练过程,使用 matplotlib 库来绘制图形。

if __name__ == "__main__":env = Env(Hell("是男人就下一百层", (SCREEN_WIDTH, SCREEN_HEIGHT)))  # 初始化强化学习环境agent = DQNAgent(env.state_size, env.action_size)  # 创建 DQN 代理model_path = "getdown_hell_model.h5"  # 你可以根据需要更改路径agent.load_model(model_path)total_steps = 0  # 初始化总步数total_game_num = 0rewards = []  # 用于记录每个回合的总奖励try:state = env.reset()while True:  # 无限训练state = np.reshape(state, [1, env.state_size])total_reward = 0  # 每个回合的总奖励#for time in range(1000):action = agent.act(state)next_state, reward, done, _ = env.step(action)reward = reward if not done else -10  # 奖励调整next_state = np.reshape(next_state, [1, env.state_size])agent.remember(state, action, reward, next_state, done)state = next_statetotal_steps += 1total_reward += reward  # 更新总奖励rewards.append(total_reward)# rewards 只保留一万条记录if len(rewards) > 10000:rewards.pop(0)if done:print(f"Total game num: {total_game_num},Total Steps: {total_steps}, total score: {env.hell.score}, Epsilon: {agent.epsilon:.7}")print(f'current step:{total_steps}, save getdown_hell_model.h5')agent.save_model("getdown_hell_model.h5")  # 保存模型total_game_num += 1env.hell.reset()if len(agent.memory) > 32:agent.replay(32)except KeyboardInterrupt:print('rewards', rewards)print("\nTraining interrupted. Saving model...")agent.save_model("getdown_hell_model.h5")  # 保存模型# 绘制奖励线图plt.plot(rewards)plt.title("Training Rewards Over Time")plt.xlabel("Episode")plt.ylabel("Total Reward")plt.savefig("training_rewards.png", format='png')

继续训练

日志输出

......

Total game num: 6,Total Steps: 176, total score: 4, Epsilon: 0.009986452
current step:176, save getdown_hell_model.h5
Total game num: 7,Total Steps: 200, total score: 3, Epsilon: 0.009986452
current step:200, save getdown_hell_model.h5
Total game num: 8,Total Steps: 234, total score: 4, Epsilon: 0.009986452
current step:234, save getdown_hell_model.h5
Total game num: 9,Total Steps: 265, total score: 5, Epsilon: 0.009986452
current step:265, save getdown_hell_model.h5
Total game num: 10,Total Steps: 288, total score: 6, Epsilon: 0.009986452
current step:288, save getdown_hell_model.h5
Total game num: 11,Total Steps: 258, total score: 5, Epsilon: 0.009986452
current step:258, save getdown_hell_model.h5
Total game num: 12,Total Steps: 884, total score: 17, Epsilon: 0.009986452
current step:884, save getdown_hell_model.h5
Total game num: 13,Total Steps: 221, total score: 4, Epsilon: 0.009986452

......

在训练过程中,通常会逐渐减少 epsilon 值,这种做法被称为 “epsilon decay”。这样可以确保代理在初始阶段有足够的探索能力,而随着训练的进行,逐渐更多地利用已学知识。一般流程如下:

  • 初始阶段:设置较高的 epsilon(如 1.0),鼓励代理进行大量探索。
  • 中期阶段:逐渐降低 epsilon,例如每个回合减少一个固定值或按指数衰减。
  • 后期阶段:将 epsilon 降到较低的最小值(如 0.01)以确保在训练后期仍然有少量的探索。

发现效果不明显,经过思考加了暴力奖惩,让物体必需停留在y轴区间内

# 判断物体所处的位置控制在100~400之间if 100 < body.y < 400:reward += 1elif 150 < body.y < 350:reward += 2elif 200 < body.y < 300:reward += 3else:reward -= 1

k8s上挂了pod持久训练,等有好的训练结果跑测试

代码提交在github地址

https://github.com/chenrui2200/getdown_hell_rl_train

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/12670.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

2024最新版JavaScript逆向爬虫教程-------基础篇之面向对象

目录 一、概念二、对象的创建和操作 2.1 JavaScript创建对象的方式2.2 对象属性操作的控制2.3 理解JavaScript创建对象 2.3.1 工厂模式2.3.2 构造函数2.3.3 原型构造函数 三、继承 3.1 通过原型链实现继承3.2 借用构造函数实现继承3.3 寄生组合式继承 3.3.1 对象的原型式继承…

网络编程示例之网络socket程序编程

注意&#xff1a;学习资料可在ElfBoard官方网站“资料”专区获取。 本节用到的源码路径&#xff1a;ELF 1开发板资料包->03-例程源码->03-1 命令行例程源码->05_elf1_cmd_net tcp协议程序 tcp_server.c 服务端仍然是按照如下顺序进行编写&#xff1a; socket()//创…

标准的渠道治理方法

在当今竞争激烈的市场环境中&#xff0c;品牌的渠道管理犹如一座大厦的基石&#xff0c;至关重要。而其中&#xff0c;对渠道价格的治理更是关键环节&#xff0c;直接关系到品牌的生死存亡与长远发展。 当品牌渠道中不幸出现低价、窜货链接时&#xff0c;一场关乎品牌未来走向…

双指针算法的妙用:提高代码效率的秘密(3)

双指针算法的妙用&#xff1a;提高代码效率的秘密&#xff08;3&#xff09; 前言&#xff1a; 小编在昨日讲述了关于双指针算法的两个题目&#xff0c;今日继续分享两个题目的解析&#xff0c;我相信&#xff0c;我只要坚持每天啥刷题&#xff0c;算法能力终究会提高的&…

动力商城-03 Idea集成apifox Mybatis-Plus字段策略

1.Idea下载apifox插件 2.新建令牌放入Idea 3.右键上传到对应接口 4.设置前置url 插件能够自动识别swagger注解 Mybatis-Plus字段策略 1、FieldStrategy作用 Mybatis-Plus字段策略FieldStrategy的作用主要是在进行新增、更新时&#xff0c;根据配置的策略判断是否对实体对…

11.11--final关键字和抽象类

一 java 1.final 关键字-----放在 访问修饰符后面 1&#xff09;防止被继承 2&#xff09;防止 父类方法 被重写 3&#xff09;防止 类中的 属性 被修改 4&#xff09;防止 局部属性 被修改 1.2.细节 1&#xff09;final 修饰属性 必须赋初值 ------------------------------…

IntelliJ+SpringBoot项目实战(三)---基于源代码直接生成漂亮的接口文档

在SpringBoot中可以集成代码插件自动生成接口文档&#xff0c;而且生成的文档很漂亮&#xff0c;除了接口功能介绍、传入参数、响应参数&#xff0c;还具体类似postman的功能&#xff0c;可调用接口进行测试&#xff0c;另外还可以下单WORD版、.md,html格式的文档。下面我们先看…

TemplatesImpl 在Shiro中的利用链学习1

一、前言 在前面的学习中&#xff0c;我们学习了CC1、CC6链&#xff0c;其中CC1链受限于Java8u71版本&#xff0c;而CC6则是通杀的利用链&#xff1b;后来又将 TemplateImpl 融入到 CommonsCollections 利用链中&#xff0c;绕过了 InvokerTransformer 不能使用的限制&#xf…

中仕公考:2025年省考请注意!

打算参加25年省考的考生们注意啦!如果打算参加2025年公务员省考&#xff0c;从这个时间点开始备考刚刚好&#xff0c;如果还不知道怎么备考的&#xff0c;看这篇就够了! 省考流程&#xff1a; 网上报名——资格审查——确认缴费——查看报名序号——准考证打印——笔试——成…

开发RAG应用,你必须知道的7个Embedding模型

在自然语言处理&#xff08;NLP&#xff09;领域&#xff0c;Embedding模型是将文本数据转化为数值向量的核心技术&#xff0c;从而让计算机能够便捷地衡量文本间的语义关联&#xff0c;这种表示法已成为多种基础NLP任务的核心&#xff0c;如文本相似度判定、语义搜索、信息检索…

基于Java+SpringBoot学生成绩管理系统

一、作品包含 源码数据库设计文档全套环境和工具资源部署教程 二、项目技术 前端技术&#xff1a;Html、Css、Js、Vue、Element-ui 数据库&#xff1a;MySQL 后端技术&#xff1a;Java、Spring Boot、MyBatis 三、运行环境 开发工具&#xff1a;IDEA/eclipse 数据库&…

Kong API网关,微服务架构中,你看到就不想错过的选型

今天&#xff0c;很多公司都采用微服务架构来处理复杂业务&#xff0c;但随着服务数量增加&#xff0c;API管理成了一项繁重任务。Kong API网关&#xff0c;作为一款高性能的开源API网关&#xff0c;给开发者带来了极大便利。它不仅可以简化API的调用和管理&#xff0c;还拥有丰…

计算机毕业设计 | springboot+vue汽车修理管理系统 汽修厂系统(附源码)

1&#xff0c;项目背景 在如今这个信息时代&#xff0c;“汽车维修管理系统” 这种维修方式已经为越来越多的人所接受。在这种背景之下&#xff0c;一个安全稳定并且强大的网络预约平台不可或缺&#xff0c;在这种成熟的市场需求的推动下&#xff0c;在先进的信息技术的支持下…

使用京东API接口进行支付结算有哪些注意事项?

用京东API接口进行支付结算时&#xff0c;需要注意以下几个事项&#xff1a; 遵守京东开放平台规定&#xff1a;在使用京东API接口时&#xff0c;必须遵守京东开放平台的相关规定&#xff0c;不得滥用接口或进行非法操作。 保护用户隐私&#xff1a;为了保护用户隐私&#xff…

全国宪法宣传周答题活动怎么做

在12月4日全国宪法宣传周即将到来之际&#xff0c;越来越多的企业单位开始举办线上知识竞赛答题活动&#xff0c;以下是一个知识竞赛答题小程序的基本功能&#xff1a; 一、了解活动信息&#xff1a;确定答题活动的开始时间、结束时间以及是否分阶段进行等。不同的答题活动时…

【debug】QT 相关问题error汇总 QT运行闪退 QT5升级到QT6注意要点

总结一下碰到过的所有问题error以及解决方案 如果这个文档未帮助到你&#xff0c;仍有bug未解决&#xff0c;可以在下方评论留言&#xff0c;有偿解决。 qt的UI更新之后构建后发现没有变化 取消项目中的Shadow build的勾选&#xff0c;作用是取消影子构建&#xff0c;此后构建目…

信捷 PLC C语言 POU 指示灯交替灭1秒亮1秒

1.在全局变量表中定义2个定时器变量timer1,timer2 名称 类型 timer1 TMR_FB False -- False False timer2 TMR_FB False -- False False ot BOOL False -- False False ot表示指示灯 2.新建pou…

【Linux进程篇3】说白了,Linux创建进程(fork父子进程)也就那样!!!

--------------------------------------------------------------------------------------------------------------------------------- 每日鸡汤&#xff1a;没人可以好运一生&#xff0c;只有努力才是一生的护身符&#xff0c;不放弃、不辜负。 -----------------------…

使用服务器时进行深度学习训练时,本地必须一直保持连接状态吗?

可以直接查看方法&#xff0c;不看背景 1.使用背景2. 方法2.1 screen命令介绍2.2 为什么要使用screen命令2.3 安装screen2.4 创建session2.5 查看session是否创建成功2.6 跳转进入session2.7 退出跑代码的session2.8 删除session 1.使用背景 我们在进行深度学习训练的时候&…

防火墙笔记地十二天

1.IPSEC协议簇 IPSEC协议簇 --- 基于网络层的&#xff0c;应用密码学的安全通信协议组 IPV6中&#xff0c;IPSEC是要求强制使用的&#xff0c;但是&#xff0c;IPV4中作为可选项使用 IPSEC可以提供的安全服务 机密性 --- 数据加密 完整性 --- 防篡改 可用性 数据源鉴别 -…