Refactor lab with Strategy classes

This commit is contained in:
Chris Proctor 2022-05-06 16:46:03 -04:00
parent f49e78c35f
commit 6ad2672bd3
6 changed files with 179 additions and 86 deletions

View File

@ -9,6 +9,6 @@ view = TTTView()
view.greet(game)
while not game.is_over():
move = view.get_move(game)
game.play_move(move)
action = view.get_action(game)
game.play_action(action)
view.conclude(game)

98
strategy.py Normal file
View File

@ -0,0 +1,98 @@
from types import MethodType
from random import choice
class RandomStrategy:
"""A Strategy which randomly chooses a move. Not a great choice.
"""
def __init__(self, game):
self.validate_game(game)
self.game = game
def choose_action(self, state):
possible_actions = game.get_actions(state)
return choice(possible_actions)
class LookaheadStrategy:
"""A Strategy which considers the future consequences of an action.
To initialize a LookaheadStrategy, pass in an instance of a game containing
the following methods. These methods encode the rules of the game,
which a LookaheadStrategy needs to know in order to determine which move is best.
- get_next_state: state, action -> state
- get_actions: state -> [actions]
- get_reward: state -> int
- is_over: state -> bool
- get_objective: str -> function
Optionally, pass the following arguments to control the behavior of the LookaheadStrategy:
- max_depth: int. A game may be too complex to search the full state tree.
Setting max_depth will set a cutoff on how far ahead the LookaheadStrategy will look.
- deterministic: bool. It's possible that there are multiple equally-good actions.
When deterministic is True, LookaheadStrategy will always choose the first of the
equally-good actions, so that LookaheadStrategy will always play out the same game.
When deterministic is False, LookaheadStrategy will choose randomly from all actions
which are equally-good.
- Explain: When set to True, LookaheadStrategy will print out its reasoning.
"""
def __init__(self, game, max_depth=None, deterministic=True, explain=False):
self.validate_game(game)
self.game = game
self.max_depth = max_depth
self.deterministic = deterministic
self.explain = explain
def choose_action(self, state):
"""Given a state, chooses an action.
This is the most important method of a Strategy, corresponding to the situation where
it's a player's turn to play a game and she needs to decide what to do.
Strategy chooses an action by considering all possible actions, and finding the
total current and future reward which would come from playing that action.
Then we use the game's objective to choose the "best" reward. Usually bigger is better,
but in zero-sum games like tic tac toe, the players want opposite outcomes. One player
wants the reward to be high, while the other wants the reward to be low.
Once we know which reward is best, we choose an action which will lead to that reward.
"""
possible_actions = self.game.get_actions(state)
rewards = {}
for action in possible_actions:
future_state = self.game.get_next_state(state, action)
rewards[action] = self.game.get_reward(future_state)
objective = self.game.get_objective(state)
best_reward = objective(rewards.values())
best_actions = [action for action in possible_actions if rewards[action] == best_reward]
if self.deterministic:
return best_actions[0]
else:
return choice(best_actions)
def get_current_and_future_reward(self, state):
"""Calculates the reward from this state, and from all future states which would be
reached, assuming all players are using this Strategy.
"""
reward = self.game.get_reward(state)
if not self.game.is_over(state):
future_state = self.choose_action(state)
reward += self.get_current_and_future_reward(future_state)
return reward
def validate_game(self, game):
"Checks that the game has all the required methods."
required_methods = [
"get_next_state",
"get_actions",
"get_reward",
"is_over",
"get_objective",
]
for method in required_methods:
if not (hasattr(game, method) and isinstance(getattr(game, method), MethodType)):
message = f"Game {game} does not have method {method}."
raise ValueError(message)

View File

@ -2,47 +2,81 @@ class TTTGame:
"Models a tic-tac-toe game."
def __init__(self, playerX, playerO):
self.board = [None] * 9
self.turn_index = 0
self.state = self.get_initial_state()
self.players = {
'X': playerX,
'O': playerO,
}
def play_move(self, move):
"Updates the game's state by recording a move"
if not self.is_valid_move(move):
raise ValueError(f"Illegal move {move} with board {self.board}.")
self.board[move] = self.get_current_player_symbol()
self.turn_index += 1
def get_initial_state(self):
"Returns the game's initial state."
return {
"board": ['-', '-', '-', '-', '-', '-', '-', '-', '-'],
"player": "X",
}
def get_valid_moves(self):
def get_next_state(self, state, action):
"""Given a state and an action, returns the resulting state.
In the resulting state, the current player's symbol has been placed
in an empty board space, and it is the opposite player's turn.
"""
new_board = state["board"]
new_board[action] = state["player"]
if state["player"] == "O":
new_player = "X"
else:
new_player = "O"
return {
"board": new_board,
"player": new_player,
}
def get_actions(self, state):
"Returns a list of the indices of empty spaces"
return [index for index in range(9) if self.board[index] is None]
return [index for index in range(9) if state["board"][index] == '-']
def is_over(self):
"Checks whether the game is over."
return self.board_is_full() or self.check_winner('X') or self.check_winner('O')
def get_reward(self, state):
"""Determines the reward associated with reaching this state.
For tic-tac-toe, the two opponents each want a different game outcome. So
we set the reward for X winning to 1 and the reward for O winning to -1.
All other states (unfinished games and games which ended in a draw) are worth 0.
"""
if self.check_winner('X'):
return 1
elif self.check_winner('O'):
return -1
else:
return 0
def get_objective(self, state):
"""Returns a player's objective, or a function describing what a player wants.
This function should choose the best value from a list. In tic tac toe, the players
want opposite things, so we set X's objective to the built-in function `max`
(which chooses the largest number), and we set O's objective to the built-in function `min`.
"""
if state["player"] == 'X':
return max
elif state["player"] == 'O':
return min
else:
raise ValueError(f"Unrecognized player {state['player']}")
def play_action(self, action):
"Plays a move, updating the game's state."
self.state = self.get_next_state(self.state, action)
def is_valid_move(self, move):
"Checks whether a move is valid"
return move in self.get_valid_moves()
def get_current_player_symbol(self):
"Returns the symbol of the current player"
if self.turn_index % 2 == 0:
return 'X'
else:
return 'O'
def get_current_player(self):
"Returns the symbol of the current player and the current player"
return self.players[self.get_current_player_symbol()]
def is_over(self):
"Checks whether the game is over."
return self.board_is_full() or self.check_winner('X') or self.check_winner('O')
def board_is_full(self):
"Checks whether all the spaces in the board are occupied."
for space in self.board:
if space == None:
for space in self.state["board"]:
if space == '-':
return False
return True

View File

@ -9,16 +9,6 @@
from itertools import count
counter = count()
def get_next_state(state, action):
"""Returns the state which would result from taking an action at a particular state.
"""
new_board = [space for space in state["board"]]
new_board[action] = state["player"]
return {
"board": new_board,
"player": get_opponent(state["player"]),
}
def get_actions(state):
"""Given a board state, returns a dictionary whose keys are possible actions and whose
values are the resulting state from each action. If the game is over, returns an empty
@ -29,7 +19,7 @@ def get_actions(state):
else:
actions = {}
for i in range(9):
if state["board"][i] is None or state["board"][i] == '-':
if state["board"][i] is None or state["board"][i] == ' ':
actions[i] = get_next_state(state, i)
return actions
@ -72,34 +62,6 @@ def get_value(state, depth=0, explain=False):
# ================================== HELPERS ==============================================
# =========================================================================================
def get_opponent(player):
"Returns 'X' when player is 'O' and 'O' when player is 'X'"
if player == 'X':
return 'O'
elif player == 'O':
return 'X'
else:
raise ValueError(f"Unrecognized player {player}")
def is_over(state):
"Returns True if the game is over"
return is_draw(state) or is_win(state, 'X') or is_win(state, 'O')
def is_draw(state):
"Returns True if the game ended in a draw."
for space in state["board"]:
if space is None or space == '-':
return False
return True
def is_win(state, player):
"Returns True if `player` has won the game."
win_lines = [[0,1,2], [3,4,5], [6,7,8], [0,3,6], [1,4,7], [2,5,8], [0, 4, 8], [2, 4, 6]]
for a, b, c in win_lines:
if state["board"][a] == player and state["board"][b] == player and state["board"][c] == player:
return True
return False
def pose_question(state, index, depth):
"Logs a question asking about a player's best move at a state."
board = format_board_inline(state['board'])
@ -114,8 +76,3 @@ def log(message, index, depth):
"Prints a message at the appropriate depth, with each message numbered."
indent = ' ' * depth
print(f"{indent}{index}. {message}")
def format_board_inline(board):
"Formats a board like '[ OOX | -X- | --- ]' "
symbols = [sym or '-' for sym in board]
return '[ ' + ' | '.join([''.join(symbols[:3]), ''.join(symbols[3:6]), ''.join(symbols[6:])]) + ' ]'

View File

@ -1,4 +1,5 @@
from click import Choice, prompt
from strategy import RandomStrategy
import random
class TTTHumanPlayer:
@ -8,9 +9,9 @@ class TTTHumanPlayer:
"Sets up the player."
self.name = name
def choose_move(self, game):
"Chooses a move by prompting the player for a choice."
choices = Choice([str(i) for i in game.get_valid_moves()])
def choose_action(self, game):
"Chooses an action by prompting the player for a choice."
choices = Choice([str(i) for i in game.get_actions(game.state)])
move = prompt("> ", type=choices, show_choices=False)
return int(move)
@ -21,10 +22,11 @@ class TTTComputerPlayer:
"Sets up the player."
self.name = name
def choose_move(self, game):
def choose_action(self, game):
"Chooses a random move from the moves available."
move = random.choice(game.get_valid_moves())
print(f"{self.name} chooses {move}.")
strategy = RandomStrategy(game)
action = strategy.choose_action(game.state)
print(f"{self.name} chooses {action}.")
return move
def get_symbol(self, game):

View File

@ -17,18 +17,19 @@ class TTTView:
print(f"{x_name} will play as X.")
print(f"{o_name} will play as O.")
def get_move(self, game):
"Shows the board and asks the current player for their choice of move."
def get_action(self, game):
"Shows the board and asks the current player for their choice of action."
self.print_board_with_options(game)
player = game.get_current_player()
current_player_symbol = game.state["player"]
player = game.players[current_player_symbol]
print(f"{player.name}, it's your move.")
return player.choose_move(game)
return player.choose_action(game)
def print_board_with_options(self, game):
"Prints the current board, showing indices of available spaces"
print(self.format_row(game, [0, 1, 2]))
print(self.divider)
print(self.format_row(game, [3, 4, 6]))
print(self.format_row(game, [3, 4, 5]))
print(self.divider)
print(self.format_row(game, [6, 7, 8]))
@ -42,9 +43,9 @@ class TTTView:
If the game board already has a symbol in that space, formats that value for the Terminal.
If the space is empty, instead formats the index of the space.
"""
if game.board[index] == 'X':
if game.state["board"][index] == 'X':
return click.style('X', fg=self.x_color)
elif game.board[index] == 'O':
elif game.state["board"][index] == 'O':
return click.style('O', fg=self.o_color)
else:
return click.style(index, fg=self.option_color)
@ -52,6 +53,7 @@ class TTTView:
def conclude(self, game):
"""Says goodbye.
"""
self.print_board_with_options(game)
if game.check_winner('X'):
winner = game.players['X']
elif game.check_winner('O'):