generated from mwc/lab_tinylm
Applying patch
This commit is contained in:
12
tlm/cli.py
12
tlm/cli.py
@@ -54,13 +54,13 @@ def generate(length, context_window_words, filepath, gutenberg, list_gutenberg,
|
||||
raise click.UsageError("No training data provided. Must specify at least one of --filepath, --gutenberg, or --mbox.")
|
||||
|
||||
# Train and generate
|
||||
model = TinyLanguageModel(n=context_window_words)
|
||||
model.train(corpus)
|
||||
tlm = TinyLanguageModel(n=context_window_words)
|
||||
tlm.train(corpus)
|
||||
if prompt:
|
||||
prompt_tokens = tokenize_text(prompt, tokenize_opts)
|
||||
else:
|
||||
prompt_tokens = None
|
||||
join_fn = ''.join if 'char' in tokenize_opts else None
|
||||
join_fn = ''.join if 'char' in tokenize_opts else join
|
||||
display_join = join_fn or join
|
||||
|
||||
if verbose:
|
||||
@@ -70,18 +70,18 @@ def generate(length, context_window_words, filepath, gutenberg, list_gutenberg,
|
||||
def step_callback(pattern, options, chosen):
|
||||
opts = textwrap.fill(', '.join(sorted(set(options))), width=60)
|
||||
rows.append([display_join(list(pattern)), opts, chosen])
|
||||
output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn, step_callback=step_callback)
|
||||
output = tlm.generate(length, prompt=prompt_tokens, join_fn=join_fn, step_callback=step_callback)
|
||||
click.echo(tabulate(rows, headers=["Context", "Options", "Selected"], tablefmt="simple"))
|
||||
click.echo()
|
||||
|
||||
else:
|
||||
output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn)
|
||||
output = tlm.generate(length, prompt=prompt_tokens, join_fn=join_fn)
|
||||
|
||||
click.echo(output)
|
||||
|
||||
if interact:
|
||||
import code
|
||||
code.interact(local=locals(), banner="Entering interactive shell. 'model' and 'output' are available.")
|
||||
code.interact(local=locals(), banner="Entering interactive shell. 'tlm' and 'output' are available.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
19
tlm/model.py
19
tlm/model.py
@@ -5,21 +5,20 @@ from .helpers import rolling_window, join
|
||||
|
||||
class TinyLanguageModel:
|
||||
def __init__(self, n=2):
|
||||
"Create a new model that looks at n words at a time."
|
||||
"Create a new model that looks at a n-word context window."
|
||||
self.n = n
|
||||
self.model = {}
|
||||
|
||||
def train(self, words):
|
||||
"Learn word patterns from a list of words."
|
||||
for sequence in rolling_window(words, self.n + 1):
|
||||
pattern = tuple(sequence[:-1])
|
||||
next_word = sequence[-1]
|
||||
if pattern not in self.model:
|
||||
self.model[pattern] = []
|
||||
self.model[pattern].append(next_word)
|
||||
context, next_word = tuple(sequence[:-1]), sequence[-1]
|
||||
if context not in self.model:
|
||||
self.model[context] = []
|
||||
self.model[context].append(next_word)
|
||||
|
||||
def generate(self, length, prompt=None, join_fn=None, step_callback=None):
|
||||
"Create new words based on what the model learned."
|
||||
def generate(self, length, prompt=None, join_fn=join, step_callback=None):
|
||||
"Generate new text based on what the model learned."
|
||||
if not self.model:
|
||||
raise Exception("The model has not been trained")
|
||||
output = list(prompt or self.get_random_pattern())
|
||||
@@ -32,10 +31,10 @@ class TinyLanguageModel:
|
||||
if step_callback:
|
||||
step_callback(pattern, options, chosen)
|
||||
output.append(chosen)
|
||||
return (join_fn or join)(output)
|
||||
return join_fn(output)
|
||||
|
||||
def get_random_pattern(self):
|
||||
"Randomly chooses one of the observed patterns"
|
||||
"Randomly chooses one of the observed contexts"
|
||||
return random.choice(list(self.model.keys()))
|
||||
|
||||
def save(self, filepath):
|
||||
|
||||
Reference in New Issue
Block a user