Initial commit
This commit is contained in:
60
tlm/cli.py
Normal file
60
tlm/cli.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import click
|
||||
from .model import TinyLanguageModel
|
||||
from .helpers import read_mail_text
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
"""TinyLM - A simple n-gram language model."""
|
||||
pass
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option("--length", default=50, help="Number of words to generate.")
|
||||
@click.option("--text", type=click.Path(exists=True), help="Text file to use as training corpus.")
|
||||
@click.option("--gutenberg", help="NLTK Gutenberg corpus key (use --list-gutenberg to see available keys).")
|
||||
@click.option("--list-gutenberg", is_flag=True, help="List available Gutenberg corpus keys.")
|
||||
@click.option("--mbox", type=click.Path(exists=True), help="Mbox file to use for training.")
|
||||
@click.option("--interact", is_flag=True, help="Drop into interactive shell after generating.")
|
||||
def generate(length, text, gutenberg, list_gutenberg, mbox, interact):
|
||||
"""Generate text using the language model."""
|
||||
import nltk
|
||||
|
||||
# Handle --list-gutenberg: list available keys
|
||||
if list_gutenberg:
|
||||
nltk.download("gutenberg", quiet=True)
|
||||
from nltk.corpus import gutenberg as gutenberg_corpus
|
||||
click.echo("Available Gutenberg corpus keys:")
|
||||
for key in gutenberg_corpus.fileids():
|
||||
click.echo(f" {key}")
|
||||
return
|
||||
|
||||
# Determine training corpus
|
||||
corpus = None
|
||||
|
||||
if text:
|
||||
with open(text, "r") as f:
|
||||
corpus = f.read().split()
|
||||
elif gutenberg:
|
||||
nltk.download("gutenberg", quiet=True)
|
||||
from nltk.corpus import gutenberg as gutenberg_corpus
|
||||
corpus = gutenberg_corpus.words(gutenberg)
|
||||
elif mbox:
|
||||
mail_text = read_mail_text(mbox)
|
||||
corpus = mail_text.split()
|
||||
else:
|
||||
raise click.UsageError("Must specify one of --text, --gutenberg, or --mbox for training data.")
|
||||
|
||||
# Train and generate
|
||||
model = TinyLanguageModel()
|
||||
model.train(corpus)
|
||||
output = model.generate(length)
|
||||
click.echo(output)
|
||||
|
||||
if interact:
|
||||
import code
|
||||
code.interact(local=locals(), banner="Entering interactive shell. 'model' and 'output' are available.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
Reference in New Issue
Block a user