From 0e26a7a52e6c6a89bec03c0c37f65021b1616a04 Mon Sep 17 00:00:00 2001 From: EndersOwner <1353708863@qq.com> Date: Sat, 4 Jan 2025 00:15:40 +0800 Subject: [PATCH] =?UTF-8?q?=E8=AE=AD=E7=BB=83=E4=BB=A3=E7=A0=81=E4=BF=AE?= =?UTF-8?q?=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 修改参数扩大场地和寿命 让蛇更倾向于保持直线 加大饿死惩罚 修改可视化参数 增加加载逻辑 --- src/game.py | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/src/game.py b/src/game.py index 3d95286..7664c6e 100644 --- a/src/game.py +++ b/src/game.py @@ -57,8 +57,10 @@ class SnakeEnv: return np.sqrt((head[0] - food[0])**2 + (head[1] - food[1])**2) def step(self, action): + reward = 0.0 head = self.snake[0] if action == 0: # 保持方向 + reward += 0.5 pass elif action == 1: # 左转 if self.direction == 'up': @@ -107,10 +109,10 @@ class SnakeEnv: if new_head == self.food: self.length += 1 self.food = self.generate_food() - reward = GRID_SIZE * 3 + reward += GRID_SIZE * 3 else: self.snake.pop() - reward = -1 + reward += -1 # 计算当前距离 current_distance = self.calculate_distance() @@ -131,14 +133,14 @@ class SnakeEnv: if self.length == 0: self.record['starve'] += 1 self.death = True - return self.get_state(), -10, True + return self.get_state(), - GRID_SIZE * GRID_SIZE, True self.snake.pop() - # # 一局最多1000步 - # if self.steps >= 1000: - # self.record['survive'] += 1 - # self.death = True - # return self.get_state(), 0, True + # 一局最多10000步 + if self.steps >= 10000: + self.record['survive'] += 1 + self.death = True + return self.get_state(), 0, True done = self.death return self.get_state(), reward, done @@ -248,7 +250,7 @@ class QLearning: q_target = reward + self.gamma * max(self.q_table[next_state]) if not done else reward self.q_table[state][action] += self.alpha * (q_target - q_predict) - def decay_epsilon(self, min_epsilon=0.001, decay_rate=0.99999): + def decay_epsilon(self, min_epsilon=0.001, decay_rate=0.9999): if self.epsilon > min_epsilon: self.epsilon *= decay_rate @@ -257,7 +259,7 @@ def train(env, agent, episodes=10000, visualize=False): for episode in range(episodes): state = env.reset() done = False - if episode % 50000 == 0 or visualize: + if episode % 5000 == 0 or visualize and episode % 1 == 0: print(episode) print(env.record) while not done: @@ -268,10 +270,10 @@ def train(env, agent, episodes=10000, visualize=False): state = next_state # if reward == 0 and done: # print(episode, "survive!") - draw_env(env, visualize) + draw_env(env, visualize and episode % 1 == 0) env.record['max steps'] = max(env.record['max steps'], env.steps) agent.decay_epsilon() - if episode % 1000000 == 0: + if episode % 100000 == 0: with open(f'q_table_{episode}.pkl', 'wb') as f: pickle.dump(agent.q_table, f) print(env.record) @@ -286,7 +288,7 @@ def draw_env(env, visualize=True): pygame.display.flip() # 添加延迟 - pygame.time.delay(10) + pygame.time.delay(5) # 处理事件 for event in pygame.event.get(): @@ -298,13 +300,13 @@ def draw_env(env, visualize=True): def main(): env = SnakeEnv(GRID_SIZE) actions = [0, 1, 2] # 保持、左转、右转 - visualize = True - if visualize: - with open('q_table_1000000.pkl', 'rb') as f: - q_table = pickle.load(f) - agent = QLearning(actions, q_table=q_table if visualize else None, epsilon=0.0 if visualize else 1.0) + visualize = False + # if visualize: + with open('q_table_1000000.pkl', 'rb') as f: + q_table = pickle.load(f) + agent = QLearning(actions, q_table=q_table if visualize else None, epsilon=0.001 if visualize else 1.0) # agent = QLearning(actions) - train(env, agent, episodes=10000000, visualize=visualize) # 关闭可视化 + train(env, agent, episodes=1000000, visualize=visualize) # 关闭可视化 pygame.quit() if __name__ == '__main__':