Files
lab_reinforcement_learning/q_learning_solution.py
Chris Proctor 42bc2e7a50 Initial commit
2026-06-22 16:14:58 -04:00

114 lines
3.3 KiB
Python

"""Solution for q_learning.py — remove before publishing to students."""
import random
import babysnake
from retro.input import ProgrammaticInput
from retro.views.headless import HeadlessView
ACTIONS = ["KEY_RIGHT", "KEY_DOWN", "KEY_LEFT", "KEY_UP"]
class BabySnakeEnv:
def reset(self):
self._inp = ProgrammaticInput()
self.game = babysnake.create_game()
self.game.input_source = self._inp
self.game.view = HeadlessView()
self.game.start()
self._prev_reward = 0.0
return self._get_state()
def step(self, action):
self._inp.press(action)
self.game.step()
next_state = self._get_state()
reward = self.game.state['reward'] - self._prev_reward
self._prev_reward = self.game.state['reward']
done = not self.game.playing
return next_state, reward, done
def _get_state(self):
s = self.game.state
return (int(s['agent_x']), int(s['agent_y']),
int(s['food_x']), int(s['food_y']))
def choose_action(q_table, state, epsilon):
if random.random() < epsilon:
return random.choice(ACTIONS)
q_values = [q_table.get((state, a), 0.0) for a in ACTIONS]
return ACTIONS[q_values.index(max(q_values))]
def update_q(q_table, state, action, reward, next_state, alpha, gamma):
old_q = q_table.get((state, action), 0.0)
next_q_values = [q_table.get((next_state, a), 0.0) for a in ACTIONS]
best_next_q = max(next_q_values)
new_q = old_q + alpha * (reward + gamma * best_next_q - old_q)
q_table[(state, action)] = new_q
def train(
episodes=1000,
alpha=0.1,
gamma=0.95,
epsilon=1.0,
epsilon_decay=0.995,
epsilon_min=0.05,
):
q_table = {}
env = BabySnakeEnv()
for episode in range(episodes):
state = env.reset()
total_reward = 0.0
while env.game.playing:
action = choose_action(q_table, state, epsilon)
next_state, reward, done = env.step(action)
update_q(q_table, state, action, reward, next_state, alpha, gamma)
state = next_state
total_reward += reward
epsilon = max(epsilon_min, epsilon * epsilon_decay)
if (episode + 1) % 100 == 0:
print(
f"Episode {episode + 1:5d} "
f"reward={total_reward:6.1f} "
f"score={env.game.state['score']} "
f"epsilon={epsilon:.3f} "
f"q_entries={len(q_table)}"
)
return q_table
def watch(q_table=None):
if q_table is None:
print("Training first...")
q_table = train()
_inp = ProgrammaticInput()
class PolicyInput:
def collect(self):
s = game.state
state = (int(s['agent_x']), int(s['agent_y']),
int(s['food_x']), int(s['food_y']))
q_values = [q_table.get((state, a), 0.0) for a in ACTIONS]
best = ACTIONS[q_values.index(max(q_values))]
_inp.press(best)
return _inp.collect()
game = babysnake.create_game()
game.play(input_source=PolicyInput())
if __name__ == '__main__':
print("Training Q-learning agent on BabySnake...")
q_table = train()
print(f"\nDone. Q-table has {len(q_table)} entries.")
print("\nWatching trained agent (press Enter or Escape to quit)...")
watch(q_table)