generated from mwc/lab_tinylm
67 lines
2.4 KiB
Python
67 lines
2.4 KiB
Python
import click
|
|
from .model import TinyLanguageModel
|
|
from .helpers import read_mail_text
|
|
|
|
|
|
@click.group()
|
|
def cli():
|
|
"""TinyLM - A simple n-gram language model."""
|
|
pass
|
|
|
|
|
|
@cli.command()
|
|
@click.option("--length", default=50, help="Number of words to generate.")
|
|
@click.option("--n", default=2, help="Number of words in the context window.")
|
|
@click.option("--text", type=click.Path(exists=True), multiple=True, help="Text file(s) to use as training corpus. Can be specified multiple times.")
|
|
@click.option("--gutenberg", multiple=True, help="NLTK Gutenberg corpus key(s). Can be specified multiple times.")
|
|
@click.option("--list-gutenberg", is_flag=True, help="List available Gutenberg corpus keys.")
|
|
@click.option("--mbox", type=click.Path(exists=True), help="Mbox file to use for training.")
|
|
@click.option("--prompt", help="Prompt to start generation.")
|
|
@click.option("--interact", is_flag=True, help="Drop into interactive shell after generating.")
|
|
def generate(length, n, text, gutenberg, list_gutenberg, mbox, prompt, interact):
|
|
"""Generate text using the language model."""
|
|
import nltk
|
|
|
|
# Handle --list-gutenberg: list available keys
|
|
if list_gutenberg:
|
|
nltk.download("gutenberg", quiet=True)
|
|
from nltk.corpus import gutenberg as gutenberg_corpus
|
|
click.echo("Available Gutenberg corpus keys:")
|
|
for key in gutenberg_corpus.fileids():
|
|
click.echo(f" {key}")
|
|
return
|
|
|
|
# Determine training corpus
|
|
corpus = []
|
|
|
|
if text:
|
|
for filepath in text:
|
|
with open(filepath, "r") as f:
|
|
corpus.extend(f.read().split())
|
|
if gutenberg:
|
|
nltk.download("gutenberg", quiet=True)
|
|
from nltk.corpus import gutenberg as gutenberg_corpus
|
|
for key in gutenberg:
|
|
corpus.extend(gutenberg_corpus.words(key))
|
|
if mbox:
|
|
mail_text = read_mail_text(mbox)
|
|
corpus.extend(mail_text.split())
|
|
|
|
if not corpus:
|
|
raise click.UsageError("Must specify at least one of --text, --gutenberg, or --mbox for training data.")
|
|
|
|
# Train and generate
|
|
model = TinyLanguageModel(n=n)
|
|
model.train(corpus)
|
|
prompt_words = prompt.split() if prompt else None
|
|
output = model.generate(length, prompt=prompt_words)
|
|
click.echo(output)
|
|
|
|
if interact:
|
|
import code
|
|
code.interact(local=locals(), banner="Entering interactive shell. 'model' and 'output' are available.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
cli()
|