From 7ee660ed4a41bc72463bb694e63c8e82ca454cc2 Mon Sep 17 00:00:00 2001 From: cpub Date: Mon, 23 Feb 2026 11:35:46 -0500 Subject: [PATCH] Applying patch --- pyproject.toml | 1 + tlm/cli.py | 60 ++++++++++++++++++++++++++++++++++---------------- tlm/model.py | 11 +++++---- uv.lock | 11 +++++++++ 4 files changed, 60 insertions(+), 23 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 84a608a..add4f15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ requires-python = ">=3.10,<4.0" dependencies = [ "click>=8.3.1", "nltk>=3.9.2", + "tabulate>=0.9.0", "tqdm>=4.67.3", ] diff --git a/tlm/cli.py b/tlm/cli.py index d1ff2d5..3733b4e 100644 --- a/tlm/cli.py +++ b/tlm/cli.py @@ -1,6 +1,7 @@ import click from .model import TinyLanguageModel -from .helpers import read_mail_text +from .helpers import read_mail_text, join +from .tokenization import tokenize_text, tokenize_words @click.group() @@ -10,15 +11,17 @@ def cli(): @cli.command() -@click.option("--length", default=50, help="Number of words to generate.") -@click.option("--n", default=2, help="Number of words in the context window.") -@click.option("--text", type=click.Path(exists=True), multiple=True, help="Text file(s) to use as training corpus. Can be specified multiple times.") -@click.option("--gutenberg", multiple=True, help="NLTK Gutenberg corpus key(s). Can be specified multiple times.") -@click.option("--list-gutenberg", is_flag=True, help="List available Gutenberg corpus keys.") -@click.option("--mbox", type=click.Path(exists=True), help="Mbox file to use for training.") -@click.option("--prompt", help="Prompt to start generation.") -@click.option("--interact", is_flag=True, help="Drop into interactive shell after generating.") -def generate(length, n, text, gutenberg, list_gutenberg, mbox, prompt, interact): +@click.option('-l', "--length", default=50, help="Number of tokens to generate.") +@click.option('-n', "--context-window-words", default=2, help="Number of words in the context window.") +@click.option('-f', "--filepath", type=click.Path(exists=True), multiple=True, help="Text file(s) to use as training corpus. Can be specified multiple times.") +@click.option('-g', "--gutenberg", multiple=True, help="NLTK Gutenberg corpus key(s). Can be specified multiple times.") +@click.option('-G', "--list-gutenberg", is_flag=True, help="List available Gutenberg corpus keys.") +@click.option('-m', "--mbox", type=click.Path(exists=True), help="Mbox file to use for training.") +@click.option('-p', "--prompt", help="Prompt to start generation.") +@click.option('-i', "--interact", is_flag=True, help="Drop into interactive shell after generating.") +@click.option('-t', "--tokenize", 'tokenize_opts', multiple=True, type=click.Choice(['lower', 'char', 'alpha']), help="Preprocessing option (can be specified multiple times). 'lower': lowercase all input text. 'char': use characters as tokens instead of words.") +@click.option('-v', "--verbose", is_flag=True, help="Display step-by-step generation as a table.") +def generate(length, context_window_words, filepath, gutenberg, list_gutenberg, mbox, prompt, interact, tokenize_opts, verbose): """Generate text using the language model.""" import nltk @@ -34,27 +37,46 @@ def generate(length, n, text, gutenberg, list_gutenberg, mbox, prompt, interact) # Determine training corpus corpus = [] - if text: - for filepath in text: - with open(filepath, "r") as f: - corpus.extend(f.read().split()) + if filepath: + for fp in filepath: + with open(fp, "r") as f: + corpus.extend(tokenize_text(f.read(), tokenize_opts)) if gutenberg: nltk.download("gutenberg", quiet=True) from nltk.corpus import gutenberg as gutenberg_corpus for key in gutenberg: - corpus.extend(gutenberg_corpus.words(key)) + corpus.extend(tokenize_words(gutenberg_corpus.words(key), tokenize_opts)) if mbox: mail_text = read_mail_text(mbox) - corpus.extend(mail_text.split()) + corpus.extend(tokenize_text(mail_text, tokenize_opts)) if not corpus: raise click.UsageError("Must specify at least one of --text, --gutenberg, or --mbox for training data.") # Train and generate - model = TinyLanguageModel(n=n) + model = TinyLanguageModel(n=context_window_words) model.train(corpus) - prompt_words = prompt.split() if prompt else None - output = model.generate(length, prompt=prompt_words) + if prompt: + prompt_tokens = tokenize_text(prompt, tokenize_opts) + else: + prompt_tokens = None + join_fn = ''.join if 'char' in tokenize_opts else None + display_join = join_fn or join + + if verbose: + from tabulate import tabulate + rows = [] + import textwrap + def step_callback(pattern, options, chosen): + opts = textwrap.fill(', '.join(sorted(set(options))), width=60) + rows.append([display_join(list(pattern)), opts, chosen]) + output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn, step_callback=step_callback) + click.echo(tabulate(rows, headers=["Context", "Options", "Selected"], tablefmt="simple")) + click.echo() + + else: + output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn) + click.echo(output) if interact: diff --git a/tlm/model.py b/tlm/model.py index d9b6020..0db8c29 100644 --- a/tlm/model.py +++ b/tlm/model.py @@ -18,7 +18,7 @@ class TinyLanguageModel: self.model[pattern] = [] self.model[pattern].append(next_word) - def generate(self, length, prompt=None): + def generate(self, length, prompt=None, join_fn=None, step_callback=None): "Create new words based on what the model learned." if not self.model: raise Exception("The model has not been trained") @@ -27,9 +27,12 @@ class TinyLanguageModel: pattern = tuple(output[-self.n:]) if pattern not in self.model: break - next_word = random.choice(self.model[pattern]) - output.append(next_word) - return join(output) + options = self.model[pattern] + chosen = random.choice(options) + if step_callback: + step_callback(pattern, options, chosen) + output.append(chosen) + return (join_fn or join)(output) def get_random_pattern(self): "Randomly chooses one of the observed patterns" diff --git a/uv.lock b/uv.lock index b3631f2..90947bf 100644 --- a/uv.lock +++ b/uv.lock @@ -38,6 +38,7 @@ source = { editable = "." } dependencies = [ { name = "click" }, { name = "nltk" }, + { name = "tabulate" }, { name = "tqdm" }, ] @@ -45,6 +46,7 @@ dependencies = [ requires-dist = [ { name = "click", specifier = ">=8.3.1" }, { name = "nltk", specifier = ">=3.9.2" }, + { name = "tabulate", specifier = ">=0.9.0" }, { name = "tqdm", specifier = ">=4.67.3" }, ] @@ -184,6 +186,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/95/e4/a3b9480c78cf8ee86626cb06f8d931d74d775897d44201ccb813097ae697/regex-2026.1.15-cp314-cp314t-win_arm64.whl", hash = "sha256:ca89c5e596fc05b015f27561b3793dc2fa0917ea0d7507eebb448efd35274a70", size = 274837 }, ] +[[package]] +name = "tabulate" +version = "0.9.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/fe/802052aecb21e3797b8f7902564ab6ea0d60ff8ca23952079064155d1ae1/tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c", size = 81090 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252 }, +] + [[package]] name = "tqdm" version = "4.67.3"