Files
lab_embeddings/tlm/cli.py
Chris Proctor 039a467a9f initial commit
2026-03-09 12:28:21 -04:00

107 lines
5.1 KiB
Python

import click
from .model import TinyLanguageModel
from .helpers import read_mail_text, join
from .tokenization import tokenize_text, tokenize_words
@click.group()
def cli():
"""TinyLM - A language model with learned word embeddings."""
pass
@cli.command()
@click.option('-n', "--context-window-words", default=None, type=int, help="Number of words in the context window. (default: 2)")
@click.option('-d', "--embedding-dim", default=None, type=int, help="Dimension of the word embeddings. (default: 32)")
@click.option('-e', "--epochs", default=5, show_default=True, help="Number of training epochs.")
@click.option('-r', "--learning-rate", default=0.05, show_default=True, help="Learning rate for gradient descent.")
@click.option('-f', "--filepath", type=click.Path(exists=True), multiple=True, help="Text file(s) to train on. Can be specified multiple times.")
@click.option('-g', "--gutenberg", multiple=True, help="NLTK Gutenberg corpus key(s). Can be specified multiple times.")
@click.option('-G', "--list-gutenberg", is_flag=True, help="List available Gutenberg corpus keys.")
@click.option('-m', "--mbox", type=click.Path(exists=True), help="Mbox file to train on.")
@click.option('-o', "--output", default="model.json", show_default=True, help="File path to save the trained model.")
@click.option('-R', "--resume", 'resume_path', type=click.Path(exists=True), default=None, help="Load a saved model and continue training it.")
@click.option('-t', "--tokenize", 'tokenize_opts', multiple=True, type=click.Choice(['lower', 'char', 'alpha']), help="Tokenization options (can be specified multiple times).")
def train(context_window_words, embedding_dim, epochs, learning_rate, filepath, gutenberg, list_gutenberg, mbox, output, resume_path, tokenize_opts):
"""Train a language model on a corpus and save it to a file."""
import nltk
if list_gutenberg:
nltk.download("gutenberg", quiet=True)
from nltk.corpus import gutenberg as gutenberg_corpus
click.echo("Available Gutenberg corpus keys:")
for key in gutenberg_corpus.fileids():
click.echo(f" {key}")
return
corpus = []
if filepath:
for fp in filepath:
with open(fp) as f:
corpus.extend(tokenize_text(f.read(), tokenize_opts))
if gutenberg:
nltk.download("gutenberg", quiet=True)
from nltk.corpus import gutenberg as gutenberg_corpus
for key in gutenberg:
corpus.extend(tokenize_words(gutenberg_corpus.words(key), tokenize_opts))
if mbox:
corpus.extend(tokenize_text(read_mail_text(mbox), tokenize_opts))
if not corpus:
raise click.UsageError("No training data provided. Use --filepath, --gutenberg, or --mbox.")
if resume_path:
if context_window_words is not None or embedding_dim is not None:
raise click.UsageError("-n/--context-window-words and -d/--embedding-dim are ignored when resuming a saved model. Remove these options.")
model = TinyLanguageModel()
model.load(resume_path)
model.train(corpus, epochs=epochs, lr=learning_rate, resume=True)
else:
model = TinyLanguageModel(n=context_window_words or 2, embedding_dim=embedding_dim or 32)
model.train(corpus, epochs=epochs, lr=learning_rate)
model.save(output)
click.echo(f"Model saved to {output}")
@cli.command()
@click.option('-m', "--model", 'model_path', required=True, type=click.Path(exists=True), help="Trained model file to load.")
@click.option('-l', "--length", default=50, show_default=True, help="Number of tokens to generate.")
@click.option('-p', "--prompt", help="Prompt to start generation.")
@click.option('-v', "--verbose", is_flag=True, help="Display step-by-step generation as a table.")
@click.option('-i', "--interact", is_flag=True, help="Drop into interactive shell after generating.")
@click.option('-t', "--tokenize", 'tokenize_opts', multiple=True, type=click.Choice(['lower', 'char', 'alpha']), help="Tokenization options for the prompt.")
def generate(model_path, length, prompt, verbose, interact, tokenize_opts):
"""Generate text using a trained model."""
model = TinyLanguageModel()
model.load(model_path)
join_fn = ''.join if 'char' in tokenize_opts else None
display_join = join_fn or join
prompt_tokens = tokenize_text(prompt, tokenize_opts) if prompt else None
if verbose:
from tabulate import tabulate
import textwrap
rows = []
def step_callback(context, options, chosen):
opts = textwrap.fill(', '.join(options), width=60)
rows.append([display_join(list(context)), opts, chosen])
output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn, step_callback=step_callback)
click.echo(tabulate(rows, headers=["Context", "Options", "Selected"], tablefmt="simple"))
click.echo()
else:
output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn)
click.echo(output)
if interact:
import code
code.interact(local=locals(), banner="Entering interactive shell. 'model' and 'output' are available.")
if __name__ == "__main__":
cli()