成果专门可视化代码
This commit is contained in:
parent
a9e76e9b47
commit
b173c29000
319
src/eval.py
Normal file
319
src/eval.py
Normal file
@ -0,0 +1,319 @@
|
|||||||
|
import numpy as np
|
||||||
|
import random
|
||||||
|
import pygame
|
||||||
|
import sys
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
# 初始化pygame
|
||||||
|
pygame.init()
|
||||||
|
|
||||||
|
# 设置窗口大小
|
||||||
|
GRID_SIZE = 128
|
||||||
|
CELL_SIZE = 10
|
||||||
|
WINDOW_SIZE = GRID_SIZE * CELL_SIZE
|
||||||
|
screen = pygame.display.set_mode((WINDOW_SIZE, WINDOW_SIZE))
|
||||||
|
pygame.display.set_caption('Snake Game')
|
||||||
|
|
||||||
|
# 定义颜色
|
||||||
|
BACKGROUND = (0, 0, 0)
|
||||||
|
SNAKE_COLOR = (0, 255, 0)
|
||||||
|
FOOD_COLOR = (255, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
# 定义环境类
|
||||||
|
class SnakeEnv:
|
||||||
|
def __init__(self, size=12):
|
||||||
|
self.size = size
|
||||||
|
self.reset()
|
||||||
|
self.record = {
|
||||||
|
"survive": 0,
|
||||||
|
"wall": 0,
|
||||||
|
"starve": 0,
|
||||||
|
"self": 0,
|
||||||
|
"max steps": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
head = (random.randint(0, self.size - 1), random.randint(0, self.size - 1))
|
||||||
|
self.snake = [head, (head[0] - 1, head[1]), (head[0] - 2, head[1])]
|
||||||
|
self.food = self.generate_food()
|
||||||
|
self.steps = 0
|
||||||
|
self.death = False
|
||||||
|
self.length = 3
|
||||||
|
self.action_memory = []
|
||||||
|
self.direction = 'right' # 初始化方向为右
|
||||||
|
self.prev_distance = self.calculate_distance() # 初始化上一状态的距离
|
||||||
|
return self.get_state()
|
||||||
|
|
||||||
|
def generate_food(self):
|
||||||
|
while True:
|
||||||
|
food = (random.randint(0, self.size - 1), random.randint(0, self.size - 1))
|
||||||
|
if food not in self.snake:
|
||||||
|
return food
|
||||||
|
|
||||||
|
def calculate_distance(self):
|
||||||
|
head = self.snake[0]
|
||||||
|
food = self.food
|
||||||
|
return np.sqrt((head[0] - food[0]) ** 2 + (head[1] - food[1]) ** 2)
|
||||||
|
|
||||||
|
def step(self, action):
|
||||||
|
reward = 0.0
|
||||||
|
head = self.snake[0]
|
||||||
|
if action == 0: # 保持方向
|
||||||
|
reward += 0.5
|
||||||
|
pass
|
||||||
|
elif action == 1: # 左转
|
||||||
|
if self.direction == 'up':
|
||||||
|
self.direction = 'left'
|
||||||
|
elif self.direction == 'down':
|
||||||
|
self.direction = 'right'
|
||||||
|
elif self.direction == 'left':
|
||||||
|
self.direction = 'down'
|
||||||
|
elif self.direction == 'right':
|
||||||
|
self.direction = 'up'
|
||||||
|
elif action == 2: # 右转
|
||||||
|
if self.direction == 'up':
|
||||||
|
self.direction = 'right'
|
||||||
|
elif self.direction == 'down':
|
||||||
|
self.direction = 'left'
|
||||||
|
elif self.direction == 'left':
|
||||||
|
self.direction = 'up'
|
||||||
|
elif self.direction == 'right':
|
||||||
|
self.direction = 'down'
|
||||||
|
|
||||||
|
# 根据方向移动蛇头
|
||||||
|
if self.direction == 'up':
|
||||||
|
new_head = (head[0] - 1, head[1])
|
||||||
|
elif self.direction == 'down':
|
||||||
|
new_head = (head[0] + 1, head[1])
|
||||||
|
elif self.direction == 'left':
|
||||||
|
new_head = (head[0], head[1] - 1)
|
||||||
|
elif self.direction == 'right':
|
||||||
|
new_head = (head[0], head[1] + 1)
|
||||||
|
|
||||||
|
# 检查是否撞墙
|
||||||
|
if new_head[0] < 0 or new_head[0] >= self.size or new_head[1] < 0 or new_head[1] >= self.size:
|
||||||
|
self.record['wall'] += 1
|
||||||
|
self.death = True
|
||||||
|
return self.get_state(), - GRID_SIZE * GRID_SIZE, True
|
||||||
|
|
||||||
|
# 检查是否撞到自己
|
||||||
|
if new_head in self.snake[:-1]:
|
||||||
|
self.record['self'] += 1
|
||||||
|
self.death = True
|
||||||
|
return self.get_state(), - GRID_SIZE * GRID_SIZE, True
|
||||||
|
|
||||||
|
self.snake.insert(0, new_head)
|
||||||
|
|
||||||
|
# 检查是否吃到食物
|
||||||
|
if new_head == self.food:
|
||||||
|
self.length += 1
|
||||||
|
self.food = self.generate_food()
|
||||||
|
reward += GRID_SIZE * 3
|
||||||
|
else:
|
||||||
|
self.snake.pop()
|
||||||
|
reward += -1
|
||||||
|
|
||||||
|
# 计算当前距离
|
||||||
|
current_distance = self.calculate_distance()
|
||||||
|
distance_change = self.prev_distance - current_distance
|
||||||
|
self.prev_distance = current_distance
|
||||||
|
|
||||||
|
# 添加距离变化的奖励
|
||||||
|
if distance_change > 0:
|
||||||
|
reward += 0.5
|
||||||
|
elif distance_change < 0:
|
||||||
|
reward -= 0.5
|
||||||
|
|
||||||
|
self.steps += 1
|
||||||
|
|
||||||
|
# 每250步失去一格身体
|
||||||
|
if self.steps % (GRID_SIZE * 3) == 0 and self.length >= 1:
|
||||||
|
self.length -= 1
|
||||||
|
if self.length == 0:
|
||||||
|
self.record['starve'] += 1
|
||||||
|
self.death = True
|
||||||
|
return self.get_state(), - GRID_SIZE * GRID_SIZE, True
|
||||||
|
self.snake.pop()
|
||||||
|
|
||||||
|
# # 一局最多10000步
|
||||||
|
# if self.steps >= 10000:
|
||||||
|
# self.record['survive'] += 1
|
||||||
|
# self.death = True
|
||||||
|
# return self.get_state(), 0, True
|
||||||
|
|
||||||
|
done = self.death
|
||||||
|
return self.get_state(), reward, done
|
||||||
|
|
||||||
|
def get_state(self):
|
||||||
|
head = self.snake[0]
|
||||||
|
food = self.food
|
||||||
|
direction = self.direction
|
||||||
|
grid = np.zeros((3, 3))
|
||||||
|
for i in range(-1, 2):
|
||||||
|
for j in range(-1, 2):
|
||||||
|
x = head[0] + i
|
||||||
|
y = head[1] + j
|
||||||
|
if x < 0 or x >= self.size or y < 0 or y >= self.size:
|
||||||
|
grid[i + 1][j + 1] = 1 # 墙壁
|
||||||
|
elif (x, y) in self.snake:
|
||||||
|
grid[i + 1][j + 1] = 2 # 蛇身
|
||||||
|
elif (x, y) == self.food:
|
||||||
|
grid[i + 1][j + 1] = 3 # 食物
|
||||||
|
else:
|
||||||
|
grid[i + 1][j + 1] = 0 # 空地
|
||||||
|
memory = self.action_memory[-5:] if len(self.action_memory) >= 5 else self.action_memory + [0] * (
|
||||||
|
5 - len(self.action_memory))
|
||||||
|
food_direction = self.get_food_direction()
|
||||||
|
direction_one_hot = self.get_direction_one_hot()
|
||||||
|
return tuple(grid.flatten().tolist() + memory + list(food_direction))
|
||||||
|
|
||||||
|
def get_food_direction(self):
|
||||||
|
head = self.snake[0]
|
||||||
|
food = self.food
|
||||||
|
dx = food[0] - head[0]
|
||||||
|
dy = food[1] - head[1]
|
||||||
|
if dx > 0:
|
||||||
|
fx = 1
|
||||||
|
elif dx < 0:
|
||||||
|
fx = -1
|
||||||
|
else:
|
||||||
|
fx = 0
|
||||||
|
if dy > 0:
|
||||||
|
fy = 1
|
||||||
|
elif dy < 0:
|
||||||
|
fy = -1
|
||||||
|
else:
|
||||||
|
fy = 0
|
||||||
|
return (fx, fy)
|
||||||
|
|
||||||
|
def get_direction_one_hot(self):
|
||||||
|
direction = self.direction
|
||||||
|
if direction == 'up':
|
||||||
|
return (1, 0, 0, 0)
|
||||||
|
elif direction == 'down':
|
||||||
|
return (0, 1, 0, 0)
|
||||||
|
elif direction == 'left':
|
||||||
|
return (0, 0, 1, 0)
|
||||||
|
elif direction == 'right':
|
||||||
|
return (0, 0, 0, 1)
|
||||||
|
else:
|
||||||
|
return (0, 0, 0, 0)
|
||||||
|
|
||||||
|
|
||||||
|
# 定义Q学习类
|
||||||
|
class QLearning:
|
||||||
|
def __init__(self, actions=[0, 1, 2], epsilon=1.0, alpha=0.1, gamma=0.99, memory_size=1000, q_table=None):
|
||||||
|
self.actions = actions
|
||||||
|
self.epsilon = epsilon
|
||||||
|
self.alpha = alpha
|
||||||
|
self.gamma = gamma
|
||||||
|
self.q_table = {}
|
||||||
|
if q_table is None:
|
||||||
|
self.q_table = {}
|
||||||
|
else:
|
||||||
|
print("load qtable!")
|
||||||
|
self.q_table = q_table
|
||||||
|
self.memory = []
|
||||||
|
self.memory_size = memory_size
|
||||||
|
|
||||||
|
def choose_action(self, state):
|
||||||
|
state_tuple = state
|
||||||
|
if state_tuple not in self.q_table:
|
||||||
|
self.q_table[state_tuple] = [0] * len(self.actions)
|
||||||
|
if random.uniform(0, 1) < self.epsilon:
|
||||||
|
action = random.choice(self.actions)
|
||||||
|
else:
|
||||||
|
action = np.argmax(self.q_table[state_tuple])
|
||||||
|
return action
|
||||||
|
|
||||||
|
def learn(self, state, action, reward, next_state, done):
|
||||||
|
state_tuple = state
|
||||||
|
next_state_tuple = next_state
|
||||||
|
if next_state_tuple not in self.q_table:
|
||||||
|
self.q_table[next_state_tuple] = [0] * len(self.actions)
|
||||||
|
q_predict = self.q_table[state_tuple][action]
|
||||||
|
q_target = reward + self.gamma * max(self.q_table[next_state_tuple]) if not done else reward
|
||||||
|
self.q_table[state_tuple][action] += self.alpha * (q_target - q_predict)
|
||||||
|
self.remember(state_tuple, action, reward, next_state_tuple, done)
|
||||||
|
self.replay()
|
||||||
|
|
||||||
|
def remember(self, state, action, reward, next_state, done):
|
||||||
|
if len(self.memory) > self.memory_size:
|
||||||
|
del self.memory[0]
|
||||||
|
self.memory.append((state, action, reward, next_state, done))
|
||||||
|
|
||||||
|
def replay(self, batch_size=32):
|
||||||
|
if len(self.memory) < batch_size:
|
||||||
|
return
|
||||||
|
batch = random.sample(self.memory, batch_size)
|
||||||
|
for state, action, reward, next_state, done in batch:
|
||||||
|
q_predict = self.q_table[state][action]
|
||||||
|
q_target = reward + self.gamma * max(self.q_table[next_state]) if not done else reward
|
||||||
|
self.q_table[state][action] += self.alpha * (q_target - q_predict)
|
||||||
|
|
||||||
|
def decay_epsilon(self, min_epsilon=0.001, decay_rate=0.9999):
|
||||||
|
if self.epsilon > min_epsilon:
|
||||||
|
self.epsilon *= decay_rate
|
||||||
|
|
||||||
|
|
||||||
|
# 训练函数
|
||||||
|
def train(env, agent, episodes=10000, visualize=False):
|
||||||
|
for episode in range(episodes):
|
||||||
|
state = env.reset()
|
||||||
|
done = False
|
||||||
|
if episode % 5000 == 0 or visualize and episode % 1 == 0:
|
||||||
|
print(episode)
|
||||||
|
print(env.record)
|
||||||
|
while not done:
|
||||||
|
action = agent.choose_action(state)
|
||||||
|
env.action_memory.append(action)
|
||||||
|
next_state, reward, done = env.step(action)
|
||||||
|
agent.learn(state, action, reward, next_state, done)
|
||||||
|
state = next_state
|
||||||
|
# if reward == 0 and done:
|
||||||
|
# print(episode, "survive!")
|
||||||
|
draw_env(env, visualize and episode % 1 == 0)
|
||||||
|
env.record['max steps'] = max(env.record['max steps'], env.steps)
|
||||||
|
agent.decay_epsilon()
|
||||||
|
if episode % 100000 == 0 and not visualize:
|
||||||
|
with open(f'q_table_{episode}.pkl', 'wb') as f:
|
||||||
|
pickle.dump(agent.q_table, f)
|
||||||
|
print(env.record)
|
||||||
|
|
||||||
|
|
||||||
|
# 可视化函数
|
||||||
|
def draw_env(env, visualize=True):
|
||||||
|
if visualize:
|
||||||
|
screen.fill(BACKGROUND)
|
||||||
|
for body in env.snake:
|
||||||
|
pygame.draw.rect(screen, SNAKE_COLOR, (body[1] * CELL_SIZE, body[0] * CELL_SIZE, CELL_SIZE, CELL_SIZE))
|
||||||
|
pygame.draw.rect(screen, FOOD_COLOR, (env.food[1] * CELL_SIZE, env.food[0] * CELL_SIZE, CELL_SIZE, CELL_SIZE))
|
||||||
|
pygame.display.flip()
|
||||||
|
|
||||||
|
# 添加延迟
|
||||||
|
pygame.time.delay(5)
|
||||||
|
|
||||||
|
# 处理事件
|
||||||
|
for event in pygame.event.get():
|
||||||
|
if event.type == pygame.QUIT:
|
||||||
|
pygame.quit()
|
||||||
|
sys.exit()
|
||||||
|
|
||||||
|
|
||||||
|
# 主函数
|
||||||
|
def main():
|
||||||
|
env = SnakeEnv(GRID_SIZE)
|
||||||
|
actions = [0, 1, 2] # 保持、左转、右转
|
||||||
|
visualize = True
|
||||||
|
# if visualize:
|
||||||
|
with open('q_table_200000.pkl', 'rb') as f:
|
||||||
|
q_table = pickle.load(f)
|
||||||
|
agent = QLearning(actions, q_table=q_table if visualize else None, epsilon=0.0001 if visualize else 1.0)
|
||||||
|
# agent = QLearning(actions)
|
||||||
|
train(env, agent, episodes=1000000, visualize=visualize) # 关闭可视化
|
||||||
|
pygame.quit()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
Loading…
x
Reference in New Issue
Block a user