diff --git a/src/game.py b/src/game.py new file mode 100644 index 0000000..3d95286 --- /dev/null +++ b/src/game.py @@ -0,0 +1,311 @@ +import numpy as np +import random +import pygame +import sys +import pickle + +# 初始化pygame +pygame.init() + +# 设置窗口大小 +GRID_SIZE = 128 +CELL_SIZE = 10 +WINDOW_SIZE = GRID_SIZE * CELL_SIZE +screen = pygame.display.set_mode((WINDOW_SIZE, WINDOW_SIZE)) +pygame.display.set_caption('Snake Game') + +# 定义颜色 +BACKGROUND = (0, 0, 0) +SNAKE_COLOR = (0, 255, 0) +FOOD_COLOR = (255, 0, 0) + + +# 定义环境类 +class SnakeEnv: + def __init__(self, size=12): + self.size = size + self.reset() + self.record = { + "survive": 0, + "wall": 0, + "starve": 0, + "self": 0, + "max steps": 0, + } + + def reset(self): + head = (random.randint(0, self.size-1), random.randint(0, self.size-1)) + self.snake = [head, (head[0]-1, head[1]), (head[0]-2, head[1])] + self.food = self.generate_food() + self.steps = 0 + self.death = False + self.length = 3 + self.action_memory = [] + self.direction = 'right' # 初始化方向为右 + self.prev_distance = self.calculate_distance() # 初始化上一状态的距离 + return self.get_state() + + def generate_food(self): + while True: + food = (random.randint(0, self.size-1), random.randint(0, self.size-1)) + if food not in self.snake: + return food + + def calculate_distance(self): + head = self.snake[0] + food = self.food + return np.sqrt((head[0] - food[0])**2 + (head[1] - food[1])**2) + + def step(self, action): + head = self.snake[0] + if action == 0: # 保持方向 + pass + elif action == 1: # 左转 + if self.direction == 'up': + self.direction = 'left' + elif self.direction == 'down': + self.direction = 'right' + elif self.direction == 'left': + self.direction = 'down' + elif self.direction == 'right': + self.direction = 'up' + elif action == 2: # 右转 + if self.direction == 'up': + self.direction = 'right' + elif self.direction == 'down': + self.direction = 'left' + elif self.direction == 'left': + self.direction = 'up' + elif self.direction == 'right': + self.direction = 'down' + + # 根据方向移动蛇头 + if self.direction == 'up': + new_head = (head[0]-1, head[1]) + elif self.direction == 'down': + new_head = (head[0]+1, head[1]) + elif self.direction == 'left': + new_head = (head[0], head[1]-1) + elif self.direction == 'right': + new_head = (head[0], head[1]+1) + + # 检查是否撞墙 + if new_head[0] < 0 or new_head[0] >= self.size or new_head[1] < 0 or new_head[1] >= self.size: + self.record['wall'] += 1 + self.death = True + return self.get_state(), - GRID_SIZE * GRID_SIZE, True + + # 检查是否撞到自己 + if new_head in self.snake[:-1]: + self.record['self'] += 1 + self.death = True + return self.get_state(), - GRID_SIZE * GRID_SIZE, True + + self.snake.insert(0, new_head) + + # 检查是否吃到食物 + if new_head == self.food: + self.length += 1 + self.food = self.generate_food() + reward = GRID_SIZE * 3 + else: + self.snake.pop() + reward = -1 + + # 计算当前距离 + current_distance = self.calculate_distance() + distance_change = self.prev_distance - current_distance + self.prev_distance = current_distance + + # 添加距离变化的奖励 + if distance_change > 0: + reward += 0.5 + elif distance_change < 0: + reward -= 0.5 + + self.steps += 1 + + # 每250步失去一格身体 + if self.steps % (GRID_SIZE * 3) == 0 and self.length >= 1: + self.length -= 1 + if self.length == 0: + self.record['starve'] += 1 + self.death = True + return self.get_state(), -10, True + self.snake.pop() + + # # 一局最多1000步 + # if self.steps >= 1000: + # self.record['survive'] += 1 + # self.death = True + # return self.get_state(), 0, True + + done = self.death + return self.get_state(), reward, done + + def get_state(self): + head = self.snake[0] + food = self.food + direction = self.direction + grid = np.zeros((3, 3)) + for i in range(-1, 2): + for j in range(-1, 2): + x = head[0] + i + y = head[1] + j + if x < 0 or x >= self.size or y < 0 or y >= self.size: + grid[i+1][j+1] = 1 # 墙壁 + elif (x, y) in self.snake: + grid[i+1][j+1] = 2 # 蛇身 + elif (x, y) == self.food: + grid[i+1][j+1] = 3 # 食物 + else: + grid[i+1][j+1] = 0 # 空地 + memory = self.action_memory[-5:] if len(self.action_memory) >= 5 else self.action_memory + [0]*(5 - len(self.action_memory)) + food_direction = self.get_food_direction() + direction_one_hot = self.get_direction_one_hot() + return tuple(grid.flatten().tolist() + memory + list(food_direction)) + + def get_food_direction(self): + head = self.snake[0] + food = self.food + dx = food[0] - head[0] + dy = food[1] - head[1] + if dx > 0: + fx = 1 + elif dx < 0: + fx = -1 + else: + fx = 0 + if dy > 0: + fy = 1 + elif dy < 0: + fy = -1 + else: + fy = 0 + return (fx, fy) + + def get_direction_one_hot(self): + direction = self.direction + if direction == 'up': + return (1, 0, 0, 0) + elif direction == 'down': + return (0, 1, 0, 0) + elif direction == 'left': + return (0, 0, 1, 0) + elif direction == 'right': + return (0, 0, 0, 1) + else: + return (0, 0, 0, 0) + +# 定义Q学习类 +class QLearning: + def __init__(self, actions=[0, 1, 2], epsilon=1.0, alpha=0.1, gamma=0.99, memory_size=1000, q_table=None): + self.actions = actions + self.epsilon = epsilon + self.alpha = alpha + self.gamma = gamma + self.q_table = {} + if q_table is None: + self.q_table = {} + else: + print("load qtable!") + self.q_table = q_table + self.memory = [] + self.memory_size = memory_size + + def choose_action(self, state): + state_tuple = state + if state_tuple not in self.q_table: + self.q_table[state_tuple] = [0]*len(self.actions) + if random.uniform(0, 1) < self.epsilon: + action = random.choice(self.actions) + else: + action = np.argmax(self.q_table[state_tuple]) + return action + + def learn(self, state, action, reward, next_state, done): + state_tuple = state + next_state_tuple = next_state + if next_state_tuple not in self.q_table: + self.q_table[next_state_tuple] = [0]*len(self.actions) + q_predict = self.q_table[state_tuple][action] + q_target = reward + self.gamma * max(self.q_table[next_state_tuple]) if not done else reward + self.q_table[state_tuple][action] += self.alpha * (q_target - q_predict) + self.remember(state_tuple, action, reward, next_state_tuple, done) + self.replay() + + def remember(self, state, action, reward, next_state, done): + if len(self.memory) > self.memory_size: + del self.memory[0] + self.memory.append((state, action, reward, next_state, done)) + + def replay(self, batch_size=32): + if len(self.memory) < batch_size: + return + batch = random.sample(self.memory, batch_size) + for state, action, reward, next_state, done in batch: + q_predict = self.q_table[state][action] + q_target = reward + self.gamma * max(self.q_table[next_state]) if not done else reward + self.q_table[state][action] += self.alpha * (q_target - q_predict) + + def decay_epsilon(self, min_epsilon=0.001, decay_rate=0.99999): + if self.epsilon > min_epsilon: + self.epsilon *= decay_rate + +# 训练函数 +def train(env, agent, episodes=10000, visualize=False): + for episode in range(episodes): + state = env.reset() + done = False + if episode % 50000 == 0 or visualize: + print(episode) + print(env.record) + while not done: + action = agent.choose_action(state) + env.action_memory.append(action) + next_state, reward, done = env.step(action) + agent.learn(state, action, reward, next_state, done) + state = next_state + # if reward == 0 and done: + # print(episode, "survive!") + draw_env(env, visualize) + env.record['max steps'] = max(env.record['max steps'], env.steps) + agent.decay_epsilon() + if episode % 1000000 == 0: + with open(f'q_table_{episode}.pkl', 'wb') as f: + pickle.dump(agent.q_table, f) + print(env.record) + +# 可视化函数 +def draw_env(env, visualize=True): + if visualize: + screen.fill(BACKGROUND) + for body in env.snake: + pygame.draw.rect(screen, SNAKE_COLOR, (body[1]*CELL_SIZE, body[0]*CELL_SIZE, CELL_SIZE, CELL_SIZE)) + pygame.draw.rect(screen, FOOD_COLOR, (env.food[1]*CELL_SIZE, env.food[0]*CELL_SIZE, CELL_SIZE, CELL_SIZE)) + pygame.display.flip() + + # 添加延迟 + pygame.time.delay(10) + + # 处理事件 + for event in pygame.event.get(): + if event.type == pygame.QUIT: + pygame.quit() + sys.exit() + +# 主函数 +def main(): + env = SnakeEnv(GRID_SIZE) + actions = [0, 1, 2] # 保持、左转、右转 + visualize = True + if visualize: + with open('q_table_1000000.pkl', 'rb') as f: + q_table = pickle.load(f) + agent = QLearning(actions, q_table=q_table if visualize else None, epsilon=0.0 if visualize else 1.0) + # agent = QLearning(actions) + train(env, agent, episodes=10000000, visualize=visualize) # 关闭可视化 + pygame.quit() + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/src/q_table_1000000.pkl b/src/q_table_1000000.pkl new file mode 100644 index 0000000..6a5d066 Binary files /dev/null and b/src/q_table_1000000.pkl differ