Training an AI to Play Snake

Training an AI to Play Snake

- 17 mins

👋 Hello AI enthusiasts! Today we’ll explore how to train an AI to play the classic Snake game using PyTorch. This is part of our journey in applying reinforcement learning to games 🚀

📦 Source Code: You can find the complete implementation on GitHub

Prerequisites and running the project 🛠️

Before we dive in, let’s set up our environment properly:

Environment Setup 🔧

  1. Create and activate a virtual environment:
    python -m venv ai-snake-venv
    source ai-snake-venv/bin/activate  # Linux/WSL
    
  2. Install dependencies:
    pip install -r requirements.txt
    

Dependencies used:

pygame
torch
numpy

For CUDA-enabled PyTorch, visit PyTorch Get Started for details.

  1. Run the project:
    python snake-game-entry.py
    

Code Structure 📁

Our project is organized into:

ai-snake-game/
├── ai_trainer.py     # AI training logic
├── constants.py      # Game constants
├── food.py           # Food class
├── game.py           # Game rendering functions
├── snake.py          # Snake class
├── snake-game-entry.py  # Main game entry point
├── requirements.txt  # Project dependencies
└── README.md         # Project documentation

Project Features ✨

Our implementation includes:

Snake Game Code Explanation 🐍

snake.py

The Snake class handles the snake’s behavior:

class Snake:
    def __init__(self):
        self.length = 1
        self.positions = [(GRID_WIDTH // 2, GRID_HEIGHT // 2)]
        self.direction = random.choice([(0, 1), (0, -1), (1, 0), (-1, 0)])
        self.color = SNAKE_COLOR
        self.score = 1
        self.growth_pending = False

    def get_head_position(self):
        return self.positions[0]

    def turn(self, point):
        if self.length > 1 and (point[0] * -1, point[1] * -1) == self.direction:
            return
        self.direction = point

    def move(self):
        cur = self.get_head_position()
        x, y = self.direction
        new = (((cur[0] + x) % GRID_WIDTH), (cur[1] + y) % GRID_HEIGHT)
        if new in self.positions[2:]:
            return False

        self.positions.insert(0, new)
        if not self.growth_pending:
            self.positions.pop()
        else:
            self.growth_pending = False
        return True

    def grow(self):
        self.growth_pending = True
        self.length += 1
        self.score = self.length

food.py

The Food class handles the food’s behavior:

class Food:
    def __init__(self):
        self.position = (0, 0)
        self.color = FOOD_COLOR
        self.randomize_position()

    def randomize_position(self):
        self.position = (
            random.randint(0, GRID_WIDTH - 1),
            random.randint(0, GRID_HEIGHT - 1),
        )

game.py

This file contains functions for rendering the game:

def draw_grid(screen):
    for y in range(GRID_HEIGHT):
        for x in range(GRID_WIDTH):
            r = pygame.Rect((x * CELL_SIZE, y * CELL_SIZE), (CELL_SIZE, CELL_SIZE))
            pygame.draw.rect(screen, BORDER_COLOR, r, 1)

def draw_play_again_prompt(screen, font, snake):
    overlay = pygame.Surface((SCREEN_WIDTH, SCREEN_HEIGHT))
    overlay.fill((0, 0, 0))
    overlay.set_alpha(128)
    screen.blit(overlay, (0, 0))

    game_over_text = font.render("GAME OVER!", True, TEXT_COLOR)
    score_text = font.render(f"Final Length: {snake.score}", True, TEXT_COLOR)
    prompt_text = font.render("Play Again? (Y/N)", True, TEXT_COLOR)

    game_over_rect = game_over_text.get_rect(
        center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2 - 40)
    )
    score_rect = score_text.get_rect(center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2))
    prompt_rect = prompt_text.get_rect(
        center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2 + 40)
    )

    screen.blit(game_over_text, game_over_rect)
    screen.blit(score_text, score_rect)
    screen.blit(prompt_text, prompt_rect)

def draw_main_menu(screen, font, model_exists=False):
    screen.fill(BACKGROUND)

    title_text = font.render("AI Snake Game", True, TEXT_COLOR)
    play_text = font.render("1. Play Game", True, TEXT_COLOR)
    train_text = font.render("2. Train AI", True, TEXT_COLOR)
    watch_text = font.render("3. Watch AI Play", True, TEXT_COLOR)
    quit_text = font.render("4. Quit", True, TEXT_COLOR)

    title_rect = title_text.get_rect(center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 4))
    play_rect = play_text.get_rect(center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2 - 40))
    train_rect = train_text.get_rect(center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2))
    watch_rect = watch_text.get_rect(center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2 + 40))
    quit_rect = quit_text.get_rect(center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2 + 80))

    screen.blit(title_text, title_rect)
    screen.blit(play_text, play_rect)
    screen.blit(train_text, train_rect)

    if model_exists:
        screen.blit(watch_text, watch_rect)
    else:
        screen.blit(
            font.render("3. Watch AI Play (No model yet)", True, (128, 128, 128)),
            watch_rect,
        )

    screen.blit(quit_text, quit_rect)

def draw_training_prompt(screen, font):
    screen.fill(BACKGROUND)
    text = font.render("Enter number of episodes (1-1000):", True, TEXT_COLOR)
    rect = text.get_rect(center=(SCREEN_WIDTH / 2, SCREEN_HEIGHT / 2))
    screen.blit(text, rect)

