训练代码修改

修改参数扩大场地和寿命
让蛇更倾向于保持直线
加大饿死惩罚
修改可视化参数
增加加载逻辑
This commit is contained in:
EndersOwner 2025-01-04 00:15:40 +08:00
parent 1bc08f2f97
commit 0e26a7a52e

View File

@ -57,8 +57,10 @@ class SnakeEnv:
return np.sqrt((head[0] - food[0])**2 + (head[1] - food[1])**2)
def step(self, action):
reward = 0.0
head = self.snake[0]
if action == 0: # 保持方向
reward += 0.5
pass
elif action == 1: # 左转
if self.direction == 'up':
@ -107,10 +109,10 @@ class SnakeEnv:
if new_head == self.food:
self.length += 1
self.food = self.generate_food()
reward = GRID_SIZE * 3
reward += GRID_SIZE * 3
else:
self.snake.pop()
reward = -1
reward += -1
# 计算当前距离
current_distance = self.calculate_distance()
@ -131,14 +133,14 @@ class SnakeEnv:
if self.length == 0:
self.record['starve'] += 1
self.death = True
return self.get_state(), -10, True
return self.get_state(), - GRID_SIZE * GRID_SIZE, True
self.snake.pop()
# # 一局最多1000步
# if self.steps >= 1000:
# self.record['survive'] += 1
# self.death = True
# return self.get_state(), 0, True
# 一局最多10000步
if self.steps >= 10000:
self.record['survive'] += 1
self.death = True
return self.get_state(), 0, True
done = self.death
return self.get_state(), reward, done
@ -248,7 +250,7 @@ class QLearning:
q_target = reward + self.gamma * max(self.q_table[next_state]) if not done else reward
self.q_table[state][action] += self.alpha * (q_target - q_predict)
def decay_epsilon(self, min_epsilon=0.001, decay_rate=0.99999):
def decay_epsilon(self, min_epsilon=0.001, decay_rate=0.9999):
if self.epsilon > min_epsilon:
self.epsilon *= decay_rate
@ -257,7 +259,7 @@ def train(env, agent, episodes=10000, visualize=False):
for episode in range(episodes):
state = env.reset()
done = False
if episode % 50000 == 0 or visualize:
if episode % 5000 == 0 or visualize and episode % 1 == 0:
print(episode)
print(env.record)
while not done:
@ -268,10 +270,10 @@ def train(env, agent, episodes=10000, visualize=False):
state = next_state
# if reward == 0 and done:
# print(episode, "survive!")
draw_env(env, visualize)
draw_env(env, visualize and episode % 1 == 0)
env.record['max steps'] = max(env.record['max steps'], env.steps)
agent.decay_epsilon()
if episode % 1000000 == 0:
if episode % 100000 == 0:
with open(f'q_table_{episode}.pkl', 'wb') as f:
pickle.dump(agent.q_table, f)
print(env.record)
@ -286,7 +288,7 @@ def draw_env(env, visualize=True):
pygame.display.flip()
# 添加延迟
pygame.time.delay(10)
pygame.time.delay(5)
# 处理事件
for event in pygame.event.get():
@ -298,13 +300,13 @@ def draw_env(env, visualize=True):
def main():
env = SnakeEnv(GRID_SIZE)
actions = [0, 1, 2] # 保持、左转、右转
visualize = True
if visualize:
visualize = False
# if visualize:
with open('q_table_1000000.pkl', 'rb') as f:
q_table = pickle.load(f)
agent = QLearning(actions, q_table=q_table if visualize else None, epsilon=0.0 if visualize else 1.0)
agent = QLearning(actions, q_table=q_table if visualize else None, epsilon=0.001 if visualize else 1.0)
# agent = QLearning(actions)
train(env, agent, episodes=10000000, visualize=visualize) # 关闭可视化
train(env, agent, episodes=1000000, visualize=visualize) # 关闭可视化
pygame.quit()
if __name__ == '__main__':