"""Solution for q_learning.py — remove before publishing to students.""" import random import babysnake from retro.input import ProgrammaticInput from retro.views.headless import HeadlessView ACTIONS = ["KEY_RIGHT", "KEY_DOWN", "KEY_LEFT", "KEY_UP"] class BabySnakeEnv: def reset(self): self._inp = ProgrammaticInput() self.game = babysnake.create_game() self.game.input_source = self._inp self.game.view = HeadlessView() self.game.start() self._prev_reward = 0.0 return self._get_state() def step(self, action): self._inp.press(action) self.game.step() next_state = self._get_state() reward = self.game.state['reward'] - self._prev_reward self._prev_reward = self.game.state['reward'] done = not self.game.playing return next_state, reward, done def _get_state(self): s = self.game.state return (int(s['agent_x']), int(s['agent_y']), int(s['food_x']), int(s['food_y'])) def choose_action(q_table, state, epsilon): if random.random() < epsilon: return random.choice(ACTIONS) q_values = [q_table.get((state, a), 0.0) for a in ACTIONS] return ACTIONS[q_values.index(max(q_values))] def update_q(q_table, state, action, reward, next_state, alpha, gamma): old_q = q_table.get((state, action), 0.0) next_q_values = [q_table.get((next_state, a), 0.0) for a in ACTIONS] best_next_q = max(next_q_values) new_q = old_q + alpha * (reward + gamma * best_next_q - old_q) q_table[(state, action)] = new_q def train( episodes=1000, alpha=0.1, gamma=0.95, epsilon=1.0, epsilon_decay=0.995, epsilon_min=0.05, ): q_table = {} env = BabySnakeEnv() for episode in range(episodes): state = env.reset() total_reward = 0.0 while env.game.playing: action = choose_action(q_table, state, epsilon) next_state, reward, done = env.step(action) update_q(q_table, state, action, reward, next_state, alpha, gamma) state = next_state total_reward += reward epsilon = max(epsilon_min, epsilon * epsilon_decay) if (episode + 1) % 100 == 0: print( f"Episode {episode + 1:5d} " f"reward={total_reward:6.1f} " f"score={env.game.state['score']} " f"epsilon={epsilon:.3f} " f"q_entries={len(q_table)}" ) return q_table def watch(q_table=None): if q_table is None: print("Training first...") q_table = train() _inp = ProgrammaticInput() class PolicyInput: def collect(self): s = game.state state = (int(s['agent_x']), int(s['agent_y']), int(s['food_x']), int(s['food_y'])) q_values = [q_table.get((state, a), 0.0) for a in ACTIONS] best = ACTIONS[q_values.index(max(q_values))] _inp.press(best) return _inp.collect() game = babysnake.create_game() game.play(input_source=PolicyInput()) if __name__ == '__main__': print("Training Q-learning agent on BabySnake...") q_table = train() print(f"\nDone. Q-table has {len(q_table)} entries.") print("\nWatching trained agent (press Enter or Escape to quit)...") watch(q_table)