def draw_training_progress(screen, font, episode, max_episodes, metrics=None):
    progress = episode / max_episodes
    bar_width = SCREEN_WIDTH * 0.8
    bar_height = 30
    bar_x = SCREEN_WIDTH * 0.1
    bar_y = SCREEN_HEIGHT / 2

    pygame.draw.rect(screen, BORDER_COLOR, (bar_x, bar_y, bar_width, bar_height))
    pygame.draw.rect(
        screen, SNAKE_COLOR, (bar_x, bar_y, bar_width * progress, bar_height)
    )

    text = font.render(f"Training Progress: {int(progress * 100)}%", True, TEXT_COLOR)
    rect = text.get_rect(center=(SCREEN_WIDTH / 2, bar_y - 40))
    screen.blit(text, rect)

    if metrics:
        y_offset = bar_y + 60
        length_color = (
            SNAKE_COLOR
            if metrics["Current Length"] >= metrics["Longest Snake"]
            else TEXT_COLOR
        )
        headers = [
            ("Longest Snake Length", metrics["Longest Snake"], TEXT_COLOR),
            ("Current Snake Length", metrics["Current Length"], length_color),
            ("Average Length", metrics["Average Length"], TEXT_COLOR),
        ]

        for label, value, color in headers:
            metric_text = font.render(f"{label}: {value}", True, color)
            metric_rect = metric_text.get_rect(center=(SCREEN_WIDTH / 2, y_offset))
            screen.blit(metric_text, metric_rect)
            y_offset += 30

AI Code Explanation 🧠

Architecture and Algorithm

The AI for the Snake game is built using a Deep Q-Network (DQN) algorithm. The DQN is a reinforcement learning algorithm that uses a neural network to approximate the Q-value function, which predicts the expected future rewards for each action given a state.

Neural Network Architecture

The neural network used in our AI consists of three fully connected layers:

Why DQN?

Important Snippets for AI Training

Initializing the AI Trainer

class AITrainer:
    def __init__(self):
        self.model = SnakeAI()
        self.target_model = SnakeAI()
        self.target_model.load_state_dict(self.model.state_dict())
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
        self.criterion = nn.MSELoss()
        self.memory = deque(maxlen=100000)
        self.batch_size = 64
        self.gamma = 0.95
        self.epsilon = 1.0
        self.epsilon_min = 0.01
        self.epsilon_decay = 0.995
        self.target_update = 10
        self.metrics = {
            "Longest Snake": 1,
            "Average Length": 1,
            "Games Played": 0,
            "Current Length": 1,
        }
        self.models_dir = "models"
        if not os.path.exists(self.models_dir):
            os.makedirs(self.models_dir)

Training Step

def train_step(self):
    if len(self.memory) < self.batch_size:
        return

    batch = random.sample(self.memory, self.batch_size)
    states = np.array([x[0] for x in batch])
    actions = np.array([x[1] for x in batch])
    rewards = np.array([x[2] for x in batch])
    next_states = np.array([x[3] for x in batch])
    dones = np.array([x[4] for x in batch])

    states = torch.FloatTensor(states).to(device)
    actions = torch.LongTensor(actions).to(device)
    rewards = torch.FloatTensor(rewards).to(device)
    next_states = torch.FloatTensor(next_states).to(device)
    dones = torch.FloatTensor(dones).to(device)

    current_q_values = self.model(states).gather(1, actions.unsqueeze(1))
    next_q_values = self.target_model(next_states).max(1)[0].detach()
    target_q_values = rewards + (1 - dones) * self.gamma * next_q_values

    loss = self.criterion(current_q_values.squeeze(), target_q_values)

    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()

    if self.epsilon > self.epsilon_min:
        self.epsilon *= self.epsilon_decay

Updating the Target Network

def update_target_network(self, episode):
    if episode % self.target_update == 0:
        self.target_model.load_state_dict(self.model.state_dict())

Saving the Model

def save_model(self):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"snake_ai_model_{timestamp}.pth"
    filepath = os.path.join(self.models_dir, filename)

    model_data = {
        "model_state_dict": self.model.state_dict(),
        "optimizer_state_dict": self.optimizer.state_dict(),
    }
    metadata = {
        "epsilon": self.epsilon,
        "metrics": self.metrics,
        "timestamp": timestamp,
    }

    torch.save(model_data, filepath)
    metadata_path = os.path.join(
        self.models_dir, f"snake_ai_metadata_{timestamp}.pt"
    )
    torch.save(metadata, metadata_path)

    self.latest_model = filepath

    latest_model_link = os.path.join(self.models_dir, "latest_model.pth")
    latest_metadata_link = os.path.join(self.models_dir, "latest_metadata.pt")

    if os.path.exists(latest_model_link):
        os.remove(latest_model_link)
    if os.path.exists(latest_metadata_link):
        os.remove(latest_metadata_link)

    torch.save(model_data, latest_model_link)
    torch.save(metadata, latest_metadata_link)

For more details, check the ai_trainer.py file.

Conclusion 🎉

Training an AI to play Snake is a fun and educational project that demonstrates the power of reinforcement learning. By following this guide, you can create your own AI Snake game and experiment with different training parameters to see how they affect the AI’s performance.

Happy coding! 🚀

Adel Aloui

Adel Aloui

Software Engineer with expertise in building innovative AI-driven solutions.