From 36fe692392871bc0726f5a6518b4edc98a90b177 Mon Sep 17 00:00:00 2001 From: cplockport Date: Sun, 1 Mar 2026 15:14:35 -0500 Subject: [PATCH] Applying patch --- tlm/cli.py | 12 ++++++------ tlm/model.py | 19 +++++++++---------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/tlm/cli.py b/tlm/cli.py index a6006c0..adc1b24 100644 --- a/tlm/cli.py +++ b/tlm/cli.py @@ -54,13 +54,13 @@ def generate(length, context_window_words, filepath, gutenberg, list_gutenberg, raise click.UsageError("No training data provided. Must specify at least one of --filepath, --gutenberg, or --mbox.") # Train and generate - model = TinyLanguageModel(n=context_window_words) - model.train(corpus) + tlm = TinyLanguageModel(n=context_window_words) + tlm.train(corpus) if prompt: prompt_tokens = tokenize_text(prompt, tokenize_opts) else: prompt_tokens = None - join_fn = ''.join if 'char' in tokenize_opts else None + join_fn = ''.join if 'char' in tokenize_opts else join display_join = join_fn or join if verbose: @@ -70,18 +70,18 @@ def generate(length, context_window_words, filepath, gutenberg, list_gutenberg, def step_callback(pattern, options, chosen): opts = textwrap.fill(', '.join(sorted(set(options))), width=60) rows.append([display_join(list(pattern)), opts, chosen]) - output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn, step_callback=step_callback) + output = tlm.generate(length, prompt=prompt_tokens, join_fn=join_fn, step_callback=step_callback) click.echo(tabulate(rows, headers=["Context", "Options", "Selected"], tablefmt="simple")) click.echo() else: - output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn) + output = tlm.generate(length, prompt=prompt_tokens, join_fn=join_fn) click.echo(output) if interact: import code - code.interact(local=locals(), banner="Entering interactive shell. 'model' and 'output' are available.") + code.interact(local=locals(), banner="Entering interactive shell. 'tlm' and 'output' are available.") if __name__ == "__main__": diff --git a/tlm/model.py b/tlm/model.py index 0db8c29..df3baac 100644 --- a/tlm/model.py +++ b/tlm/model.py @@ -5,21 +5,20 @@ from .helpers import rolling_window, join class TinyLanguageModel: def __init__(self, n=2): - "Create a new model that looks at n words at a time." + "Create a new model that looks at a n-word context window." self.n = n self.model = {} def train(self, words): "Learn word patterns from a list of words." for sequence in rolling_window(words, self.n + 1): - pattern = tuple(sequence[:-1]) - next_word = sequence[-1] - if pattern not in self.model: - self.model[pattern] = [] - self.model[pattern].append(next_word) + context, next_word = tuple(sequence[:-1]), sequence[-1] + if context not in self.model: + self.model[context] = [] + self.model[context].append(next_word) - def generate(self, length, prompt=None, join_fn=None, step_callback=None): - "Create new words based on what the model learned." + def generate(self, length, prompt=None, join_fn=join, step_callback=None): + "Generate new text based on what the model learned." if not self.model: raise Exception("The model has not been trained") output = list(prompt or self.get_random_pattern()) @@ -32,10 +31,10 @@ class TinyLanguageModel: if step_callback: step_callback(pattern, options, chosen) output.append(chosen) - return (join_fn or join)(output) + return join_fn(output) def get_random_pattern(self): - "Randomly chooses one of the observed patterns" + "Randomly chooses one of the observed contexts" return random.choice(list(self.model.keys())) def save(self, filepath):