Files
lab_embeddings/tlm/model.py
Chris Proctor 039a467a9f initial commit
2026-03-09 12:28:21 -04:00

171 lines
6.9 KiB
Python

import json
import random
import numpy as np
from .helpers import rolling_window, join, softmax
class TinyLanguageModel:
"""
A language model that uses learned embeddings to predict the next word.
Instead of counting how often each word follows each context (as in the
matrix model), this model learns dense vector representations of words,
called embeddings, and uses them to estimate the probability of the next word.
The model has two learned matrices:
E: (vocab_size x embedding_dim) -- one row per word in the vocabulary,
representing that word as a dense vector
W: (embedding_dim x vocab_size) -- maps an embedding to one score per word
To predict the next word, the model:
1. Looks up the embedding E[i] for each word i in the context window.
2. Averages those embeddings to get a single context vector.
3. Computes logits = context_vector @ W (one score per vocabulary word).
4. Applies softmax to turn scores into probabilities.
5. Samples the next word from those probabilities.
Training works by gradient descent: the model repeatedly sees (context, target)
pairs, makes a prediction, computes how wrong it was (cross-entropy loss),
and nudges E and W in the direction that reduces the loss.
"""
def __init__(self, n=2, embedding_dim=32):
"Create a new model with context window of n words and given embedding size."
self.n = n
self.embedding_dim = embedding_dim
self.vocab = None
self.word_to_idx = None
self.E = None # input embeddings: (vocab_size, embedding_dim)
self.W = None # output weights: (embedding_dim, vocab_size)
self.b = None # output bias: (vocab_size,)
def train(self, words, epochs=5, lr=0.05, resume=False):
"""Learn word embeddings from a list of words using gradient descent.
If resume=True, the existing vocabulary and learned matrices (E, W, b) are
kept and training continues from where it left off. Words in the corpus that
are not in the saved vocabulary are skipped with a warning.
"""
if resume:
oov = sorted({w for w in words if w not in self.word_to_idx})
if oov:
print(f"Warning: {len(oov)} word(s) not in saved vocabulary will be skipped: {oov[:10]}{'...' if len(oov) > 10 else ''}")
else:
self.vocab = sorted(set(words))
self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
self.initialize_matrices()
training_data = [
(window[:-1], window[-1])
for window in rolling_window(words, self.n + 1)
if all(w in self.word_to_idx for w in window)
]
from tqdm import tqdm
for epoch in range(epochs):
total_loss = 0.0
random.shuffle(training_data)
for context, target in tqdm(training_data, desc=f"Epoch {epoch + 1}/{epochs}", leave=False):
total_loss += self._step(context, target, lr)
print(f"Epoch {epoch + 1}/{epochs} loss={total_loss / len(training_data):.4f}")
def initialize_matrices(self):
"Randomly initializes embedding, weight, and bias matrices"
vocab_size = len(self.vocab)
rng = np.random.default_rng(42)
self.E = rng.standard_normal((vocab_size, self.embedding_dim)) * 0.01
self.W = rng.standard_normal((self.embedding_dim, vocab_size)) * 0.01
self.b = np.zeros(vocab_size)
def _context_embedding(self, context):
"Average the embeddings of the words in the context window."
indices = [self.word_to_idx[w] for w in context]
return self.E[indices].mean(axis=0)
def _forward(self, context):
"Return (context_emb, probs) for a given context tuple."
ctx_emb = self._context_embedding(context)
logits = ctx_emb @ self.W + self.b
probs = softmax(logits)
return ctx_emb, probs
def _step(self, context, target, lr):
"One gradient-descent step on a single (context, target) pair. Returns the loss."
ctx_emb, probs = self._forward(context)
target_idx = self.word_to_idx[target]
loss = -np.log(probs[target_idx] + 1e-12)
# Gradient of cross-entropy loss w.r.t. logits: probs with 1 subtracted at target
d_logits = probs.copy()
d_logits[target_idx] -= 1.0
# Gradients for output weights and bias
d_W = ctx_emb[:, None] @ d_logits[None, :] # (embedding_dim, vocab_size)
d_b = d_logits # (vocab_size,)
# Gradient flows back through averaging to each context word's embedding
d_ctx_emb = self.W @ d_logits # (embedding_dim,)
d_per_word = d_ctx_emb / len(context)
for idx in [self.word_to_idx[w] for w in context]:
self.E[idx] -= lr * d_per_word
self.W -= lr * d_W
self.b -= lr * d_b
return loss
def generate(self, length, prompt=None, join_fn=None, step_callback=None):
"Create new text using the learned embeddings."
if self.E is None:
raise Exception("The model has not been trained.")
output = list(prompt or self.get_random_prompt())
# Drop any prompt tokens not in vocabulary
output = [w for w in output if w in self.word_to_idx]
if len(output) < self.n:
output = list(self.get_random_prompt())
while len(output) < length:
context = tuple(output[-self.n:])
_, probs = self._forward(context)
chosen_idx = np.random.choice(len(self.vocab), p=probs)
chosen_word = self.vocab[chosen_idx]
if step_callback:
top_indices = np.argsort(probs)[-10:][::-1]
top_words = [self.vocab[i] for i in top_indices if probs[i] > 0.001]
step_callback(context, top_words, chosen_word)
output.append(chosen_word)
return (join_fn or join)(output)
def get_random_prompt(self):
"Return a random context window drawn from the vocabulary."
return tuple(random.choices(self.vocab, k=self.n))
def save(self, filepath):
"Save the model to a JSON file."
model_data = {
"n": self.n,
"embedding_dim": self.embedding_dim,
"vocab": self.vocab,
"E": self.E.tolist(),
"W": self.W.tolist(),
"b": self.b.tolist(),
}
with open(filepath, "w") as f:
json.dump(model_data, f)
def load(self, filepath):
"Load a model from a JSON file."
with open(filepath) as f:
data = json.load(f)
self.n = data["n"]
self.embedding_dim = data["embedding_dim"]
self.vocab = data["vocab"]
self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
self.E = np.array(data["E"])
self.W = np.array(data["W"])
self.b = np.array(data["b"])