import json import random from .helpers import rolling_window, join class TinyLanguageModel: def __init__(self, n=2): "Create a new model that looks at n words at a time." self.n = n self.model = {} def train(self, words): "Learn word patterns from a list of words." for sequence in rolling_window(words, self.n + 1): pattern = tuple(sequence[:-1]) next_word = sequence[-1] if pattern not in self.model: self.model[pattern] = [] self.model[pattern].append(next_word) def generate(self, length, prompt=None): "Create new words based on what the model learned." if not self.model: raise Exception("The model has not been trained") output = list(prompt or self.get_random_pattern()) while len(output) < length: pattern = tuple(output[-self.n:]) if pattern not in self.model: break next_word = random.choice(self.model[pattern]) output.append(next_word) return join(output) def get_random_pattern(self): "Randomly chooses one of the observed patterns" return random.choice(list(self.model.keys())) def save(self, filepath): "Save the model to a file." model_data = { "n": self.n, "model": {" ".join(k): v for k, v in self.model.items()} } with open(filepath, "w") as f: json.dump(model_data, f) def load(self, filepath): "Load a model from a file." with open(filepath, "r") as f: data = json.load(f) self.n = data["n"] self.model = {tuple(k.split()): v for k, v in data["model"].items()}