训练代码修改
修改参数扩大场地和寿命 让蛇更倾向于保持直线 加大饿死惩罚 修改可视化参数 增加加载逻辑
This commit is contained in:
parent
1bc08f2f97
commit
0e26a7a52e
40
src/game.py
40
src/game.py
@ -57,8 +57,10 @@ class SnakeEnv:
|
|||||||
return np.sqrt((head[0] - food[0])**2 + (head[1] - food[1])**2)
|
return np.sqrt((head[0] - food[0])**2 + (head[1] - food[1])**2)
|
||||||
|
|
||||||
def step(self, action):
|
def step(self, action):
|
||||||
|
reward = 0.0
|
||||||
head = self.snake[0]
|
head = self.snake[0]
|
||||||
if action == 0: # 保持方向
|
if action == 0: # 保持方向
|
||||||
|
reward += 0.5
|
||||||
pass
|
pass
|
||||||
elif action == 1: # 左转
|
elif action == 1: # 左转
|
||||||
if self.direction == 'up':
|
if self.direction == 'up':
|
||||||
@ -107,10 +109,10 @@ class SnakeEnv:
|
|||||||
if new_head == self.food:
|
if new_head == self.food:
|
||||||
self.length += 1
|
self.length += 1
|
||||||
self.food = self.generate_food()
|
self.food = self.generate_food()
|
||||||
reward = GRID_SIZE * 3
|
reward += GRID_SIZE * 3
|
||||||
else:
|
else:
|
||||||
self.snake.pop()
|
self.snake.pop()
|
||||||
reward = -1
|
reward += -1
|
||||||
|
|
||||||
# 计算当前距离
|
# 计算当前距离
|
||||||
current_distance = self.calculate_distance()
|
current_distance = self.calculate_distance()
|
||||||
@ -131,14 +133,14 @@ class SnakeEnv:
|
|||||||
if self.length == 0:
|
if self.length == 0:
|
||||||
self.record['starve'] += 1
|
self.record['starve'] += 1
|
||||||
self.death = True
|
self.death = True
|
||||||
return self.get_state(), -10, True
|
return self.get_state(), - GRID_SIZE * GRID_SIZE, True
|
||||||
self.snake.pop()
|
self.snake.pop()
|
||||||
|
|
||||||
# # 一局最多1000步
|
# 一局最多10000步
|
||||||
# if self.steps >= 1000:
|
if self.steps >= 10000:
|
||||||
# self.record['survive'] += 1
|
self.record['survive'] += 1
|
||||||
# self.death = True
|
self.death = True
|
||||||
# return self.get_state(), 0, True
|
return self.get_state(), 0, True
|
||||||
|
|
||||||
done = self.death
|
done = self.death
|
||||||
return self.get_state(), reward, done
|
return self.get_state(), reward, done
|
||||||
@ -248,7 +250,7 @@ class QLearning:
|
|||||||
q_target = reward + self.gamma * max(self.q_table[next_state]) if not done else reward
|
q_target = reward + self.gamma * max(self.q_table[next_state]) if not done else reward
|
||||||
self.q_table[state][action] += self.alpha * (q_target - q_predict)
|
self.q_table[state][action] += self.alpha * (q_target - q_predict)
|
||||||
|
|
||||||
def decay_epsilon(self, min_epsilon=0.001, decay_rate=0.99999):
|
def decay_epsilon(self, min_epsilon=0.001, decay_rate=0.9999):
|
||||||
if self.epsilon > min_epsilon:
|
if self.epsilon > min_epsilon:
|
||||||
self.epsilon *= decay_rate
|
self.epsilon *= decay_rate
|
||||||
|
|
||||||
@ -257,7 +259,7 @@ def train(env, agent, episodes=10000, visualize=False):
|
|||||||
for episode in range(episodes):
|
for episode in range(episodes):
|
||||||
state = env.reset()
|
state = env.reset()
|
||||||
done = False
|
done = False
|
||||||
if episode % 50000 == 0 or visualize:
|
if episode % 5000 == 0 or visualize and episode % 1 == 0:
|
||||||
print(episode)
|
print(episode)
|
||||||
print(env.record)
|
print(env.record)
|
||||||
while not done:
|
while not done:
|
||||||
@ -268,10 +270,10 @@ def train(env, agent, episodes=10000, visualize=False):
|
|||||||
state = next_state
|
state = next_state
|
||||||
# if reward == 0 and done:
|
# if reward == 0 and done:
|
||||||
# print(episode, "survive!")
|
# print(episode, "survive!")
|
||||||
draw_env(env, visualize)
|
draw_env(env, visualize and episode % 1 == 0)
|
||||||
env.record['max steps'] = max(env.record['max steps'], env.steps)
|
env.record['max steps'] = max(env.record['max steps'], env.steps)
|
||||||
agent.decay_epsilon()
|
agent.decay_epsilon()
|
||||||
if episode % 1000000 == 0:
|
if episode % 100000 == 0:
|
||||||
with open(f'q_table_{episode}.pkl', 'wb') as f:
|
with open(f'q_table_{episode}.pkl', 'wb') as f:
|
||||||
pickle.dump(agent.q_table, f)
|
pickle.dump(agent.q_table, f)
|
||||||
print(env.record)
|
print(env.record)
|
||||||
@ -286,7 +288,7 @@ def draw_env(env, visualize=True):
|
|||||||
pygame.display.flip()
|
pygame.display.flip()
|
||||||
|
|
||||||
# 添加延迟
|
# 添加延迟
|
||||||
pygame.time.delay(10)
|
pygame.time.delay(5)
|
||||||
|
|
||||||
# 处理事件
|
# 处理事件
|
||||||
for event in pygame.event.get():
|
for event in pygame.event.get():
|
||||||
@ -298,13 +300,13 @@ def draw_env(env, visualize=True):
|
|||||||
def main():
|
def main():
|
||||||
env = SnakeEnv(GRID_SIZE)
|
env = SnakeEnv(GRID_SIZE)
|
||||||
actions = [0, 1, 2] # 保持、左转、右转
|
actions = [0, 1, 2] # 保持、左转、右转
|
||||||
visualize = True
|
visualize = False
|
||||||
if visualize:
|
# if visualize:
|
||||||
with open('q_table_1000000.pkl', 'rb') as f:
|
with open('q_table_1000000.pkl', 'rb') as f:
|
||||||
q_table = pickle.load(f)
|
q_table = pickle.load(f)
|
||||||
agent = QLearning(actions, q_table=q_table if visualize else None, epsilon=0.0 if visualize else 1.0)
|
agent = QLearning(actions, q_table=q_table if visualize else None, epsilon=0.001 if visualize else 1.0)
|
||||||
# agent = QLearning(actions)
|
# agent = QLearning(actions)
|
||||||
train(env, agent, episodes=10000000, visualize=visualize) # 关闭可视化
|
train(env, agent, episodes=1000000, visualize=visualize) # 关闭可视化
|
||||||
pygame.quit()
|
pygame.quit()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
Loading…
x
Reference in New Issue
Block a user