First version of lab_matrices
This commit is contained in:
114
tlm/model.py
Normal file
114
tlm/model.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import json
|
||||
import random
|
||||
import numpy as np
|
||||
from .helpers import rolling_window, join
|
||||
|
||||
|
||||
class TinyLanguageModel:
|
||||
"""
|
||||
A simple language model that predicts the next word based on the last n words.
|
||||
|
||||
The model stores everything it has learned in a matrix W with shape:
|
||||
|
||||
(vocabulary size) x (number of context windows seen)
|
||||
|
||||
Each row of W corresponds to one word in the vocabulary.
|
||||
Each column of W corresponds to one context window (e.g. the words "the cat").
|
||||
W[i, j] counts how many times word i was observed following context j.
|
||||
|
||||
To predict the next word, the model:
|
||||
1. Represents the current context as a one-hot column vector x.
|
||||
2. Computes Wx to get the counts for each word.
|
||||
3. Divides by the total count to get a probability distribution.
|
||||
4. Samples the next word from those probabilities.
|
||||
"""
|
||||
|
||||
def __init__(self, n=2):
|
||||
"Create a new model that looks at n words at a time."
|
||||
self.n = n
|
||||
self.vocab = None
|
||||
self.word_to_idx = None
|
||||
self.context_to_idx = None
|
||||
self.W = None
|
||||
|
||||
def train(self, words):
|
||||
"Learn word patterns from a list of words."
|
||||
self.vocab, self.contexts = self.get_unique_contexts_and_words(words)
|
||||
self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
|
||||
self.context_to_idx = {ctx: idx for idx, ctx in enumerate(self.contexts)}
|
||||
self.W = np.zeros((len(self.vocab), len(self.contexts)))
|
||||
self.count_contexts_and_words(words)
|
||||
|
||||
def get_unique_contexts_and_words(self, words):
|
||||
"Scan words and return the set of unique words and unique context windows."
|
||||
unique_words = set()
|
||||
unique_contexts = set()
|
||||
for window in rolling_window(words, self.n + 1):
|
||||
context, word = window[:-1], window[-1]
|
||||
unique_words.add(word)
|
||||
unique_contexts.add(context)
|
||||
return sorted(unique_words), sorted(unique_contexts)
|
||||
|
||||
def count_contexts_and_words(self, words):
|
||||
"Fill W by counting how often each word follows each context."
|
||||
for window in rolling_window(words, self.n + 1):
|
||||
context, word = window[:-1], window[-1]
|
||||
self.W[self.word_to_idx[word], self.context_to_idx[context]] += 1
|
||||
|
||||
def generate(self, length, prompt=None, join_fn=None, step_callback=None):
|
||||
"Create new text based on what the model learned."
|
||||
if self.W is None:
|
||||
raise Exception("The model has not been trained")
|
||||
|
||||
output = list(prompt or self.get_random_pattern())
|
||||
|
||||
while len(output) < length:
|
||||
context = tuple(output[-self.n:])
|
||||
if context not in self.context_to_idx:
|
||||
break
|
||||
|
||||
context_col = self.context_to_idx[context]
|
||||
one_hot = np.zeros(len(self.context_to_idx))
|
||||
one_hot[context_col] = 1
|
||||
|
||||
counts = self.W @ one_hot
|
||||
probs = counts / counts.sum()
|
||||
chosen_idx = np.random.choice(len(self.vocab), p=probs)
|
||||
chosen_word = self.vocab[chosen_idx]
|
||||
|
||||
if step_callback:
|
||||
possible_next_words = [
|
||||
self.vocab[j] for j in range(len(self.vocab)) if counts[j] > 0
|
||||
]
|
||||
step_callback(context, possible_next_words, chosen_word)
|
||||
|
||||
output.append(chosen_word)
|
||||
|
||||
return (join_fn or join)(output)
|
||||
|
||||
def get_random_pattern(self):
|
||||
"Randomly chooses one of the observed context windows."
|
||||
return random.choice(list(self.context_to_idx.keys()))
|
||||
|
||||
def save(self, filepath):
|
||||
"Save the model to a file."
|
||||
ordered_contexts = sorted(self.context_to_idx, key=self.context_to_idx.get)
|
||||
model_data = {
|
||||
"n": self.n,
|
||||
"vocab": self.vocab,
|
||||
"contexts": [list(ctx) for ctx in ordered_contexts],
|
||||
"W": self.W.tolist(),
|
||||
}
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(model_data, f)
|
||||
|
||||
def load(self, filepath):
|
||||
"Load a model from a file."
|
||||
with open(filepath, "r") as f:
|
||||
data = json.load(f)
|
||||
self.n = data["n"]
|
||||
self.vocab = data["vocab"]
|
||||
self.word_to_idx = {word: idx for idx, word in enumerate(self.vocab)}
|
||||
contexts = [tuple(ctx) for ctx in data["contexts"]]
|
||||
self.context_to_idx = {ctx: idx for idx, ctx in enumerate(contexts)}
|
||||
self.W = np.array(data["W"])
|
||||
Reference in New Issue
Block a user