
Training an AI to Play Snake
- 17 minsTable of Contents
👋 Hello AI enthusiasts! Today we’ll explore how to train an AI to play the classic Snake game using PyTorch. This is part of our journey in applying reinforcement learning to games 🚀
📦 Source Code: You can find the complete implementation on GitHub
Prerequisites and running the project 🛠️
Before we dive in, let’s set up our environment properly:
Environment Setup 🔧
- Create and activate a virtual environment:
python -m venv ai-snake-venv source ai-snake-venv/bin/activate # Linux/WSL
- Install dependencies:
pip install -r requirements.txt
Dependencies used:
pygame
torch
numpy
For CUDA-enabled PyTorch, visit PyTorch Get Started for details.
- Run the project:
python snake-game-entry.py
Code Structure 📁
Our project is organized into:
ai-snake-game/
├── ai_trainer.py # AI training logic
├── constants.py # Game constants
├── food.py # Food class
├── game.py # Game rendering functions
├── snake.py # Snake class
├── snake-game-entry.py # Main game entry point
├── requirements.txt # Project dependencies
└── README.md # Project documentation
Project Features ✨
Our implementation includes:
- 🎮 Classic Snake gameplay with manual controls
- 🧠 AI training using reinforcement learning
- 👀 Watch the AI play the game after training
- 📊 Real-time training progress and metrics
Snake Game Code Explanation 🐍
snake.py
The Snake
class handles the snake’s behavior:
class Snake:
def __init__(self):
self.length = 1
self.positions = [(GRID_WIDTH // 2, GRID_HEIGHT // 2)]
self.direction = random.choice([(0, 1), (0, -1), (1, 0), (-1, 0)])
self.color = SNAKE_COLOR
self.score = 1
self.growth_pending = False
def get_head_position(self):
return self.positions[0]
def turn(self, point):
if self.length > 1 and (point[0] * -1, point[1] * -1) == self.direction:
return
self.direction = point
def move(self):
cur = self.get_head_position()
x, y = self.direction
new = (((cur[0] + x) % GRID_WIDTH), (cur[1] + y) % GRID_HEIGHT)
if new in self.positions[2:]:
return False
self.positions.insert(0, new)
if not self.growth_pending:
self.positions.pop()
else:
self.growth_pending = False
return True
def grow(self):
self.growth_pending = True
self.length += 1
self.score = self.length
- Initialization: Sets the initial position, direction, and length of the snake.
- Movement: Updates the snake’s position based on its direction.
- Growth: Increases the snake’s length when it eats food.
food.py
The Food
class handles the food’s behavior:
class Food:
def __init__(self):
self.position = (0, 0)
self.color = FOOD_COLOR
self.randomize_position()
def randomize_position(self):
self.position = (
random.randint(0, GRID_WIDTH - 1),
random.randint(0, GRID_HEIGHT - 1),
)
- Initialization: Sets the initial position and color of the food.
- Randomization: Changes the food’s position to a random location on the grid.
game.py
This file contains functions for rendering the game:
def draw_grid(screen):
for y in range(GRID_HEIGHT):
for x in range(GRID_WIDTH):
r = pygame.Rect((x * CELL_SIZE, y * CELL_SIZE), (CELL_SIZE, CELL_SIZE))
pygame.draw.rect(screen, BORDER_COLOR, r, 1)
def draw_play_again_prompt(screen, font, snake):
overlay = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
overlay.fill((0, 0, 0))
overlay.set_alpha(128)
screen.blit(overlay, (0, 0))
game_over_text = font.render("GAME OVER!", True, TEXT_COLOR)
score_text = font.render(f"Final Length: {snake.score}", True, TEXT_COLOR)
prompt_text = font.render("Play Again? (Y/N)", True, TEXT_COLOR)
game_over_rect = game_over_text.get_rect(
center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2 - 40)
)
score_rect = score_text.get_rect(center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2))
prompt_rect = prompt_text.get_rect(
center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2 + 40)
)
screen.blit(game_over_text, game_over_rect)
screen.blit(score_text, score_rect)
screen.blit(prompt_text, prompt_rect)
def draw_main_menu(screen, font, model_exists=False):
screen.fill(BACKGROUND)
title_text = font.render("AI Snake Game", True, TEXT_COLOR)
play_text = font.render("1. Play Game", True, TEXT_COLOR)
train_text = font.render("2. Train AI", True, TEXT_COLOR)
watch_text = font.render("3. Watch AI Play", True, TEXT_COLOR)
quit_text = font.render("4. Quit", True, TEXT_COLOR)
title_rect = title_text.get_rect(center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 4))
play_rect = play_text.get_rect(center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2 - 40))
train_rect = train_text.get_rect(center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2))
watch_rect = watch_text.get_rect(center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2 + 40))
quit_rect = quit_text.get_rect(center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2 + 80))
screen.blit(title_text, title_rect)
screen.blit(play_text, play_rect)
screen.blit(train_text, train_rect)
if model_exists:
screen.blit(watch_text, watch_rect)
else:
screen.blit(
font.render("3. Watch AI Play (No model yet)", True, (128, 128, 128)),
watch_rect,
)
screen.blit(quit_text, quit_rect)
def draw_training_prompt(screen, font):
screen.fill(BACKGROUND)
text = font.render("Enter number of episodes (1-1000):", True, TEXT_COLOR)
rect = text.get_rect(center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2))
screen.blit(text, rect)
def draw_training_progress(screen, font, episode, max_episodes, metrics=None):
progress = episode / max_episodes
bar_width = SCREEN_WIDTH * 0.8
bar_height = 30
bar_x = SCREEN_WIDTH * 0.1
bar_y = SCREEN_HEIGHT / 2
pygame.draw.rect(screen, BORDER_COLOR, (bar_x, bar_y, bar_width, bar_height))
pygame.draw.rect(
screen, SNAKE_COLOR, (bar_x, bar_y, bar_width * progress, bar_height)
)
text = font.render(f"Training Progress: {int(progress * 100)}%", True, TEXT_COLOR)
rect = text.get_rect(center=(SCREEN_WIDTH / 2, bar_y - 40))
screen.blit(text, rect)
if metrics:
y_offset = bar_y + 60
length_color = (
SNAKE_COLOR
if metrics["Current Length"] >= metrics["Longest Snake"]
else TEXT_COLOR
)
headers = [
("Longest Snake Length", metrics["Longest Snake"], TEXT_COLOR),
("Current Snake Length", metrics["Current Length"], length_color),
("Average Length", metrics["Average Length"], TEXT_COLOR),
]
for label, value, color in headers:
metric_text = font.render(f"{label}: {value}", True, color)
metric_rect = metric_text.get_rect(center=(SCREEN_WIDTH / 2, y_offset))
screen.blit(metric_text, metric_rect)
y_offset += 30
- Grid Drawing: Draws the game grid.
- Game Over Prompt: Displays a prompt when the game is over.
- Main Menu: Renders the main menu with options to play, train the AI, watch the AI play, or quit.
- Training Prompt: Displays a prompt to enter the number of training episodes.
- Training Progress: Shows the training progress and metrics in real-time.
AI Code Explanation 🧠
Architecture and Algorithm
The AI for the Snake game is built using a Deep Q-Network (DQN) algorithm. The DQN is a reinforcement learning algorithm that uses a neural network to approximate the Q-value function, which predicts the expected future rewards for each action given a state.
Neural Network Architecture
The neural network used in our AI consists of three fully connected layers:
- Input Layer: Takes the state representation of the game, which includes information about the snake’s position, direction, and the food’s position.
- Hidden Layers: Two hidden layers with ReLU activation functions to introduce non-linearity and help the network learn complex patterns.
- Output Layer: Produces Q-values for the four possible actions (up, down, left, right).
Why DQN?
- Experience Replay: DQN uses a replay buffer to store experiences and samples from it to break the correlation between consecutive experiences, leading to more stable training.
- Target Network: A separate target network is used to compute the target Q-values, which helps in reducing the oscillations during training.
- Exploration vs. Exploitation: The epsilon-greedy strategy is used to balance exploration (trying new actions) and exploitation (using the best-known actions).
Important Snippets for AI Training
Initializing the AI Trainer
class AITrainer:
def __init__(self):
self.model = SnakeAI()
self.target_model = SnakeAI()
self.target_model.load_state_dict(self.model.state_dict())
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
self.criterion = nn.MSELoss()
self.memory = deque(maxlen=100000)
self.batch_size = 64
self.gamma = 0.95
self.epsilon = 1.0
self.epsilon_min = 0.01
self.epsilon_decay = 0.995
self.target_update = 10
self.metrics = {
"Longest Snake": 1,
"Average Length": 1,
"Games Played": 0,
"Current Length": 1,
}
self.models_dir = "models"
if not os.path.exists(self.models_dir):
os.makedirs(self.models_dir)
- Initialization: Sets up the neural network, optimizer, loss function, and training parameters.
- Replay Buffer: Uses a deque to store experiences for experience replay.
- Epsilon-Greedy Strategy: Initializes epsilon for exploration vs. exploitation.
Training Step
def train_step(self):
if len(self.memory) < self.batch_size:
return
batch = random.sample(self.memory, self.batch_size)
states = np.array([x[0] for x in batch])
actions = np.array([x[1] for x in batch])
rewards = np.array([x[2] for x in batch])
next_states = np.array([x[3] for x in batch])
dones = np.array([x[4] for x in batch])
states = torch.FloatTensor(states).to(device)
actions = torch.LongTensor(actions).to(device)
rewards = torch.FloatTensor(rewards).to(device)
next_states = torch.FloatTensor(next_states).to(device)
dones = torch.FloatTensor(dones).to(device)
current_q_values = self.model(states).gather(1, actions.unsqueeze(1))
next_q_values = self.target_model(next_states).max(1)[0].detach()
target_q_values = rewards + (1 - dones) * self.gamma * next_q_values
loss = self.criterion(current_q_values.squeeze(), target_q_values)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
- Batch Sampling: Samples a batch of experiences from the replay buffer.
- Q-Value Calculation: Computes current and target Q-values.
- Loss Calculation: Calculates the loss between current and target Q-values.
- Backpropagation: Updates the model parameters using backpropagation.
- Epsilon Decay: Decays epsilon to reduce exploration over time.
Updating the Target Network
def update_target_network(self, episode):
if episode % self.target_update == 0:
self.target_model.load_state_dict(self.model.state_dict())
- Target Network Update: Periodically updates the target network with the current model’s weights to stabilize training.
Saving the Model
def save_model(self):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"snake_ai_model_{timestamp}.pth"
filepath = os.path.join(self.models_dir, filename)
model_data = {
"model_state_dict": self.model.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
}
metadata = {
"epsilon": self.epsilon,
"metrics": self.metrics,
"timestamp": timestamp,
}
torch.save(model_data, filepath)
metadata_path = os.path.join(
self.models_dir, f"snake_ai_metadata_{timestamp}.pt"
)
torch.save(metadata, metadata_path)
self.latest_model = filepath
latest_model_link = os.path.join(self.models_dir, "latest_model.pth")
latest_metadata_link = os.path.join(self.models_dir, "latest_metadata.pt")
if os.path.exists(latest_model_link):
os.remove(latest_model_link)
if os.path.exists(latest_metadata_link):
os.remove(latest_metadata_link)
torch.save(model_data, latest_model_link)
torch.save(metadata, latest_metadata_link)
- Model Saving: Saves the model’s state dictionary and optimizer state.
- Metadata Saving: Saves training metadata including epsilon and metrics.
- Latest Model Links: Updates the latest model and metadata links for easy access.
For more details, check the ai_trainer.py
file.
Conclusion 🎉
Training an AI to play Snake is a fun and educational project that demonstrates the power of reinforcement learning. By following this guide, you can create your own AI Snake game and experiment with different training parameters to see how they affect the AI’s performance.
Happy coding! 🚀