171 lines
6.9 KiB
Python
171 lines
6.9 KiB
Python
import json
|
|
import random
|
|
import numpy as np
|
|
from .helpers import rolling_window, join, softmax
|
|
|
|
class TinyLanguageModel:
|
|
"""
|
|
A language model that uses learned embeddings to predict the next word.
|
|
|
|
Instead of counting how often each word follows each context (as in the
|
|
matrix model), this model learns dense vector representations of words,
|
|
called embeddings, and uses them to estimate the probability of the next word.
|
|
|
|
The model has two learned matrices:
|
|
|
|
E: (vocab_size x embedding_dim) -- one row per word in the vocabulary,
|
|
representing that word as a dense vector
|
|
W: (embedding_dim x vocab_size) -- maps an embedding to one score per word
|
|
|
|
To predict the next word, the model:
|
|
1. Looks up the embedding E[i] for each word i in the context window.
|
|
2. Averages those embeddings to get a single context vector.
|
|
3. Computes logits = context_vector @ W (one score per vocabulary word).
|
|
4. Applies softmax to turn scores into probabilities.
|
|
5. Samples the next word from those probabilities.
|
|
|
|
Training works by gradient descent: the model repeatedly sees (context, target)
|
|
pairs, makes a prediction, computes how wrong it was (cross-entropy loss),
|
|
and nudges E and W in the direction that reduces the loss.
|
|
"""
|
|
|
|
def __init__(self, n=2, embedding_dim=32):
|
|
"Create a new model with context window of n words and given embedding size."
|
|
self.n = n
|
|
self.embedding_dim = embedding_dim
|
|
self.vocab = None
|
|
self.word_to_idx = None
|
|
self.E = None # input embeddings: (vocab_size, embedding_dim)
|
|
self.W = None # output weights: (embedding_dim, vocab_size)
|
|
self.b = None # output bias: (vocab_size,)
|
|
|
|
def train(self, words, epochs=5, lr=0.05, resume=False):
|
|
"""Learn word embeddings from a list of words using gradient descent.
|
|
|
|
If resume=True, the existing vocabulary and learned matrices (E, W, b) are
|
|
kept and training continues from where it left off. Words in the corpus that
|
|
are not in the saved vocabulary are skipped with a warning.
|
|
"""
|
|
if resume:
|
|
oov = sorted({w for w in words if w not in self.word_to_idx})
|
|
if oov:
|
|
print(f"Warning: {len(oov)} word(s) not in saved vocabulary will be skipped: {oov[:10]}{'...' if len(oov) > 10 else ''}")
|
|
else:
|
|
self.vocab = sorted(set(words))
|
|
self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
|
|
self.initialize_matrices()
|
|
|
|
training_data = [
|
|
(window[:-1], window[-1])
|
|
for window in rolling_window(words, self.n + 1)
|
|
if all(w in self.word_to_idx for w in window)
|
|
]
|
|
|
|
from tqdm import tqdm
|
|
for epoch in range(epochs):
|
|
total_loss = 0.0
|
|
random.shuffle(training_data)
|
|
for context, target in tqdm(training_data, desc=f"Epoch {epoch + 1}/{epochs}", leave=False):
|
|
total_loss += self._step(context, target, lr)
|
|
print(f"Epoch {epoch + 1}/{epochs} loss={total_loss / len(training_data):.4f}")
|
|
|
|
def initialize_matrices(self):
|
|
"Randomly initializes embedding, weight, and bias matrices"
|
|
vocab_size = len(self.vocab)
|
|
rng = np.random.default_rng(42)
|
|
self.E = rng.standard_normal((vocab_size, self.embedding_dim)) * 0.01
|
|
self.W = rng.standard_normal((self.embedding_dim, vocab_size)) * 0.01
|
|
self.b = np.zeros(vocab_size)
|
|
|
|
def _context_embedding(self, context):
|
|
"Average the embeddings of the words in the context window."
|
|
indices = [self.word_to_idx[w] for w in context]
|
|
return self.E[indices].mean(axis=0)
|
|
|
|
def _forward(self, context):
|
|
"Return (context_emb, probs) for a given context tuple."
|
|
ctx_emb = self._context_embedding(context)
|
|
logits = ctx_emb @ self.W + self.b
|
|
probs = softmax(logits)
|
|
return ctx_emb, probs
|
|
|
|
def _step(self, context, target, lr):
|
|
"One gradient-descent step on a single (context, target) pair. Returns the loss."
|
|
ctx_emb, probs = self._forward(context)
|
|
target_idx = self.word_to_idx[target]
|
|
loss = -np.log(probs[target_idx] + 1e-12)
|
|
|
|
# Gradient of cross-entropy loss w.r.t. logits: probs with 1 subtracted at target
|
|
d_logits = probs.copy()
|
|
d_logits[target_idx] -= 1.0
|
|
|
|
# Gradients for output weights and bias
|
|
d_W = ctx_emb[:, None] @ d_logits[None, :] # (embedding_dim, vocab_size)
|
|
d_b = d_logits # (vocab_size,)
|
|
|
|
# Gradient flows back through averaging to each context word's embedding
|
|
d_ctx_emb = self.W @ d_logits # (embedding_dim,)
|
|
d_per_word = d_ctx_emb / len(context)
|
|
for idx in [self.word_to_idx[w] for w in context]:
|
|
self.E[idx] -= lr * d_per_word
|
|
|
|
self.W -= lr * d_W
|
|
self.b -= lr * d_b
|
|
|
|
return loss
|
|
|
|
def generate(self, length, prompt=None, join_fn=None, step_callback=None):
|
|
"Create new text using the learned embeddings."
|
|
if self.E is None:
|
|
raise Exception("The model has not been trained.")
|
|
|
|
output = list(prompt or self.get_random_prompt())
|
|
# Drop any prompt tokens not in vocabulary
|
|
output = [w for w in output if w in self.word_to_idx]
|
|
if len(output) < self.n:
|
|
output = list(self.get_random_prompt())
|
|
|
|
while len(output) < length:
|
|
context = tuple(output[-self.n:])
|
|
_, probs = self._forward(context)
|
|
chosen_idx = np.random.choice(len(self.vocab), p=probs)
|
|
chosen_word = self.vocab[chosen_idx]
|
|
|
|
if step_callback:
|
|
top_indices = np.argsort(probs)[-10:][::-1]
|
|
top_words = [self.vocab[i] for i in top_indices if probs[i] > 0.001]
|
|
step_callback(context, top_words, chosen_word)
|
|
|
|
output.append(chosen_word)
|
|
|
|
return (join_fn or join)(output)
|
|
|
|
def get_random_prompt(self):
|
|
"Return a random context window drawn from the vocabulary."
|
|
return tuple(random.choices(self.vocab, k=self.n))
|
|
|
|
def save(self, filepath):
|
|
"Save the model to a JSON file."
|
|
model_data = {
|
|
"n": self.n,
|
|
"embedding_dim": self.embedding_dim,
|
|
"vocab": self.vocab,
|
|
"E": self.E.tolist(),
|
|
"W": self.W.tolist(),
|
|
"b": self.b.tolist(),
|
|
}
|
|
with open(filepath, "w") as f:
|
|
json.dump(model_data, f)
|
|
|
|
def load(self, filepath):
|
|
"Load a model from a JSON file."
|
|
with open(filepath) as f:
|
|
data = json.load(f)
|
|
self.n = data["n"]
|
|
self.embedding_dim = data["embedding_dim"]
|
|
self.vocab = data["vocab"]
|
|
self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
|
|
self.E = np.array(data["E"])
|
|
self.W = np.array(data["W"])
|
|
self.b = np.array(data["b"])
|