Files
lab_tinylm/tlm/model.py
2026-02-19 14:59:57 -05:00

57 lines
1.9 KiB
Python

import json
import random
from .helpers import rolling_window, join
class TinyLanguageModel:
def __init__(self, n=2):
"Create a new model that looks at n words at a time."
self.n = n
self.model = {}
def train(self, words):
"Learn word patterns from a list of words."
for sequence in rolling_window(words, self.n + 1):
pattern = tuple(sequence[:-1])
next_word = sequence[-1]
if pattern not in self.model:
self.model[pattern] = []
self.model[pattern].append(next_word)
def generate(self, length, prompt=None, join_fn=None, step_callback=None):
"Create new words based on what the model learned."
if not self.model:
raise Exception("The model has not been trained")
output = list(prompt or self.get_random_pattern())
while len(output) < length:
pattern = tuple(output[-self.n:])
if pattern not in self.model:
break
options = self.model[pattern]
chosen = random.choice(options)
if step_callback:
step_callback(pattern, options, chosen)
output.append(chosen)
return (join_fn or join)(output)
def get_random_pattern(self):
"Randomly chooses one of the observed patterns"
return random.choice(list(self.model.keys()))
def save(self, filepath):
"Save the model to a file."
model_data = {
"n": self.n,
"model": {" ".join(k): v for k, v in self.model.items()}
}
with open(filepath, "w") as f:
json.dump(model_data, f)
def load(self, filepath):
"Load a model from a file."
with open(filepath, "r") as f:
data = json.load(f)
self.n = data["n"]
self.model = {tuple(k.split()): v for k, v in data["model"].items()}