initial commit
This commit is contained in:
106
tlm/cli.py
Normal file
106
tlm/cli.py
Normal file
@@ -0,0 +1,106 @@
|
||||
import click
|
||||
from .model import TinyLanguageModel
|
||||
from .helpers import read_mail_text, join
|
||||
from .tokenization import tokenize_text, tokenize_words
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
"""TinyLM - A language model with learned word embeddings."""
|
||||
pass
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option('-n', "--context-window-words", default=None, type=int, help="Number of words in the context window. (default: 2)")
|
||||
@click.option('-d', "--embedding-dim", default=None, type=int, help="Dimension of the word embeddings. (default: 32)")
|
||||
@click.option('-e', "--epochs", default=5, show_default=True, help="Number of training epochs.")
|
||||
@click.option('-r', "--learning-rate", default=0.05, show_default=True, help="Learning rate for gradient descent.")
|
||||
@click.option('-f', "--filepath", type=click.Path(exists=True), multiple=True, help="Text file(s) to train on. Can be specified multiple times.")
|
||||
@click.option('-g', "--gutenberg", multiple=True, help="NLTK Gutenberg corpus key(s). Can be specified multiple times.")
|
||||
@click.option('-G', "--list-gutenberg", is_flag=True, help="List available Gutenberg corpus keys.")
|
||||
@click.option('-m', "--mbox", type=click.Path(exists=True), help="Mbox file to train on.")
|
||||
@click.option('-o', "--output", default="model.json", show_default=True, help="File path to save the trained model.")
|
||||
@click.option('-R', "--resume", 'resume_path', type=click.Path(exists=True), default=None, help="Load a saved model and continue training it.")
|
||||
@click.option('-t', "--tokenize", 'tokenize_opts', multiple=True, type=click.Choice(['lower', 'char', 'alpha']), help="Tokenization options (can be specified multiple times).")
|
||||
def train(context_window_words, embedding_dim, epochs, learning_rate, filepath, gutenberg, list_gutenberg, mbox, output, resume_path, tokenize_opts):
|
||||
"""Train a language model on a corpus and save it to a file."""
|
||||
import nltk
|
||||
|
||||
if list_gutenberg:
|
||||
nltk.download("gutenberg", quiet=True)
|
||||
from nltk.corpus import gutenberg as gutenberg_corpus
|
||||
click.echo("Available Gutenberg corpus keys:")
|
||||
for key in gutenberg_corpus.fileids():
|
||||
click.echo(f" {key}")
|
||||
return
|
||||
|
||||
corpus = []
|
||||
if filepath:
|
||||
for fp in filepath:
|
||||
with open(fp) as f:
|
||||
corpus.extend(tokenize_text(f.read(), tokenize_opts))
|
||||
if gutenberg:
|
||||
nltk.download("gutenberg", quiet=True)
|
||||
from nltk.corpus import gutenberg as gutenberg_corpus
|
||||
for key in gutenberg:
|
||||
corpus.extend(tokenize_words(gutenberg_corpus.words(key), tokenize_opts))
|
||||
if mbox:
|
||||
corpus.extend(tokenize_text(read_mail_text(mbox), tokenize_opts))
|
||||
|
||||
if not corpus:
|
||||
raise click.UsageError("No training data provided. Use --filepath, --gutenberg, or --mbox.")
|
||||
|
||||
if resume_path:
|
||||
if context_window_words is not None or embedding_dim is not None:
|
||||
raise click.UsageError("-n/--context-window-words and -d/--embedding-dim are ignored when resuming a saved model. Remove these options.")
|
||||
model = TinyLanguageModel()
|
||||
model.load(resume_path)
|
||||
model.train(corpus, epochs=epochs, lr=learning_rate, resume=True)
|
||||
else:
|
||||
model = TinyLanguageModel(n=context_window_words or 2, embedding_dim=embedding_dim or 32)
|
||||
model.train(corpus, epochs=epochs, lr=learning_rate)
|
||||
model.save(output)
|
||||
click.echo(f"Model saved to {output}")
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option('-m', "--model", 'model_path', required=True, type=click.Path(exists=True), help="Trained model file to load.")
|
||||
@click.option('-l', "--length", default=50, show_default=True, help="Number of tokens to generate.")
|
||||
@click.option('-p', "--prompt", help="Prompt to start generation.")
|
||||
@click.option('-v', "--verbose", is_flag=True, help="Display step-by-step generation as a table.")
|
||||
@click.option('-i', "--interact", is_flag=True, help="Drop into interactive shell after generating.")
|
||||
@click.option('-t', "--tokenize", 'tokenize_opts', multiple=True, type=click.Choice(['lower', 'char', 'alpha']), help="Tokenization options for the prompt.")
|
||||
def generate(model_path, length, prompt, verbose, interact, tokenize_opts):
|
||||
"""Generate text using a trained model."""
|
||||
model = TinyLanguageModel()
|
||||
model.load(model_path)
|
||||
|
||||
join_fn = ''.join if 'char' in tokenize_opts else None
|
||||
display_join = join_fn or join
|
||||
|
||||
prompt_tokens = tokenize_text(prompt, tokenize_opts) if prompt else None
|
||||
|
||||
if verbose:
|
||||
from tabulate import tabulate
|
||||
import textwrap
|
||||
rows = []
|
||||
|
||||
def step_callback(context, options, chosen):
|
||||
opts = textwrap.fill(', '.join(options), width=60)
|
||||
rows.append([display_join(list(context)), opts, chosen])
|
||||
|
||||
output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn, step_callback=step_callback)
|
||||
click.echo(tabulate(rows, headers=["Context", "Options", "Selected"], tablefmt="simple"))
|
||||
click.echo()
|
||||
else:
|
||||
output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn)
|
||||
|
||||
click.echo(output)
|
||||
|
||||
if interact:
|
||||
import code
|
||||
code.interact(local=locals(), banner="Entering interactive shell. 'model' and 'output' are available.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
Reference in New Issue
Block a user