Applying patch

This commit is contained in:
cplockport
2026-03-01 15:14:35 -05:00
parent cbbd194f33
commit 36fe692392
2 changed files with 15 additions and 16 deletions

View File

@@ -54,13 +54,13 @@ def generate(length, context_window_words, filepath, gutenberg, list_gutenberg,
raise click.UsageError("No training data provided. Must specify at least one of --filepath, --gutenberg, or --mbox.") raise click.UsageError("No training data provided. Must specify at least one of --filepath, --gutenberg, or --mbox.")
# Train and generate # Train and generate
model = TinyLanguageModel(n=context_window_words) tlm = TinyLanguageModel(n=context_window_words)
model.train(corpus) tlm.train(corpus)
if prompt: if prompt:
prompt_tokens = tokenize_text(prompt, tokenize_opts) prompt_tokens = tokenize_text(prompt, tokenize_opts)
else: else:
prompt_tokens = None prompt_tokens = None
join_fn = ''.join if 'char' in tokenize_opts else None join_fn = ''.join if 'char' in tokenize_opts else join
display_join = join_fn or join display_join = join_fn or join
if verbose: if verbose:
@@ -70,18 +70,18 @@ def generate(length, context_window_words, filepath, gutenberg, list_gutenberg,
def step_callback(pattern, options, chosen): def step_callback(pattern, options, chosen):
opts = textwrap.fill(', '.join(sorted(set(options))), width=60) opts = textwrap.fill(', '.join(sorted(set(options))), width=60)
rows.append([display_join(list(pattern)), opts, chosen]) rows.append([display_join(list(pattern)), opts, chosen])
output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn, step_callback=step_callback) output = tlm.generate(length, prompt=prompt_tokens, join_fn=join_fn, step_callback=step_callback)
click.echo(tabulate(rows, headers=["Context", "Options", "Selected"], tablefmt="simple")) click.echo(tabulate(rows, headers=["Context", "Options", "Selected"], tablefmt="simple"))
click.echo() click.echo()
else: else:
output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn) output = tlm.generate(length, prompt=prompt_tokens, join_fn=join_fn)
click.echo(output) click.echo(output)
if interact: if interact:
import code import code
code.interact(local=locals(), banner="Entering interactive shell. 'model' and 'output' are available.") code.interact(local=locals(), banner="Entering interactive shell. 'tlm' and 'output' are available.")
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -5,21 +5,20 @@ from .helpers import rolling_window, join
class TinyLanguageModel: class TinyLanguageModel:
def __init__(self, n=2): def __init__(self, n=2):
"Create a new model that looks at n words at a time." "Create a new model that looks at a n-word context window."
self.n = n self.n = n
self.model = {} self.model = {}
def train(self, words): def train(self, words):
"Learn word patterns from a list of words." "Learn word patterns from a list of words."
for sequence in rolling_window(words, self.n + 1): for sequence in rolling_window(words, self.n + 1):
pattern = tuple(sequence[:-1]) context, next_word = tuple(sequence[:-1]), sequence[-1]
next_word = sequence[-1] if context not in self.model:
if pattern not in self.model: self.model[context] = []
self.model[pattern] = [] self.model[context].append(next_word)
self.model[pattern].append(next_word)
def generate(self, length, prompt=None, join_fn=None, step_callback=None): def generate(self, length, prompt=None, join_fn=join, step_callback=None):
"Create new words based on what the model learned." "Generate new text based on what the model learned."
if not self.model: if not self.model:
raise Exception("The model has not been trained") raise Exception("The model has not been trained")
output = list(prompt or self.get_random_pattern()) output = list(prompt or self.get_random_pattern())
@@ -32,10 +31,10 @@ class TinyLanguageModel:
if step_callback: if step_callback:
step_callback(pattern, options, chosen) step_callback(pattern, options, chosen)
output.append(chosen) output.append(chosen)
return (join_fn or join)(output) return join_fn(output)
def get_random_pattern(self): def get_random_pattern(self):
"Randomly chooses one of the observed patterns" "Randomly chooses one of the observed contexts"
return random.choice(list(self.model.keys())) return random.choice(list(self.model.keys()))
def save(self, filepath): def save(self, filepath):