generated from mwc/lab_tinylm
Initial commit
This commit is contained in:
0
tlm/__init__.py
Normal file
0
tlm/__init__.py
Normal file
88
tlm/cli.py
Normal file
88
tlm/cli.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import click
|
||||
from .model import TinyLanguageModel
|
||||
from .helpers import read_mail_text, join
|
||||
from .tokenization import tokenize_text, tokenize_words
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
"""TinyLM - A simple n-gram language model."""
|
||||
pass
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.option('-l', "--length", default=50, help="Number of tokens to generate.")
|
||||
@click.option('-n', "--context-window-words", default=2, help="Number of words in the context window.")
|
||||
@click.option('-f', "--filepath", type=click.Path(exists=True), multiple=True, help="Text file(s) to use as training corpus. Can be specified multiple times.")
|
||||
@click.option('-g', "--gutenberg", multiple=True, help="NLTK Gutenberg corpus key(s). Can be specified multiple times.")
|
||||
@click.option('-G', "--list-gutenberg", is_flag=True, help="List available Gutenberg corpus keys.")
|
||||
@click.option('-m', "--mbox", type=click.Path(exists=True), help="Mbox file to use for training.")
|
||||
@click.option('-p', "--prompt", help="Prompt to start generation.")
|
||||
@click.option('-i', "--interact", is_flag=True, help="Drop into interactive shell after generating.")
|
||||
@click.option('-t', "--tokenize", 'tokenize_opts', multiple=True, type=click.Choice(['lower', 'char', 'alpha']), help="Preprocessing option (can be specified multiple times). 'lower': lowercase all input text. 'char': use characters as tokens instead of words.")
|
||||
@click.option('-v', "--verbose", is_flag=True, help="Display step-by-step generation as a table.")
|
||||
def generate(length, context_window_words, filepath, gutenberg, list_gutenberg, mbox, prompt, interact, tokenize_opts, verbose):
|
||||
"""Generate text using the language model."""
|
||||
import nltk
|
||||
|
||||
# Handle --list-gutenberg: list available keys
|
||||
if list_gutenberg:
|
||||
nltk.download("gutenberg", quiet=True)
|
||||
from nltk.corpus import gutenberg as gutenberg_corpus
|
||||
click.echo("Available Gutenberg corpus keys:")
|
||||
for key in gutenberg_corpus.fileids():
|
||||
click.echo(f" {key}")
|
||||
return
|
||||
|
||||
# Determine training corpus
|
||||
corpus = []
|
||||
|
||||
if filepath:
|
||||
for fp in filepath:
|
||||
with open(fp, "r") as f:
|
||||
corpus.extend(tokenize_text(f.read(), tokenize_opts))
|
||||
if gutenberg:
|
||||
nltk.download("gutenberg", quiet=True)
|
||||
from nltk.corpus import gutenberg as gutenberg_corpus
|
||||
for key in gutenberg:
|
||||
corpus.extend(tokenize_words(gutenberg_corpus.words(key), tokenize_opts))
|
||||
if mbox:
|
||||
mail_text = read_mail_text(mbox)
|
||||
corpus.extend(tokenize_text(mail_text, tokenize_opts))
|
||||
|
||||
if not corpus:
|
||||
raise click.UsageError("Must specify at least one of --text, --gutenberg, or --mbox for training data.")
|
||||
|
||||
# Train and generate
|
||||
model = TinyLanguageModel(n=context_window_words)
|
||||
model.train(corpus)
|
||||
if prompt:
|
||||
prompt_tokens = tokenize_text(prompt, tokenize_opts)
|
||||
else:
|
||||
prompt_tokens = None
|
||||
join_fn = ''.join if 'char' in tokenize_opts else None
|
||||
display_join = join_fn or join
|
||||
|
||||
if verbose:
|
||||
from tabulate import tabulate
|
||||
rows = []
|
||||
import textwrap
|
||||
def step_callback(pattern, options, chosen):
|
||||
opts = textwrap.fill(', '.join(sorted(set(options))), width=60)
|
||||
rows.append([display_join(list(pattern)), opts, chosen])
|
||||
output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn, step_callback=step_callback)
|
||||
click.echo(tabulate(rows, headers=["Context", "Options", "Selected"], tablefmt="simple"))
|
||||
click.echo()
|
||||
|
||||
else:
|
||||
output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn)
|
||||
|
||||
click.echo(output)
|
||||
|
||||
if interact:
|
||||
import code
|
||||
code.interact(local=locals(), banner="Entering interactive shell. 'model' and 'output' are available.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
63
tlm/helpers.py
Normal file
63
tlm/helpers.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import mailbox
|
||||
import email
|
||||
from email.policy import default
|
||||
from tqdm import tqdm
|
||||
|
||||
def rolling_window(iterable, n):
|
||||
"""Passes a rolling window over the iterable, yielding each n-length tuple.
|
||||
rolling_window(range(5), 3) -> (0, 1, 2), (1, 2, 3), (2, 3, 4)
|
||||
"""
|
||||
it = iter(iterable)
|
||||
try:
|
||||
window = [next(it) for _ in range(n)]
|
||||
while True:
|
||||
yield tuple(window)
|
||||
window = window[1:] + [next(it)]
|
||||
except StopIteration:
|
||||
return
|
||||
|
||||
def read_mail_text(mbox_path):
|
||||
"""
|
||||
Extract and concatenate all plaintext content from an mbox file.
|
||||
"""
|
||||
texts = []
|
||||
mbox = mailbox.mbox(
|
||||
mbox_path,
|
||||
factory=lambda f: email.message_from_binary_file(f, policy=default)
|
||||
)
|
||||
for msg in tqdm(mbox):
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain":
|
||||
try:
|
||||
text = part.get_content()
|
||||
if text:
|
||||
texts.append(text.strip())
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
if msg.get_content_type() == "text/plain":
|
||||
try:
|
||||
text = msg.get_content()
|
||||
if text:
|
||||
texts.append(text.strip())
|
||||
except Exception:
|
||||
pass
|
||||
return "\n\n".join(texts)
|
||||
|
||||
def clean_corpus(corpus, max_length=10, remove_numbers=False, exclude=None):
|
||||
result = []
|
||||
for word in corpus:
|
||||
if max_length and len(word) > max_length:
|
||||
continue
|
||||
if remove_numbers and word.isnumeric():
|
||||
continue
|
||||
if exclude and word in exclude:
|
||||
continue
|
||||
result.append(word)
|
||||
return result
|
||||
|
||||
def join(tokens, punctuation=".,?!:;'\""):
|
||||
"Joins text, but does not give extra space for punctuation"
|
||||
tokens = [t if t in punctuation else ' ' + t for t in tokens]
|
||||
return ''.join(tokens).strip()
|
||||
56
tlm/model.py
Normal file
56
tlm/model.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import json
|
||||
import random
|
||||
from .helpers import rolling_window, join
|
||||
|
||||
|
||||
class TinyLanguageModel:
|
||||
def __init__(self, n=2):
|
||||
"Create a new model that looks at n words at a time."
|
||||
self.n = n
|
||||
self.model = {}
|
||||
|
||||
def train(self, words):
|
||||
"Learn word patterns from a list of words."
|
||||
for sequence in rolling_window(words, self.n + 1):
|
||||
pattern = tuple(sequence[:-1])
|
||||
next_word = sequence[-1]
|
||||
if pattern not in self.model:
|
||||
self.model[pattern] = []
|
||||
self.model[pattern].append(next_word)
|
||||
|
||||
def generate(self, length, prompt=None, join_fn=None, step_callback=None):
|
||||
"Create new words based on what the model learned."
|
||||
if not self.model:
|
||||
raise Exception("The model has not been trained")
|
||||
output = list(prompt or self.get_random_pattern())
|
||||
while len(output) < length:
|
||||
pattern = tuple(output[-self.n:])
|
||||
if pattern not in self.model:
|
||||
break
|
||||
options = self.model[pattern]
|
||||
chosen = random.choice(options)
|
||||
if step_callback:
|
||||
step_callback(pattern, options, chosen)
|
||||
output.append(chosen)
|
||||
return (join_fn or join)(output)
|
||||
|
||||
def get_random_pattern(self):
|
||||
"Randomly chooses one of the observed patterns"
|
||||
return random.choice(list(self.model.keys()))
|
||||
|
||||
def save(self, filepath):
|
||||
"Save the model to a file."
|
||||
model_data = {
|
||||
"n": self.n,
|
||||
"model": {" ".join(k): v for k, v in self.model.items()}
|
||||
}
|
||||
with open(filepath, "w") as f:
|
||||
json.dump(model_data, f)
|
||||
|
||||
def load(self, filepath):
|
||||
"Load a model from a file."
|
||||
with open(filepath, "r") as f:
|
||||
data = json.load(f)
|
||||
self.n = data["n"]
|
||||
self.model = {tuple(k.split()): v for k, v in data["model"].items()}
|
||||
|
||||
31
tlm/tokenization.py
Normal file
31
tlm/tokenization.py
Normal file
@@ -0,0 +1,31 @@
|
||||
def compress_whitespace(text):
|
||||
"""Collapse sequences of whitespace into a single space."""
|
||||
import re
|
||||
return re.sub(r'\s+', ' ', text).strip()
|
||||
|
||||
|
||||
def _is_alpha_token(token):
|
||||
return all(c.isalpha() or c in " '" for c in token)
|
||||
|
||||
|
||||
def tokenize_text(text, options):
|
||||
"""Tokenize a raw text string according to the given options."""
|
||||
if 'lower' in options:
|
||||
text = text.lower()
|
||||
words = text.split()
|
||||
if 'alpha' in options:
|
||||
words = [w for w in words if _is_alpha_token(w)]
|
||||
if 'char' in options:
|
||||
return list(compress_whitespace(' '.join(words)))
|
||||
return words
|
||||
|
||||
|
||||
def tokenize_words(words, options):
|
||||
"""Apply tokenization options to an already word-tokenized sequence."""
|
||||
if 'lower' in options:
|
||||
words = [w.lower() for w in words]
|
||||
if 'alpha' in options:
|
||||
words = [w for w in words if _is_alpha_token(w)]
|
||||
if 'char' in options:
|
||||
return list(compress_whitespace(' '.join(words)))
|
||||
return list(words)
|
||||
Reference in New Issue
Block a user