54 lines
1.7 KiB
Python
54 lines
1.7 KiB
Python
import json
|
|
import random
|
|
from .helpers import rolling_window, join
|
|
|
|
|
|
class TinyLanguageModel:
|
|
def __init__(self, n=2):
|
|
"Create a new model that looks at n words at a time."
|
|
self.n = n
|
|
self.model = {}
|
|
|
|
def train(self, words):
|
|
"Learn word patterns from a list of words."
|
|
for sequence in rolling_window(words, self.n + 1):
|
|
pattern = tuple(sequence[:-1])
|
|
next_word = sequence[-1]
|
|
if pattern not in self.model:
|
|
self.model[pattern] = []
|
|
self.model[pattern].append(next_word)
|
|
|
|
def generate(self, length, prompt=None):
|
|
"Create new words based on what the model learned."
|
|
if not self.model:
|
|
raise Exception("The model has not been trained")
|
|
output = list(prompt or self.get_random_pattern())
|
|
while len(output) < length:
|
|
pattern = tuple(output[-self.n:])
|
|
if pattern not in self.model:
|
|
break
|
|
next_word = random.choice(self.model[pattern])
|
|
output.append(next_word)
|
|
return join(output)
|
|
|
|
def get_random_pattern(self):
|
|
"Randomly chooses one of the observed patterns"
|
|
return random.choice(list(self.model.keys()))
|
|
|
|
def save(self, filepath):
|
|
"Save the model to a file."
|
|
model_data = {
|
|
"n": self.n,
|
|
"model": {" ".join(k): v for k, v in self.model.items()}
|
|
}
|
|
with open(filepath, "w") as f:
|
|
json.dump(model_data, f)
|
|
|
|
def load(self, filepath):
|
|
"Load a model from a file."
|
|
with open(filepath, "r") as f:
|
|
data = json.load(f)
|
|
self.n = data["n"]
|
|
self.model = {tuple(k.split()): v for k, v in data["model"].items()}
|
|
|