Initial commit
This commit is contained in:
218
q_learning.py
Normal file
218
q_learning.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Q-learning agent for BabySnake.
|
||||
|
||||
This file contains starter code for implementing a Q-learning agent.
|
||||
You need to fill in two functions:
|
||||
- choose_action: select an action using an epsilon-greedy policy
|
||||
- update_q: update the Q-table using the Bellman equation
|
||||
|
||||
Run this file to train the agent:
|
||||
python q_learning.py
|
||||
|
||||
After training, run this to watch it play:
|
||||
python -c "from q_learning import watch; watch()"
|
||||
"""
|
||||
|
||||
import random
|
||||
import babysnake
|
||||
from retro.input import ProgrammaticInput
|
||||
from retro.views.headless import HeadlessView
|
||||
|
||||
# The four actions the agent can take.
|
||||
ACTIONS = ["KEY_RIGHT", "KEY_DOWN", "KEY_LEFT", "KEY_UP"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Environment wrapper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class BabySnakeEnv:
|
||||
"""A simple wrapper that lets us step through BabySnake programmatically.
|
||||
|
||||
Usage:
|
||||
env = BabySnakeEnv()
|
||||
state = env.reset() # start a new episode
|
||||
next_state, reward, done = env.step("KEY_RIGHT")
|
||||
"""
|
||||
|
||||
def reset(self):
|
||||
"""Start a new episode. Returns the initial state tuple."""
|
||||
self._inp = ProgrammaticInput()
|
||||
self.game = babysnake.create_game()
|
||||
self.game.input_source = self._inp
|
||||
self.game.view = HeadlessView()
|
||||
self.game.start()
|
||||
self._prev_reward = 0.0
|
||||
return self._get_state()
|
||||
|
||||
def step(self, action):
|
||||
"""Take one action. Returns (next_state, reward, done).
|
||||
|
||||
Arguments:
|
||||
action (str): One of ACTIONS, or None for no-op.
|
||||
|
||||
Returns:
|
||||
next_state (tuple): The state after the action.
|
||||
reward (float): The reward received this step.
|
||||
done (bool): True if the episode has ended.
|
||||
"""
|
||||
self._inp.press(action)
|
||||
self.game.step()
|
||||
next_state = self._get_state()
|
||||
reward = self.game.state['reward'] - self._prev_reward
|
||||
self._prev_reward = self.game.state['reward']
|
||||
done = not self.game.playing
|
||||
return next_state, reward, done
|
||||
|
||||
def _get_state(self):
|
||||
"""Return the current state as a tuple of four integers."""
|
||||
s = self.game.state
|
||||
return (int(s['agent_x']), int(s['agent_y']),
|
||||
int(s['food_x']), int(s['food_y']))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Q-learning functions — fill these in!
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def choose_action(q_table, state, epsilon):
|
||||
"""Choose an action using an epsilon-greedy policy.
|
||||
|
||||
With probability `epsilon`, return a random action from ACTIONS.
|
||||
Otherwise, return the action with the highest Q-value in `q_table`
|
||||
for the given `state`. If a (state, action) pair has not been seen
|
||||
before, treat its Q-value as 0.
|
||||
|
||||
Arguments:
|
||||
q_table (dict): Maps (state, action) -> Q-value.
|
||||
state (tuple): The current state, e.g. (1, 2, 3, 0).
|
||||
epsilon (float): Exploration rate, between 0.0 and 1.0.
|
||||
|
||||
Returns:
|
||||
str: One action from ACTIONS.
|
||||
|
||||
Hint: random.random() returns a float in [0.0, 1.0).
|
||||
random.choice(ACTIONS) returns a random action.
|
||||
q_table.get(key, default) is handy for missing entries.
|
||||
"""
|
||||
raise NotImplementedError("Fill in choose_action")
|
||||
|
||||
|
||||
def update_q(q_table, state, action, reward, next_state, alpha, gamma):
|
||||
"""Update one entry of the Q-table using the Bellman equation.
|
||||
|
||||
The update rule is:
|
||||
|
||||
Q(s, a) <- Q(s, a) + alpha * (r + gamma * max_a' Q(s', a') - Q(s, a))
|
||||
|
||||
where:
|
||||
s, a — the state we were in and the action we took
|
||||
r — the reward we received
|
||||
s' — the state we ended up in
|
||||
max_a' ... — the best possible Q-value from the new state
|
||||
|
||||
Arguments:
|
||||
q_table (dict): Maps (state, action) -> Q-value (modified in place).
|
||||
state (tuple): The state before the action.
|
||||
action (str): The action taken.
|
||||
reward (float): The reward received.
|
||||
next_state (tuple): The state after the action.
|
||||
alpha (float): Learning rate (how much to update).
|
||||
gamma (float): Discount factor (how much to value future rewards).
|
||||
|
||||
Returns:
|
||||
None — modifies q_table in place.
|
||||
|
||||
Hint: Q-values for unseen (state, action) pairs start at 0.
|
||||
"""
|
||||
raise NotImplementedError("Fill in update_q")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Training loop
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def train(
|
||||
episodes=1000,
|
||||
alpha=0.1,
|
||||
gamma=0.95,
|
||||
epsilon=1.0,
|
||||
epsilon_decay=0.995,
|
||||
epsilon_min=0.05,
|
||||
):
|
||||
"""Train a Q-learning agent on BabySnake.
|
||||
|
||||
Arguments:
|
||||
episodes (int): How many episodes to run.
|
||||
alpha (float): Learning rate.
|
||||
gamma (float): Discount factor.
|
||||
epsilon (float): Starting exploration rate.
|
||||
epsilon_decay (float): Multiply epsilon by this each episode.
|
||||
epsilon_min (float): Epsilon never falls below this.
|
||||
|
||||
Returns:
|
||||
dict: The trained Q-table.
|
||||
"""
|
||||
q_table = {}
|
||||
env = BabySnakeEnv()
|
||||
|
||||
for episode in range(episodes):
|
||||
state = env.reset()
|
||||
total_reward = 0.0
|
||||
|
||||
while env.game.playing:
|
||||
action = choose_action(q_table, state, epsilon)
|
||||
next_state, reward, done = env.step(action)
|
||||
update_q(q_table, state, action, reward, next_state, alpha, gamma)
|
||||
state = next_state
|
||||
total_reward += reward
|
||||
|
||||
epsilon = max(epsilon_min, epsilon * epsilon_decay)
|
||||
|
||||
if (episode + 1) % 100 == 0:
|
||||
print(
|
||||
f"Episode {episode + 1:5d} "
|
||||
f"reward={total_reward:6.1f} "
|
||||
f"score={env.game.state['score']} "
|
||||
f"epsilon={epsilon:.3f} "
|
||||
f"q_entries={len(q_table)}"
|
||||
)
|
||||
|
||||
return q_table
|
||||
|
||||
|
||||
def watch(q_table=None):
|
||||
"""Watch the trained agent play in the terminal.
|
||||
|
||||
Arguments:
|
||||
q_table (dict | None): A trained Q-table. If None, trains first.
|
||||
"""
|
||||
import babysnake
|
||||
from retro.input import ProgrammaticInput
|
||||
|
||||
if q_table is None:
|
||||
print("Training first...")
|
||||
q_table = train()
|
||||
|
||||
_inp = ProgrammaticInput()
|
||||
|
||||
class PolicyInput:
|
||||
"""An input source that picks actions from the Q-table."""
|
||||
def collect(self):
|
||||
s = game.state
|
||||
state = (int(s['agent_x']), int(s['agent_y']),
|
||||
int(s['food_x']), int(s['food_y']))
|
||||
q_values = [q_table.get((state, a), 0.0) for a in ACTIONS]
|
||||
best = ACTIONS[q_values.index(max(q_values))]
|
||||
_inp.press(best)
|
||||
return _inp.collect()
|
||||
|
||||
game = babysnake.create_game()
|
||||
game.play(input_source=PolicyInput())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print("Training Q-learning agent on BabySnake...")
|
||||
q_table = train()
|
||||
print(f"\nDone. Q-table has {len(q_table)} entries.")
|
||||
print("\nWatching trained agent (press Enter or Escape to quit)...")
|
||||
watch(q_table)
|
||||
Reference in New Issue
Block a user