成果专门可视化代码

This commit is contained in:
EndersOwner 2025-01-04 00:17:16 +08:00
parent a9e76e9b47
commit b173c29000

319
src/eval.py Normal file
View File

@ -0,0 +1,319 @@
import numpy as np
import random
import pygame
import sys
import pickle
# 初始化pygame
pygame.init()
# 设置窗口大小
GRID_SIZE = 128
CELL_SIZE = 10
WINDOW_SIZE = GRID_SIZE * CELL_SIZE
screen = pygame.display.set_mode((WINDOW_SIZE, WINDOW_SIZE))
pygame.display.set_caption('Snake Game')
# 定义颜色
BACKGROUND = (0, 0, 0)
SNAKE_COLOR = (0, 255, 0)
FOOD_COLOR = (255, 0, 0)
# 定义环境类
class SnakeEnv:
def __init__(self, size=12):
self.size = size
self.reset()
self.record = {
"survive": 0,
"wall": 0,
"starve": 0,
"self": 0,
"max steps": 0,
}
def reset(self):
head = (random.randint(0, self.size - 1), random.randint(0, self.size - 1))
self.snake = [head, (head[0] - 1, head[1]), (head[0] - 2, head[1])]
self.food = self.generate_food()
self.steps = 0
self.death = False
self.length = 3
self.action_memory = []
self.direction = 'right' # 初始化方向为右
self.prev_distance = self.calculate_distance() # 初始化上一状态的距离
return self.get_state()
def generate_food(self):
while True:
food = (random.randint(0, self.size - 1), random.randint(0, self.size - 1))
if food not in self.snake:
return food
def calculate_distance(self):
head = self.snake[0]
food = self.food
return np.sqrt((head[0] - food[0]) ** 2 + (head[1] - food[1]) ** 2)
def step(self, action):
reward = 0.0
head = self.snake[0]
if action == 0: # 保持方向
reward += 0.5
pass
elif action == 1: # 左转
if self.direction == 'up':
self.direction = 'left'
elif self.direction == 'down':
self.direction = 'right'
elif self.direction == 'left':
self.direction = 'down'
elif self.direction == 'right':
self.direction = 'up'
elif action == 2: # 右转
if self.direction == 'up':
self.direction = 'right'
elif self.direction == 'down':
self.direction = 'left'
elif self.direction == 'left':
self.direction = 'up'
elif self.direction == 'right':
self.direction = 'down'
# 根据方向移动蛇头
if self.direction == 'up':
new_head = (head[0] - 1, head[1])
elif self.direction == 'down':
new_head = (head[0] + 1, head[1])
elif self.direction == 'left':
new_head = (head[0], head[1] - 1)
elif self.direction == 'right':
new_head = (head[0], head[1] + 1)
# 检查是否撞墙
if new_head[0] < 0 or new_head[0] >= self.size or new_head[1] < 0 or new_head[1] >= self.size:
self.record['wall'] += 1
self.death = True
return self.get_state(), - GRID_SIZE * GRID_SIZE, True
# 检查是否撞到自己
if new_head in self.snake[:-1]:
self.record['self'] += 1
self.death = True
return self.get_state(), - GRID_SIZE * GRID_SIZE, True
self.snake.insert(0, new_head)
# 检查是否吃到食物
if new_head == self.food:
self.length += 1
self.food = self.generate_food()
reward += GRID_SIZE * 3
else:
self.snake.pop()
reward += -1
# 计算当前距离
current_distance = self.calculate_distance()
distance_change = self.prev_distance - current_distance
self.prev_distance = current_distance
# 添加距离变化的奖励
if distance_change > 0:
reward += 0.5
elif distance_change < 0:
reward -= 0.5
self.steps += 1
# 每250步失去一格身体
if self.steps % (GRID_SIZE * 3) == 0 and self.length >= 1:
self.length -= 1
if self.length == 0:
self.record['starve'] += 1
self.death = True
return self.get_state(), - GRID_SIZE * GRID_SIZE, True
self.snake.pop()
# # 一局最多10000步
# if self.steps >= 10000:
# self.record['survive'] += 1
# self.death = True
# return self.get_state(), 0, True
done = self.death
return self.get_state(), reward, done
def get_state(self):
head = self.snake[0]
food = self.food
direction = self.direction
grid = np.zeros((3, 3))
for i in range(-1, 2):
for j in range(-1, 2):
x = head[0] + i
y = head[1] + j
if x < 0 or x >= self.size or y < 0 or y >= self.size:
grid[i + 1][j + 1] = 1 # 墙壁
elif (x, y) in self.snake:
grid[i + 1][j + 1] = 2 # 蛇身
elif (x, y) == self.food:
grid[i + 1][j + 1] = 3 # 食物
else:
grid[i + 1][j + 1] = 0 # 空地
memory = self.action_memory[-5:] if len(self.action_memory) >= 5 else self.action_memory + [0] * (
5 - len(self.action_memory))
food_direction = self.get_food_direction()
direction_one_hot = self.get_direction_one_hot()
return tuple(grid.flatten().tolist() + memory + list(food_direction))
def get_food_direction(self):
head = self.snake[0]
food = self.food
dx = food[0] - head[0]
dy = food[1] - head[1]
if dx > 0:
fx = 1
elif dx < 0:
fx = -1
else:
fx = 0
if dy > 0:
fy = 1
elif dy < 0:
fy = -1
else:
fy = 0
return (fx, fy)
def get_direction_one_hot(self):
direction = self.direction
if direction == 'up':
return (1, 0, 0, 0)
elif direction == 'down':
return (0, 1, 0, 0)
elif direction == 'left':
return (0, 0, 1, 0)
elif direction == 'right':
return (0, 0, 0, 1)
else:
return (0, 0, 0, 0)
# 定义Q学习类
class QLearning:
def __init__(self, actions=[0, 1, 2], epsilon=1.0, alpha=0.1, gamma=0.99, memory_size=1000, q_table=None):
self.actions = actions
self.epsilon = epsilon
self.alpha = alpha
self.gamma = gamma
self.q_table = {}
if q_table is None:
self.q_table = {}
else:
print("load qtable!")
self.q_table = q_table
self.memory = []
self.memory_size = memory_size
def choose_action(self, state):
state_tuple = state
if state_tuple not in self.q_table:
self.q_table[state_tuple] = [0] * len(self.actions)
if random.uniform(0, 1) < self.epsilon:
action = random.choice(self.actions)
else:
action = np.argmax(self.q_table[state_tuple])
return action
def learn(self, state, action, reward, next_state, done):
state_tuple = state
next_state_tuple = next_state
if next_state_tuple not in self.q_table:
self.q_table[next_state_tuple] = [0] * len(self.actions)
q_predict = self.q_table[state_tuple][action]
q_target = reward + self.gamma * max(self.q_table[next_state_tuple]) if not done else reward
self.q_table[state_tuple][action] += self.alpha * (q_target - q_predict)
self.remember(state_tuple, action, reward, next_state_tuple, done)
self.replay()
def remember(self, state, action, reward, next_state, done):
if len(self.memory) > self.memory_size:
del self.memory[0]
self.memory.append((state, action, reward, next_state, done))
def replay(self, batch_size=32):
if len(self.memory) < batch_size:
return
batch = random.sample(self.memory, batch_size)
for state, action, reward, next_state, done in batch:
q_predict = self.q_table[state][action]
q_target = reward + self.gamma * max(self.q_table[next_state]) if not done else reward
self.q_table[state][action] += self.alpha * (q_target - q_predict)
def decay_epsilon(self, min_epsilon=0.001, decay_rate=0.9999):
if self.epsilon > min_epsilon:
self.epsilon *= decay_rate
# 训练函数
def train(env, agent, episodes=10000, visualize=False):
for episode in range(episodes):
state = env.reset()
done = False
if episode % 5000 == 0 or visualize and episode % 1 == 0:
print(episode)
print(env.record)
while not done:
action = agent.choose_action(state)
env.action_memory.append(action)
next_state, reward, done = env.step(action)
agent.learn(state, action, reward, next_state, done)
state = next_state
# if reward == 0 and done:
# print(episode, "survive!")
draw_env(env, visualize and episode % 1 == 0)
env.record['max steps'] = max(env.record['max steps'], env.steps)
agent.decay_epsilon()
if episode % 100000 == 0 and not visualize:
with open(f'q_table_{episode}.pkl', 'wb') as f:
pickle.dump(agent.q_table, f)
print(env.record)
# 可视化函数
def draw_env(env, visualize=True):
if visualize:
screen.fill(BACKGROUND)
for body in env.snake:
pygame.draw.rect(screen, SNAKE_COLOR, (body[1] * CELL_SIZE, body[0] * CELL_SIZE, CELL_SIZE, CELL_SIZE))
pygame.draw.rect(screen, FOOD_COLOR, (env.food[1] * CELL_SIZE, env.food[0] * CELL_SIZE, CELL_SIZE, CELL_SIZE))
pygame.display.flip()
# 添加延迟
pygame.time.delay(5)
# 处理事件
for event in pygame.event.get():
if event.type == pygame.QUIT:
pygame.quit()
sys.exit()
# 主函数
def main():
env = SnakeEnv(GRID_SIZE)
actions = [0, 1, 2] # 保持、左转、右转
visualize = True
# if visualize:
with open('q_table_200000.pkl', 'rb') as f:
q_table = pickle.load(f)
agent = QLearning(actions, q_table=q_table if visualize else None, epsilon=0.0001 if visualize else 1.0)
# agent = QLearning(actions)
train(env, agent, episodes=1000000, visualize=visualize) # 关闭可视化
pygame.quit()
if __name__ == '__main__':
main()