114 lines
3.3 KiB
Python
114 lines
3.3 KiB
Python
"""Solution for q_learning.py — remove before publishing to students."""
|
|
|
|
import random
|
|
import babysnake
|
|
from retro.input import ProgrammaticInput
|
|
from retro.views.headless import HeadlessView
|
|
|
|
ACTIONS = ["KEY_RIGHT", "KEY_DOWN", "KEY_LEFT", "KEY_UP"]
|
|
|
|
|
|
class BabySnakeEnv:
|
|
def reset(self):
|
|
self._inp = ProgrammaticInput()
|
|
self.game = babysnake.create_game()
|
|
self.game.input_source = self._inp
|
|
self.game.view = HeadlessView()
|
|
self.game.start()
|
|
self._prev_reward = 0.0
|
|
return self._get_state()
|
|
|
|
def step(self, action):
|
|
self._inp.press(action)
|
|
self.game.step()
|
|
next_state = self._get_state()
|
|
reward = self.game.state['reward'] - self._prev_reward
|
|
self._prev_reward = self.game.state['reward']
|
|
done = not self.game.playing
|
|
return next_state, reward, done
|
|
|
|
def _get_state(self):
|
|
s = self.game.state
|
|
return (int(s['agent_x']), int(s['agent_y']),
|
|
int(s['food_x']), int(s['food_y']))
|
|
|
|
|
|
def choose_action(q_table, state, epsilon):
|
|
if random.random() < epsilon:
|
|
return random.choice(ACTIONS)
|
|
q_values = [q_table.get((state, a), 0.0) for a in ACTIONS]
|
|
return ACTIONS[q_values.index(max(q_values))]
|
|
|
|
|
|
def update_q(q_table, state, action, reward, next_state, alpha, gamma):
|
|
old_q = q_table.get((state, action), 0.0)
|
|
next_q_values = [q_table.get((next_state, a), 0.0) for a in ACTIONS]
|
|
best_next_q = max(next_q_values)
|
|
new_q = old_q + alpha * (reward + gamma * best_next_q - old_q)
|
|
q_table[(state, action)] = new_q
|
|
|
|
|
|
def train(
|
|
episodes=1000,
|
|
alpha=0.1,
|
|
gamma=0.95,
|
|
epsilon=1.0,
|
|
epsilon_decay=0.995,
|
|
epsilon_min=0.05,
|
|
):
|
|
q_table = {}
|
|
env = BabySnakeEnv()
|
|
|
|
for episode in range(episodes):
|
|
state = env.reset()
|
|
total_reward = 0.0
|
|
|
|
while env.game.playing:
|
|
action = choose_action(q_table, state, epsilon)
|
|
next_state, reward, done = env.step(action)
|
|
update_q(q_table, state, action, reward, next_state, alpha, gamma)
|
|
state = next_state
|
|
total_reward += reward
|
|
|
|
epsilon = max(epsilon_min, epsilon * epsilon_decay)
|
|
|
|
if (episode + 1) % 100 == 0:
|
|
print(
|
|
f"Episode {episode + 1:5d} "
|
|
f"reward={total_reward:6.1f} "
|
|
f"score={env.game.state['score']} "
|
|
f"epsilon={epsilon:.3f} "
|
|
f"q_entries={len(q_table)}"
|
|
)
|
|
|
|
return q_table
|
|
|
|
|
|
def watch(q_table=None):
|
|
if q_table is None:
|
|
print("Training first...")
|
|
q_table = train()
|
|
|
|
_inp = ProgrammaticInput()
|
|
|
|
class PolicyInput:
|
|
def collect(self):
|
|
s = game.state
|
|
state = (int(s['agent_x']), int(s['agent_y']),
|
|
int(s['food_x']), int(s['food_y']))
|
|
q_values = [q_table.get((state, a), 0.0) for a in ACTIONS]
|
|
best = ACTIONS[q_values.index(max(q_values))]
|
|
_inp.press(best)
|
|
return _inp.collect()
|
|
|
|
game = babysnake.create_game()
|
|
game.play(input_source=PolicyInput())
|
|
|
|
|
|
if __name__ == '__main__':
|
|
print("Training Q-learning agent on BabySnake...")
|
|
q_table = train()
|
|
print(f"\nDone. Q-table has {len(q_table)} entries.")
|
|
print("\nWatching trained agent (press Enter or Escape to quit)...")
|
|
watch(q_table)
|