init:
训练100万步成果,在最多1000步、12*12的情况下 推理代码
This commit is contained in:
parent
9419df92eb
commit
1bc08f2f97
311
src/game.py
Normal file
311
src/game.py
Normal file
@ -0,0 +1,311 @@
|
||||
import numpy as np
|
||||
import random
|
||||
import pygame
|
||||
import sys
|
||||
import pickle
|
||||
|
||||
# 初始化pygame
|
||||
pygame.init()
|
||||
|
||||
# 设置窗口大小
|
||||
GRID_SIZE = 128
|
||||
CELL_SIZE = 10
|
||||
WINDOW_SIZE = GRID_SIZE * CELL_SIZE
|
||||
screen = pygame.display.set_mode((WINDOW_SIZE, WINDOW_SIZE))
|
||||
pygame.display.set_caption('Snake Game')
|
||||
|
||||
# 定义颜色
|
||||
BACKGROUND = (0, 0, 0)
|
||||
SNAKE_COLOR = (0, 255, 0)
|
||||
FOOD_COLOR = (255, 0, 0)
|
||||
|
||||
|
||||
# 定义环境类
|
||||
class SnakeEnv:
|
||||
def __init__(self, size=12):
|
||||
self.size = size
|
||||
self.reset()
|
||||
self.record = {
|
||||
"survive": 0,
|
||||
"wall": 0,
|
||||
"starve": 0,
|
||||
"self": 0,
|
||||
"max steps": 0,
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
head = (random.randint(0, self.size-1), random.randint(0, self.size-1))
|
||||
self.snake = [head, (head[0]-1, head[1]), (head[0]-2, head[1])]
|
||||
self.food = self.generate_food()
|
||||
self.steps = 0
|
||||
self.death = False
|
||||
self.length = 3
|
||||
self.action_memory = []
|
||||
self.direction = 'right' # 初始化方向为右
|
||||
self.prev_distance = self.calculate_distance() # 初始化上一状态的距离
|
||||
return self.get_state()
|
||||
|
||||
def generate_food(self):
|
||||
while True:
|
||||
food = (random.randint(0, self.size-1), random.randint(0, self.size-1))
|
||||
if food not in self.snake:
|
||||
return food
|
||||
|
||||
def calculate_distance(self):
|
||||
head = self.snake[0]
|
||||
food = self.food
|
||||
return np.sqrt((head[0] - food[0])**2 + (head[1] - food[1])**2)
|
||||
|
||||
def step(self, action):
|
||||
head = self.snake[0]
|
||||
if action == 0: # 保持方向
|
||||
pass
|
||||
elif action == 1: # 左转
|
||||
if self.direction == 'up':
|
||||
self.direction = 'left'
|
||||
elif self.direction == 'down':
|
||||
self.direction = 'right'
|
||||
elif self.direction == 'left':
|
||||
self.direction = 'down'
|
||||
elif self.direction == 'right':
|
||||
self.direction = 'up'
|
||||
elif action == 2: # 右转
|
||||
if self.direction == 'up':
|
||||
self.direction = 'right'
|
||||
elif self.direction == 'down':
|
||||
self.direction = 'left'
|
||||
elif self.direction == 'left':
|
||||
self.direction = 'up'
|
||||
elif self.direction == 'right':
|
||||
self.direction = 'down'
|
||||
|
||||
# 根据方向移动蛇头
|
||||
if self.direction == 'up':
|
||||
new_head = (head[0]-1, head[1])
|
||||
elif self.direction == 'down':
|
||||
new_head = (head[0]+1, head[1])
|
||||
elif self.direction == 'left':
|
||||
new_head = (head[0], head[1]-1)
|
||||
elif self.direction == 'right':
|
||||
new_head = (head[0], head[1]+1)
|
||||
|
||||
# 检查是否撞墙
|
||||
if new_head[0] < 0 or new_head[0] >= self.size or new_head[1] < 0 or new_head[1] >= self.size:
|
||||
self.record['wall'] += 1
|
||||
self.death = True
|
||||
return self.get_state(), - GRID_SIZE * GRID_SIZE, True
|
||||
|
||||
# 检查是否撞到自己
|
||||
if new_head in self.snake[:-1]:
|
||||
self.record['self'] += 1
|
||||
self.death = True
|
||||
return self.get_state(), - GRID_SIZE * GRID_SIZE, True
|
||||
|
||||
self.snake.insert(0, new_head)
|
||||
|
||||
# 检查是否吃到食物
|
||||
if new_head == self.food:
|
||||
self.length += 1
|
||||
self.food = self.generate_food()
|
||||
reward = GRID_SIZE * 3
|
||||
else:
|
||||
self.snake.pop()
|
||||
reward = -1
|
||||
|
||||
# 计算当前距离
|
||||
current_distance = self.calculate_distance()
|
||||
distance_change = self.prev_distance - current_distance
|
||||
self.prev_distance = current_distance
|
||||
|
||||
# 添加距离变化的奖励
|
||||
if distance_change > 0:
|
||||
reward += 0.5
|
||||
elif distance_change < 0:
|
||||
reward -= 0.5
|
||||
|
||||
self.steps += 1
|
||||
|
||||
# 每250步失去一格身体
|
||||
if self.steps % (GRID_SIZE * 3) == 0 and self.length >= 1:
|
||||
self.length -= 1
|
||||
if self.length == 0:
|
||||
self.record['starve'] += 1
|
||||
self.death = True
|
||||
return self.get_state(), -10, True
|
||||
self.snake.pop()
|
||||
|
||||
# # 一局最多1000步
|
||||
# if self.steps >= 1000:
|
||||
# self.record['survive'] += 1
|
||||
# self.death = True
|
||||
# return self.get_state(), 0, True
|
||||
|
||||
done = self.death
|
||||
return self.get_state(), reward, done
|
||||
|
||||
def get_state(self):
|
||||
head = self.snake[0]
|
||||
food = self.food
|
||||
direction = self.direction
|
||||
grid = np.zeros((3, 3))
|
||||
for i in range(-1, 2):
|
||||
for j in range(-1, 2):
|
||||
x = head[0] + i
|
||||
y = head[1] + j
|
||||
if x < 0 or x >= self.size or y < 0 or y >= self.size:
|
||||
grid[i+1][j+1] = 1 # 墙壁
|
||||
elif (x, y) in self.snake:
|
||||
grid[i+1][j+1] = 2 # 蛇身
|
||||
elif (x, y) == self.food:
|
||||
grid[i+1][j+1] = 3 # 食物
|
||||
else:
|
||||
grid[i+1][j+1] = 0 # 空地
|
||||
memory = self.action_memory[-5:] if len(self.action_memory) >= 5 else self.action_memory + [0]*(5 - len(self.action_memory))
|
||||
food_direction = self.get_food_direction()
|
||||
direction_one_hot = self.get_direction_one_hot()
|
||||
return tuple(grid.flatten().tolist() + memory + list(food_direction))
|
||||
|
||||
def get_food_direction(self):
|
||||
head = self.snake[0]
|
||||
food = self.food
|
||||
dx = food[0] - head[0]
|
||||
dy = food[1] - head[1]
|
||||
if dx > 0:
|
||||
fx = 1
|
||||
elif dx < 0:
|
||||
fx = -1
|
||||
else:
|
||||
fx = 0
|
||||
if dy > 0:
|
||||
fy = 1
|
||||
elif dy < 0:
|
||||
fy = -1
|
||||
else:
|
||||
fy = 0
|
||||
return (fx, fy)
|
||||
|
||||
def get_direction_one_hot(self):
|
||||
direction = self.direction
|
||||
if direction == 'up':
|
||||
return (1, 0, 0, 0)
|
||||
elif direction == 'down':
|
||||
return (0, 1, 0, 0)
|
||||
elif direction == 'left':
|
||||
return (0, 0, 1, 0)
|
||||
elif direction == 'right':
|
||||
return (0, 0, 0, 1)
|
||||
else:
|
||||
return (0, 0, 0, 0)
|
||||
|
||||
# 定义Q学习类
|
||||
class QLearning:
|
||||
def __init__(self, actions=[0, 1, 2], epsilon=1.0, alpha=0.1, gamma=0.99, memory_size=1000, q_table=None):
|
||||
self.actions = actions
|
||||
self.epsilon = epsilon
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.q_table = {}
|
||||
if q_table is None:
|
||||
self.q_table = {}
|
||||
else:
|
||||
print("load qtable!")
|
||||
self.q_table = q_table
|
||||
self.memory = []
|
||||
self.memory_size = memory_size
|
||||
|
||||
def choose_action(self, state):
|
||||
state_tuple = state
|
||||
if state_tuple not in self.q_table:
|
||||
self.q_table[state_tuple] = [0]*len(self.actions)
|
||||
if random.uniform(0, 1) < self.epsilon:
|
||||
action = random.choice(self.actions)
|
||||
else:
|
||||
action = np.argmax(self.q_table[state_tuple])
|
||||
return action
|
||||
|
||||
def learn(self, state, action, reward, next_state, done):
|
||||
state_tuple = state
|
||||
next_state_tuple = next_state
|
||||
if next_state_tuple not in self.q_table:
|
||||
self.q_table[next_state_tuple] = [0]*len(self.actions)
|
||||
q_predict = self.q_table[state_tuple][action]
|
||||
q_target = reward + self.gamma * max(self.q_table[next_state_tuple]) if not done else reward
|
||||
self.q_table[state_tuple][action] += self.alpha * (q_target - q_predict)
|
||||
self.remember(state_tuple, action, reward, next_state_tuple, done)
|
||||
self.replay()
|
||||
|
||||
def remember(self, state, action, reward, next_state, done):
|
||||
if len(self.memory) > self.memory_size:
|
||||
del self.memory[0]
|
||||
self.memory.append((state, action, reward, next_state, done))
|
||||
|
||||
def replay(self, batch_size=32):
|
||||
if len(self.memory) < batch_size:
|
||||
return
|
||||
batch = random.sample(self.memory, batch_size)
|
||||
for state, action, reward, next_state, done in batch:
|
||||
q_predict = self.q_table[state][action]
|
||||
q_target = reward + self.gamma * max(self.q_table[next_state]) if not done else reward
|
||||
self.q_table[state][action] += self.alpha * (q_target - q_predict)
|
||||
|
||||
def decay_epsilon(self, min_epsilon=0.001, decay_rate=0.99999):
|
||||
if self.epsilon > min_epsilon:
|
||||
self.epsilon *= decay_rate
|
||||
|
||||
# 训练函数
|
||||
def train(env, agent, episodes=10000, visualize=False):
|
||||
for episode in range(episodes):
|
||||
state = env.reset()
|
||||
done = False
|
||||
if episode % 50000 == 0 or visualize:
|
||||
print(episode)
|
||||
print(env.record)
|
||||
while not done:
|
||||
action = agent.choose_action(state)
|
||||
env.action_memory.append(action)
|
||||
next_state, reward, done = env.step(action)
|
||||
agent.learn(state, action, reward, next_state, done)
|
||||
state = next_state
|
||||
# if reward == 0 and done:
|
||||
# print(episode, "survive!")
|
||||
draw_env(env, visualize)
|
||||
env.record['max steps'] = max(env.record['max steps'], env.steps)
|
||||
agent.decay_epsilon()
|
||||
if episode % 1000000 == 0:
|
||||
with open(f'q_table_{episode}.pkl', 'wb') as f:
|
||||
pickle.dump(agent.q_table, f)
|
||||
print(env.record)
|
||||
|
||||
# 可视化函数
|
||||
def draw_env(env, visualize=True):
|
||||
if visualize:
|
||||
screen.fill(BACKGROUND)
|
||||
for body in env.snake:
|
||||
pygame.draw.rect(screen, SNAKE_COLOR, (body[1]*CELL_SIZE, body[0]*CELL_SIZE, CELL_SIZE, CELL_SIZE))
|
||||
pygame.draw.rect(screen, FOOD_COLOR, (env.food[1]*CELL_SIZE, env.food[0]*CELL_SIZE, CELL_SIZE, CELL_SIZE))
|
||||
pygame.display.flip()
|
||||
|
||||
# 添加延迟
|
||||
pygame.time.delay(10)
|
||||
|
||||
# 处理事件
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.QUIT:
|
||||
pygame.quit()
|
||||
sys.exit()
|
||||
|
||||
# 主函数
|
||||
def main():
|
||||
env = SnakeEnv(GRID_SIZE)
|
||||
actions = [0, 1, 2] # 保持、左转、右转
|
||||
visualize = True
|
||||
if visualize:
|
||||
with open('q_table_1000000.pkl', 'rb') as f:
|
||||
q_table = pickle.load(f)
|
||||
agent = QLearning(actions, q_table=q_table if visualize else None, epsilon=0.0 if visualize else 1.0)
|
||||
# agent = QLearning(actions)
|
||||
train(env, agent, episodes=10000000, visualize=visualize) # 关闭可视化
|
||||
pygame.quit()
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
BIN
src/q_table_1000000.pkl
Normal file
BIN
src/q_table_1000000.pkl
Normal file
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user