import json import random import numpy as np from .helpers import rolling_window, join, softmax class TinyLanguageModel: """ A language model that uses learned embeddings to predict the next word. Instead of counting how often each word follows each context (as in the matrix model), this model learns dense vector representations of words, called embeddings, and uses them to estimate the probability of the next word. The model has two learned matrices: E: (vocab_size x embedding_dim) -- one row per word in the vocabulary, representing that word as a dense vector W: (embedding_dim x vocab_size) -- maps an embedding to one score per word To predict the next word, the model: 1. Looks up the embedding E[i] for each word i in the context window. 2. Averages those embeddings to get a single context vector. 3. Computes logits = context_vector @ W (one score per vocabulary word). 4. Applies softmax to turn scores into probabilities. 5. Samples the next word from those probabilities. Training works by gradient descent: the model repeatedly sees (context, target) pairs, makes a prediction, computes how wrong it was (cross-entropy loss), and nudges E and W in the direction that reduces the loss. """ def __init__(self, n=2, embedding_dim=32): "Create a new model with context window of n words and given embedding size." self.n = n self.embedding_dim = embedding_dim self.vocab = None self.word_to_idx = None self.E = None # input embeddings: (vocab_size, embedding_dim) self.W = None # output weights: (embedding_dim, vocab_size) self.b = None # output bias: (vocab_size,) def train(self, words, epochs=5, lr=0.05, resume=False): """Learn word embeddings from a list of words using gradient descent. If resume=True, the existing vocabulary and learned matrices (E, W, b) are kept and training continues from where it left off. Words in the corpus that are not in the saved vocabulary are skipped with a warning. """ if resume: oov = sorted({w for w in words if w not in self.word_to_idx}) if oov: print(f"Warning: {len(oov)} word(s) not in saved vocabulary will be skipped: {oov[:10]}{'...' if len(oov) > 10 else ''}") else: self.vocab = sorted(set(words)) self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)} self.initialize_matrices() training_data = [ (window[:-1], window[-1]) for window in rolling_window(words, self.n + 1) if all(w in self.word_to_idx for w in window) ] from tqdm import tqdm for epoch in range(epochs): total_loss = 0.0 random.shuffle(training_data) for context, target in tqdm(training_data, desc=f"Epoch {epoch + 1}/{epochs}", leave=False): total_loss += self._step(context, target, lr) print(f"Epoch {epoch + 1}/{epochs} loss={total_loss / len(training_data):.4f}") def initialize_matrices(self): "Randomly initializes embedding, weight, and bias matrices" vocab_size = len(self.vocab) rng = np.random.default_rng(42) self.E = rng.standard_normal((vocab_size, self.embedding_dim)) * 0.01 self.W = rng.standard_normal((self.embedding_dim, vocab_size)) * 0.01 self.b = np.zeros(vocab_size) def _context_embedding(self, context): "Average the embeddings of the words in the context window." indices = [self.word_to_idx[w] for w in context] return self.E[indices].mean(axis=0) def _forward(self, context): "Return (context_emb, probs) for a given context tuple." ctx_emb = self._context_embedding(context) logits = ctx_emb @ self.W + self.b probs = softmax(logits) return ctx_emb, probs def _step(self, context, target, lr): "One gradient-descent step on a single (context, target) pair. Returns the loss." ctx_emb, probs = self._forward(context) target_idx = self.word_to_idx[target] loss = -np.log(probs[target_idx] + 1e-12) # Gradient of cross-entropy loss w.r.t. logits: probs with 1 subtracted at target d_logits = probs.copy() d_logits[target_idx] -= 1.0 # Gradients for output weights and bias d_W = ctx_emb[:, None] @ d_logits[None, :] # (embedding_dim, vocab_size) d_b = d_logits # (vocab_size,) # Gradient flows back through averaging to each context word's embedding d_ctx_emb = self.W @ d_logits # (embedding_dim,) d_per_word = d_ctx_emb / len(context) for idx in [self.word_to_idx[w] for w in context]: self.E[idx] -= lr * d_per_word self.W -= lr * d_W self.b -= lr * d_b return loss def generate(self, length, prompt=None, join_fn=None, step_callback=None): "Create new text using the learned embeddings." if self.E is None: raise Exception("The model has not been trained.") output = list(prompt or self.get_random_prompt()) # Drop any prompt tokens not in vocabulary output = [w for w in output if w in self.word_to_idx] if len(output) < self.n: output = list(self.get_random_prompt()) while len(output) < length: context = tuple(output[-self.n:]) _, probs = self._forward(context) chosen_idx = np.random.choice(len(self.vocab), p=probs) chosen_word = self.vocab[chosen_idx] if step_callback: top_indices = np.argsort(probs)[-10:][::-1] top_words = [self.vocab[i] for i in top_indices if probs[i] > 0.001] step_callback(context, top_words, chosen_word) output.append(chosen_word) return (join_fn or join)(output) def get_random_prompt(self): "Return a random context window drawn from the vocabulary." return tuple(random.choices(self.vocab, k=self.n)) def save(self, filepath): "Save the model to a JSON file." model_data = { "n": self.n, "embedding_dim": self.embedding_dim, "vocab": self.vocab, "E": self.E.tolist(), "W": self.W.tolist(), "b": self.b.tolist(), } with open(filepath, "w") as f: json.dump(model_data, f) def load(self, filepath): "Load a model from a JSON file." with open(filepath) as f: data = json.load(f) self.n = data["n"] self.embedding_dim = data["embedding_dim"] self.vocab = data["vocab"] self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)} self.E = np.array(data["E"]) self.W = np.array(data["W"]) self.b = np.array(data["b"])