Allow combined corpora

This commit is contained in:
chris
2026-02-09 12:35:23 -05:00
parent 40207b7623
commit f372786dbc

View File

@@ -32,24 +32,23 @@ def generate(length, n, text, gutenberg, list_gutenberg, mbox, prompt, interact)
return return
# Determine training corpus # Determine training corpus
corpus = None corpus = []
if text: if text:
corpus = []
for filepath in text: for filepath in text:
with open(filepath, "r") as f: with open(filepath, "r") as f:
corpus.extend(f.read().split()) corpus.extend(f.read().split())
elif gutenberg: if gutenberg:
nltk.download("gutenberg", quiet=True) nltk.download("gutenberg", quiet=True)
from nltk.corpus import gutenberg as gutenberg_corpus from nltk.corpus import gutenberg as gutenberg_corpus
corpus = []
for key in gutenberg: for key in gutenberg:
corpus.extend(gutenberg_corpus.words(key)) corpus.extend(gutenberg_corpus.words(key))
elif mbox: if mbox:
mail_text = read_mail_text(mbox) mail_text = read_mail_text(mbox)
corpus = mail_text.split() corpus.extend(mail_text.split())
else:
raise click.UsageError("Must specify one of --text, --gutenberg, or --mbox for training data.") if not corpus:
raise click.UsageError("Must specify at least one of --text, --gutenberg, or --mbox for training data.")
# Train and generate # Train and generate
model = TinyLanguageModel(n=n) model = TinyLanguageModel(n=n)