diff --git a/ttt_learn.py b/ttt_learn.py index f467902..a383390 100644 --- a/ttt_learn.py +++ b/ttt_learn.py @@ -1,3 +1,4 @@ + # A state is a dictionary with two keys, "board" and "player." Here's an example: # # { @@ -5,18 +6,17 @@ # "player": "X", # } +from itertools import count +counter = count() def get_next_state(state, action): """Returns the state which would result from taking an action at a particular state. """ - if state["board"][action] is not None: - raise ValueError(f"Action {action} is illegal at state {state}; the space is occupied.") - new_board = state["board"].copy() + new_board = [space for space in state["board"]] new_board[action] = state["player"] - new_player = get_opponent(state["player"]) return { "board": new_board, - "player": new_player, + "player": get_opponent(state["player"]), } def get_actions(state): @@ -29,26 +29,33 @@ def get_actions(state): else: actions = {} for i in range(9): - if state["board"][i] is None: + if state["board"][i] is None or state["board"][i] == '-': actions[i] = get_next_state(state, i) return actions -def choose_best_action(state): +def choose_best_action(state, depth=0, explain=False): """Given a state, returns the best action, its resulting state, and that state's value. For each possible action, we find the value of the resulting state. Then, if the player is 'X', choose the action corresponding to the highest value. If the player is 'O', choose the action corresponding to the lowest value. """ + if explain: + question_number = next(counter) + pose_question(state, question_number, depth) actions = get_actions(state) - values_and_actions = [[get_value(result), action] for action, result in actions.items()] + values_and_actions = [] + for action, result in actions.items(): + values_and_actions.append([get_value(result, depth=depth, explain=explain), action]) if state["player"] == "X": value, action = max(values_and_actions) else: value, action = min(values_and_actions) + if explain: + answer_question(state, action, question_number, depth) return action, actions[action], value - -def get_value(state, depth=0, debug=False): + +def get_value(state, depth=0, explain=False): """Determines the value of the state. """ if is_win(state, 'X'): @@ -58,10 +65,12 @@ def get_value(state, depth=0, debug=False): elif is_draw(state): return 0 else: - action, result, value = choose_best_action(state) + action, result, value = choose_best_action(state, depth=depth+1, explain=explain) return value +# ========================================================================================= # ================================== HELPERS ============================================== +# ========================================================================================= def get_opponent(player): "Returns 'X' when player is 'O' and 'O' when player is 'X'" @@ -79,7 +88,7 @@ def is_over(state): def is_draw(state): "Returns True if the game ended in a draw." for space in state["board"]: - if space is None: + if space is None or space == '-': return False return True @@ -90,3 +99,23 @@ def is_win(state, player): if state["board"][a] == player and state["board"][b] == player and state["board"][c] == player: return True return False + +def pose_question(state, index, depth): + "Logs a question asking about a player's best move at a state." + board = format_board_inline(state['board']) + log(f"What is the best action for {state['player']} at {board}?", index, depth) + +def answer_question(state, action, index, depth): + "Logs the answer to a question about a player's best move at a state." + board = format_board_inline(state['board']) + log(f"The best action for {state['player']} at {board} is {action}", index, depth) + +def log(message, index, depth): + "Prints a message at the appropriate depth, with each message numbered." + indent = ' ' * depth + print(f"{indent}{index}. {message}") + +def format_board_inline(board): + "Formats a board like '[ OOX | -X- | --- ]' " + symbols = [sym or '-' for sym in board] + return '[ ' + ' | '.join([''.join(symbols[:3]), ''.join(symbols[3:6]), ''.join(symbols[6:])]) + ' ]'