Add lots of features

This commit is contained in:
chris
2026-02-19 14:59:57 -05:00
parent f372786dbc
commit 01f58ded9a
5 changed files with 91 additions and 23 deletions

View File

@@ -1,6 +1,7 @@
import click
from .model import TinyLanguageModel
from .helpers import read_mail_text
from .helpers import read_mail_text, join
from .tokenization import tokenize_text, tokenize_words
@click.group()
@@ -10,15 +11,17 @@ def cli():
@cli.command()
@click.option("--length", default=50, help="Number of words to generate.")
@click.option("--n", default=2, help="Number of words in the context window.")
@click.option("--text", type=click.Path(exists=True), multiple=True, help="Text file(s) to use as training corpus. Can be specified multiple times.")
@click.option("--gutenberg", multiple=True, help="NLTK Gutenberg corpus key(s). Can be specified multiple times.")
@click.option("--list-gutenberg", is_flag=True, help="List available Gutenberg corpus keys.")
@click.option("--mbox", type=click.Path(exists=True), help="Mbox file to use for training.")
@click.option("--prompt", help="Prompt to start generation.")
@click.option("--interact", is_flag=True, help="Drop into interactive shell after generating.")
def generate(length, n, text, gutenberg, list_gutenberg, mbox, prompt, interact):
@click.option('-l', "--length", default=50, help="Number of tokens to generate.")
@click.option('-n', "--context-window-words", default=2, help="Number of words in the context window.")
@click.option('-f', "--filepath", type=click.Path(exists=True), multiple=True, help="Text file(s) to use as training corpus. Can be specified multiple times.")
@click.option('-g', "--gutenberg", multiple=True, help="NLTK Gutenberg corpus key(s). Can be specified multiple times.")
@click.option('-G', "--list-gutenberg", is_flag=True, help="List available Gutenberg corpus keys.")
@click.option('-m', "--mbox", type=click.Path(exists=True), help="Mbox file to use for training.")
@click.option('-p', "--prompt", help="Prompt to start generation.")
@click.option('-i', "--interact", is_flag=True, help="Drop into interactive shell after generating.")
@click.option('-t', "--tokenize", 'tokenize_opts', multiple=True, type=click.Choice(['lower', 'char', 'alpha']), help="Preprocessing option (can be specified multiple times). 'lower': lowercase all input text. 'char': use characters as tokens instead of words.")
@click.option('-v', "--verbose", is_flag=True, help="Display step-by-step generation as a table.")
def generate(length, context_window_words, filepath, gutenberg, list_gutenberg, mbox, prompt, interact, tokenize_opts, verbose):
"""Generate text using the language model."""
import nltk
@@ -34,27 +37,46 @@ def generate(length, n, text, gutenberg, list_gutenberg, mbox, prompt, interact)
# Determine training corpus
corpus = []
if text:
for filepath in text:
with open(filepath, "r") as f:
corpus.extend(f.read().split())
if filepath:
for fp in filepath:
with open(fp, "r") as f:
corpus.extend(tokenize_text(f.read(), tokenize_opts))
if gutenberg:
nltk.download("gutenberg", quiet=True)
from nltk.corpus import gutenberg as gutenberg_corpus
for key in gutenberg:
corpus.extend(gutenberg_corpus.words(key))
corpus.extend(tokenize_words(gutenberg_corpus.words(key), tokenize_opts))
if mbox:
mail_text = read_mail_text(mbox)
corpus.extend(mail_text.split())
corpus.extend(tokenize_text(mail_text, tokenize_opts))
if not corpus:
raise click.UsageError("Must specify at least one of --text, --gutenberg, or --mbox for training data.")
# Train and generate
model = TinyLanguageModel(n=n)
model = TinyLanguageModel(n=context_window_words)
model.train(corpus)
prompt_words = prompt.split() if prompt else None
output = model.generate(length, prompt=prompt_words)
if prompt:
prompt_tokens = tokenize_text(prompt, tokenize_opts)
else:
prompt_tokens = None
join_fn = ''.join if 'char' in tokenize_opts else None
display_join = join_fn or join
if verbose:
from tabulate import tabulate
rows = []
import textwrap
def step_callback(pattern, options, chosen):
opts = textwrap.fill(', '.join(sorted(set(options))), width=60)
rows.append([display_join(list(pattern)), opts, chosen])
output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn, step_callback=step_callback)
click.echo(tabulate(rows, headers=["Context", "Options", "Selected"], tablefmt="simple"))
click.echo()
else:
output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn)
click.echo(output)
if interact:

View File

@@ -18,7 +18,7 @@ class TinyLanguageModel:
self.model[pattern] = []
self.model[pattern].append(next_word)
def generate(self, length, prompt=None):
def generate(self, length, prompt=None, join_fn=None, step_callback=None):
"Create new words based on what the model learned."
if not self.model:
raise Exception("The model has not been trained")
@@ -27,9 +27,12 @@ class TinyLanguageModel:
pattern = tuple(output[-self.n:])
if pattern not in self.model:
break
next_word = random.choice(self.model[pattern])
output.append(next_word)
return join(output)
options = self.model[pattern]
chosen = random.choice(options)
if step_callback:
step_callback(pattern, options, chosen)
output.append(chosen)
return (join_fn or join)(output)
def get_random_pattern(self):
"Randomly chooses one of the observed patterns"

31
tlm/tokenization.py Normal file
View File

@@ -0,0 +1,31 @@
def compress_whitespace(text):
"""Collapse sequences of whitespace into a single space."""
import re
return re.sub(r'\s+', ' ', text).strip()
def _is_alpha_token(token):
return all(c.isalpha() or c in " '" for c in token)
def tokenize_text(text, options):
"""Tokenize a raw text string according to the given options."""
if 'lower' in options:
text = text.lower()
words = text.split()
if 'alpha' in options:
words = [w for w in words if _is_alpha_token(w)]
if 'char' in options:
return list(compress_whitespace(' '.join(words)))
return words
def tokenize_words(words, options):
"""Apply tokenization options to an already word-tokenized sequence."""
if 'lower' in options:
words = [w.lower() for w in words]
if 'alpha' in options:
words = [w for w in words if _is_alpha_token(w)]
if 'char' in options:
return list(compress_whitespace(' '.join(words)))
return list(words)