115 lines
4.3 KiB
Python
115 lines
4.3 KiB
Python
import json
|
|
import random
|
|
import numpy as np
|
|
from .helpers import rolling_window, join
|
|
|
|
|
|
class TinyLanguageModel:
|
|
"""
|
|
A simple language model that predicts the next word based on the last n words.
|
|
|
|
The model stores everything it has learned in a matrix W with shape:
|
|
|
|
(vocabulary size) x (number of context windows seen)
|
|
|
|
Each row of W corresponds to one word in the vocabulary.
|
|
Each column of W corresponds to one context window (e.g. the words "the cat").
|
|
W[i, j] counts how many times word i was observed following context j.
|
|
|
|
To predict the next word, the model:
|
|
1. Represents the current context as a one-hot column vector x.
|
|
2. Computes Wx to get the counts for each word.
|
|
3. Divides by the total count to get a probability distribution.
|
|
4. Samples the next word from those probabilities.
|
|
"""
|
|
|
|
def __init__(self, n=2):
|
|
"Create a new model that looks at n words at a time."
|
|
self.n = n
|
|
self.vocab = None
|
|
self.word_to_idx = None
|
|
self.context_to_idx = None
|
|
self.W = None
|
|
|
|
def train(self, words):
|
|
"Learn word patterns from a list of words."
|
|
self.vocab, self.contexts = self.get_unique_contexts_and_words(words)
|
|
self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
|
|
self.context_to_idx = {ctx: idx for idx, ctx in enumerate(self.contexts)}
|
|
self.W = np.zeros((len(self.vocab), len(self.contexts)))
|
|
self.count_contexts_and_words(words)
|
|
|
|
def get_unique_contexts_and_words(self, words):
|
|
"Scan words and return the set of unique words and unique context windows."
|
|
unique_words = set()
|
|
unique_contexts = set()
|
|
for window in rolling_window(words, self.n + 1):
|
|
context, word = window[:-1], window[-1]
|
|
unique_words.add(word)
|
|
unique_contexts.add(context)
|
|
return sorted(unique_words), sorted(unique_contexts)
|
|
|
|
def count_contexts_and_words(self, words):
|
|
"Fill W by counting how often each word follows each context."
|
|
for window in rolling_window(words, self.n + 1):
|
|
context, word = window[:-1], window[-1]
|
|
self.W[self.word_to_idx[word], self.context_to_idx[context]] += 1
|
|
|
|
def generate(self, length, prompt=None, join_fn=None, step_callback=None):
|
|
"Create new text based on what the model learned."
|
|
if self.W is None:
|
|
raise Exception("The model has not been trained")
|
|
|
|
output = list(prompt or self.get_random_pattern())
|
|
|
|
while len(output) < length:
|
|
context = tuple(output[-self.n:])
|
|
if context not in self.context_to_idx:
|
|
break
|
|
|
|
context_col = self.context_to_idx[context]
|
|
one_hot = np.zeros(len(self.context_to_idx))
|
|
one_hot[context_col] = 1
|
|
|
|
counts = self.W @ one_hot
|
|
probs = counts / counts.sum()
|
|
chosen_idx = np.random.choice(len(self.vocab), p=probs)
|
|
chosen_word = self.vocab[chosen_idx]
|
|
|
|
if step_callback:
|
|
possible_next_words = [
|
|
self.vocab[j] for j in range(len(self.vocab)) if counts[j] > 0
|
|
]
|
|
step_callback(context, possible_next_words, chosen_word)
|
|
|
|
output.append(chosen_word)
|
|
|
|
return (join_fn or join)(output)
|
|
|
|
def get_random_pattern(self):
|
|
"Randomly chooses one of the observed context windows."
|
|
return random.choice(list(self.context_to_idx.keys()))
|
|
|
|
def save(self, filepath):
|
|
"Save the model to a file."
|
|
ordered_contexts = sorted(self.context_to_idx, key=self.context_to_idx.get)
|
|
model_data = {
|
|
"n": self.n,
|
|
"vocab": self.vocab,
|
|
"contexts": [list(ctx) for ctx in ordered_contexts],
|
|
"W": self.W.tolist(),
|
|
}
|
|
with open(filepath, "w") as f:
|
|
json.dump(model_data, f)
|
|
|
|
def load(self, filepath):
|
|
"Load a model from a file."
|
|
with open(filepath, "r") as f:
|
|
data = json.load(f)
|
|
self.n = data["n"]
|
|
self.vocab = data["vocab"]
|
|
self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
|
|
contexts = [tuple(ctx) for ctx in data["contexts"]]
|
|
self.context_to_idx = {ctx: idx for idx, ctx in enumerate(contexts)}
|
|
self.W = np.array(data["W"])
|