Initial commit
This commit is contained in:
53
tlm/model.py
Normal file
53
tlm/model.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import json
|
||||
import random
|
||||
from .helpers import rolling_window, join
|
||||
|
||||
|
||||
class TinyLanguageModel:
|
||||
def __init__(self, n=2):
|
||||
"Create a new model that looks at n words at a time."
|
||||
self.n = n
|
||||
self.model = {}
|
||||
|
||||
def train(self, words):
|
||||
"Learn word patterns from a list of words."
|
||||
for sequence in rolling_window(words, self.n + 1):
|
||||
pattern = tuple(sequence[:-1])
|
||||
next_word = sequence[-1]
|
||||
if pattern not in self.model:
|
||||
self.model[pattern] = []
|
||||
self.model[pattern].append(next_word)
|
||||
|
||||
def generate(self, length, prompt=None):
|
||||
"Create new words based on what the model learned."
|
||||
if not self.model:
|
||||
raise Exception("The model has not been trained")
|
||||
output = list(prompt or self.get_random_pattern())
|
||||
while len(output) < length:
|
||||
pattern = tuple(output[-self.n:])
|
||||
if pattern not in self.model:
|
||||
break
|
||||
next_word = random.choice(self.model[pattern])
|
||||
output.append(next_word)
|
||||
return join(output)
|
||||
|
||||
def get_random_pattern(self):
|
||||
"Randomly chooses one of the observed patterns"
|
||||
return random.choice(list(self.model.keys()))
|
||||
|
||||
def save(self, filepath):
|
||||
"Save the model to a file."
|
||||
model_data = {
|
||||
"n": self.n,
|
||||
"model": {" ".join(k): v for k, v in self.model.items()}
|
||||
}
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(model_data, f)
|
||||
|
||||
def load(self, filepath):
|
||||
"Load a model from a file."
|
||||
with open(filepath, "r") as f:
|
||||
data = json.load(f)
|
||||
self.n = data["n"]
|
||||
self.model = {tuple(k.split()): v for k, v in data["model"].items()}
|
||||
|
||||
Reference in New Issue
Block a user