import json import random import numpy as np from .helpers import rolling_window, join class TinyLanguageModel: """ A simple language model that predicts the next word based on the last n words. The model stores everything it has learned in a matrix W with shape: (vocabulary size) x (number of context windows seen) Each row of W corresponds to one word in the vocabulary. Each column of W corresponds to one context window (e.g. the words "the cat"). W[i, j] counts how many times word i was observed following context j. To predict the next word, the model: 1. Represents the current context as a one-hot column vector x. 2. Computes Wx to get the counts for each word. 3. Divides by the total count to get a probability distribution. 4. Samples the next word from those probabilities. """ def __init__(self, n=2): "Create a new model that looks at n words at a time." self.n = n self.vocab = None self.word_to_idx = None self.context_to_idx = None self.W = None def train(self, words): "Learn word patterns from a list of words." self.vocab, self.contexts = self.get_unique_contexts_and_words(words) self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)} self.context_to_idx = {ctx: idx for idx, ctx in enumerate(self.contexts)} self.W = np.zeros((len(self.vocab), len(self.contexts))) self.count_contexts_and_words(words) def get_unique_contexts_and_words(self, words): "Scan words and return the set of unique words and unique context windows." unique_words = set() unique_contexts = set() for window in rolling_window(words, self.n + 1): context, word = window[:-1], window[-1] unique_words.add(word) unique_contexts.add(context) return sorted(unique_words), sorted(unique_contexts) def count_contexts_and_words(self, words): "Fill W by counting how often each word follows each context." for window in rolling_window(words, self.n + 1): context, word = window[:-1], window[-1] self.W[self.word_to_idx[word], self.context_to_idx[context]] += 1 def generate(self, length, prompt=None, join_fn=None, step_callback=None): "Create new text based on what the model learned." if self.W is None: raise Exception("The model has not been trained") output = list(prompt or self.get_random_pattern()) while len(output) < length: context = tuple(output[-self.n:]) if context not in self.context_to_idx: break context_col = self.context_to_idx[context] one_hot = np.zeros(len(self.context_to_idx)) one_hot[context_col] = 1 counts = self.W @ one_hot probs = counts / counts.sum() chosen_idx = np.random.choice(len(self.vocab), p=probs) chosen_word = self.vocab[chosen_idx] if step_callback: possible_next_words = [ self.vocab[j] for j in range(len(self.vocab)) if counts[j] > 0 ] step_callback(context, possible_next_words, chosen_word) output.append(chosen_word) return (join_fn or join)(output) def get_random_pattern(self): "Randomly chooses one of the observed context windows." return random.choice(list(self.context_to_idx.keys())) def save(self, filepath): "Save the model to a file." ordered_contexts = sorted(self.context_to_idx, key=self.context_to_idx.get) model_data = { "n": self.n, "vocab": self.vocab, "contexts": [list(ctx) for ctx in ordered_contexts], "W": self.W.tolist(), } with open(filepath, "w") as f: json.dump(model_data, f) def load(self, filepath): "Load a model from a file." with open(filepath, "r") as f: data = json.load(f) self.n = data["n"] self.vocab = data["vocab"] self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)} contexts = [tuple(ctx) for ctx in data["contexts"]] self.context_to_idx = {ctx: idx for idx, ctx in enumerate(contexts)} self.W = np.array(data["W"])