Files
lab_matrices/tlm/model.py
2026-03-02 12:29:18 -05:00

115 lines
4.3 KiB
Python

import json
import random
import numpy as np
from .helpers import rolling_window, join
class TinyLanguageModel:
"""
A simple language model that predicts the next word based on the last n words.
The model stores everything it has learned in a matrix W with shape:
(vocabulary size) x (number of context windows seen)
Each row of W corresponds to one word in the vocabulary.
Each column of W corresponds to one context window (e.g. the words "the cat").
W[i, j] counts how many times word i was observed following context j.
To predict the next word, the model:
1. Represents the current context as a one-hot column vector x.
2. Computes Wx to get the counts for each word.
3. Divides by the total count to get a probability distribution.
4. Samples the next word from those probabilities.
"""
def __init__(self, n=2):
"Create a new model that looks at n words at a time."
self.n = n
self.vocab = None
self.word_to_idx = None
self.context_to_idx = None
self.W = None
def train(self, words):
"Learn word patterns from a list of words."
self.vocab, self.contexts = self.get_unique_contexts_and_words(words)
self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
self.context_to_idx = {ctx: idx for idx, ctx in enumerate(self.contexts)}
self.W = np.zeros((len(self.vocab), len(self.contexts)))
self.count_contexts_and_words(words)
def get_unique_contexts_and_words(self, words):
"Scan words and return the set of unique words and unique context windows."
unique_words = set()
unique_contexts = set()
for window in rolling_window(words, self.n + 1):
context, word = window[:-1], window[-1]
unique_words.add(word)
unique_contexts.add(context)
return sorted(unique_words), sorted(unique_contexts)
def count_contexts_and_words(self, words):
"Fill W by counting how often each word follows each context."
for window in rolling_window(words, self.n + 1):
context, word = window[:-1], window[-1]
self.W[self.word_to_idx[word], self.context_to_idx[context]] += 1
def generate(self, length, prompt=None, join_fn=None, step_callback=None):
"Create new text based on what the model learned."
if self.W is None:
raise Exception("The model has not been trained")
output = list(prompt or self.get_random_pattern())
while len(output) < length:
context = tuple(output[-self.n:])
if context not in self.context_to_idx:
break
context_col = self.context_to_idx[context]
one_hot = np.zeros(len(self.context_to_idx))
one_hot[context_col] = 1
counts = self.W @ one_hot
probs = counts / counts.sum()
chosen_idx = np.random.choice(len(self.vocab), p=probs)
chosen_word = self.vocab[chosen_idx]
if step_callback:
possible_next_words = [
self.vocab[j] for j in range(len(self.vocab)) if counts[j] > 0
]
step_callback(context, possible_next_words, chosen_word)
output.append(chosen_word)
return (join_fn or join)(output)
def get_random_pattern(self):
"Randomly chooses one of the observed context windows."
return random.choice(list(self.context_to_idx.keys()))
def save(self, filepath):
"Save the model to a file."
ordered_contexts = sorted(self.context_to_idx, key=self.context_to_idx.get)
model_data = {
"n": self.n,
"vocab": self.vocab,
"contexts": [list(ctx) for ctx in ordered_contexts],
"W": self.W.tolist(),
}
with open(filepath, "w") as f:
json.dump(model_data, f)
def load(self, filepath):
"Load a model from a file."
with open(filepath, "r") as f:
data = json.load(f)
self.n = data["n"]
self.vocab = data["vocab"]
self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
contexts = [tuple(ctx) for ctx in data["contexts"]]
self.context_to_idx = {ctx: idx for idx, ctx in enumerate(contexts)}
self.W = np.array(data["W"])