diff --git a/tlm/cli.py b/tlm/cli.py index 441b98a..3fb5c18 100644 --- a/tlm/cli.py +++ b/tlm/cli.py @@ -11,12 +11,14 @@ def cli(): @cli.command() @click.option("--length", default=50, help="Number of words to generate.") +@click.option("--n", default=2, help="Number of words in the context window.") @click.option("--text", type=click.Path(exists=True), help="Text file to use as training corpus.") @click.option("--gutenberg", help="NLTK Gutenberg corpus key (use --list-gutenberg to see available keys).") @click.option("--list-gutenberg", is_flag=True, help="List available Gutenberg corpus keys.") @click.option("--mbox", type=click.Path(exists=True), help="Mbox file to use for training.") +@click.option("--prompt", help="Prompt to start generation.") @click.option("--interact", is_flag=True, help="Drop into interactive shell after generating.") -def generate(length, text, gutenberg, list_gutenberg, mbox, interact): +def generate(length, n, text, gutenberg, list_gutenberg, mbox, prompt, interact): """Generate text using the language model.""" import nltk @@ -46,9 +48,10 @@ def generate(length, text, gutenberg, list_gutenberg, mbox, interact): raise click.UsageError("Must specify one of --text, --gutenberg, or --mbox for training data.") # Train and generate - model = TinyLanguageModel() + model = TinyLanguageModel(n=n) model.train(corpus) - output = model.generate(length) + prompt_words = prompt.split() if prompt else None + output = model.generate(length, prompt=prompt_words) click.echo(output) if interact: