"""Q-learning agent for BabySnake. This file contains starter code for implementing a Q-learning agent. You need to fill in two functions: - choose_action: select an action using an epsilon-greedy policy - update_q: update the Q-table using the Bellman equation Run this file to train the agent: python q_learning.py After training, run this to watch it play: python -c "from q_learning import watch; watch()" """ import random import babysnake from retro.input import ProgrammaticInput from retro.views.headless import HeadlessView # The four actions the agent can take. ACTIONS = ["KEY_RIGHT", "KEY_DOWN", "KEY_LEFT", "KEY_UP"] # --------------------------------------------------------------------------- # Environment wrapper # --------------------------------------------------------------------------- class BabySnakeEnv: """A simple wrapper that lets us step through BabySnake programmatically. Usage: env = BabySnakeEnv() state = env.reset() # start a new episode next_state, reward, done = env.step("KEY_RIGHT") """ def reset(self): """Start a new episode. Returns the initial state tuple.""" self._inp = ProgrammaticInput() self.game = babysnake.create_game() self.game.input_source = self._inp self.game.view = HeadlessView() self.game.start() self._prev_reward = 0.0 return self._get_state() def step(self, action): """Take one action. Returns (next_state, reward, done). Arguments: action (str): One of ACTIONS, or None for no-op. Returns: next_state (tuple): The state after the action. reward (float): The reward received this step. done (bool): True if the episode has ended. """ self._inp.press(action) self.game.step() next_state = self._get_state() reward = self.game.state['reward'] - self._prev_reward self._prev_reward = self.game.state['reward'] done = not self.game.playing return next_state, reward, done def _get_state(self): """Return the current state as a tuple of four integers.""" s = self.game.state return (int(s['agent_x']), int(s['agent_y']), int(s['food_x']), int(s['food_y'])) # --------------------------------------------------------------------------- # Q-learning functions — fill these in! # --------------------------------------------------------------------------- def choose_action(q_table, state, epsilon): """Choose an action using an epsilon-greedy policy. With probability `epsilon`, return a random action from ACTIONS. Otherwise, return the action with the highest Q-value in `q_table` for the given `state`. If a (state, action) pair has not been seen before, treat its Q-value as 0. Arguments: q_table (dict): Maps (state, action) -> Q-value. state (tuple): The current state, e.g. (1, 2, 3, 0). epsilon (float): Exploration rate, between 0.0 and 1.0. Returns: str: One action from ACTIONS. Hint: random.random() returns a float in [0.0, 1.0). random.choice(ACTIONS) returns a random action. q_table.get(key, default) is handy for missing entries. """ raise NotImplementedError("Fill in choose_action") def update_q(q_table, state, action, reward, next_state, alpha, gamma): """Update one entry of the Q-table using the Bellman equation. The update rule is: Q(s, a) <- Q(s, a) + alpha * (r + gamma * max_a' Q(s', a') - Q(s, a)) where: s, a — the state we were in and the action we took r — the reward we received s' — the state we ended up in max_a' ... — the best possible Q-value from the new state Arguments: q_table (dict): Maps (state, action) -> Q-value (modified in place). state (tuple): The state before the action. action (str): The action taken. reward (float): The reward received. next_state (tuple): The state after the action. alpha (float): Learning rate (how much to update). gamma (float): Discount factor (how much to value future rewards). Returns: None — modifies q_table in place. Hint: Q-values for unseen (state, action) pairs start at 0. """ raise NotImplementedError("Fill in update_q") # --------------------------------------------------------------------------- # Training loop # --------------------------------------------------------------------------- def train( episodes=1000, alpha=0.1, gamma=0.95, epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.05, ): """Train a Q-learning agent on BabySnake. Arguments: episodes (int): How many episodes to run. alpha (float): Learning rate. gamma (float): Discount factor. epsilon (float): Starting exploration rate. epsilon_decay (float): Multiply epsilon by this each episode. epsilon_min (float): Epsilon never falls below this. Returns: dict: The trained Q-table. """ q_table = {} env = BabySnakeEnv() for episode in range(episodes): state = env.reset() total_reward = 0.0 while env.game.playing: action = choose_action(q_table, state, epsilon) next_state, reward, done = env.step(action) update_q(q_table, state, action, reward, next_state, alpha, gamma) state = next_state total_reward += reward epsilon = max(epsilon_min, epsilon * epsilon_decay) if (episode + 1) % 100 == 0: print( f"Episode {episode + 1:5d} " f"reward={total_reward:6.1f} " f"score={env.game.state['score']} " f"epsilon={epsilon:.3f} " f"q_entries={len(q_table)}" ) return q_table def watch(q_table=None): """Watch the trained agent play in the terminal. Arguments: q_table (dict | None): A trained Q-table. If None, trains first. """ import babysnake from retro.input import ProgrammaticInput if q_table is None: print("Training first...") q_table = train() _inp = ProgrammaticInput() class PolicyInput: """An input source that picks actions from the Q-table.""" def collect(self): s = game.state state = (int(s['agent_x']), int(s['agent_y']), int(s['food_x']), int(s['food_y'])) q_values = [q_table.get((state, a), 0.0) for a in ACTIONS] best = ACTIONS[q_values.index(max(q_values))] _inp.press(best) return _inp.collect() game = babysnake.create_game() game.play(input_source=PolicyInput()) if __name__ == '__main__': print("Training Q-learning agent on BabySnake...") q_table = train() print(f"\nDone. Q-table has {len(q_table)} entries.") print("\nWatching trained agent (press Enter or Escape to quit)...") watch(q_table)