Initial commit

This commit is contained in:
mwc
2026-02-09 12:15:12 -05:00
commit ce251fddbe
10 changed files with 410 additions and 0 deletions

0
tlm/__init__.py Normal file
View File

60
tlm/cli.py Normal file
View File

@@ -0,0 +1,60 @@
import click
from .model import TinyLanguageModel
from .helpers import read_mail_text
@click.group()
def cli():
"""TinyLM - A simple n-gram language model."""
pass
@cli.command()
@click.option("--length", default=50, help="Number of words to generate.")
@click.option("--text", type=click.Path(exists=True), help="Text file to use as training corpus.")
@click.option("--gutenberg", help="NLTK Gutenberg corpus key (use --list-gutenberg to see available keys).")
@click.option("--list-gutenberg", is_flag=True, help="List available Gutenberg corpus keys.")
@click.option("--mbox", type=click.Path(exists=True), help="Mbox file to use for training.")
@click.option("--interact", is_flag=True, help="Drop into interactive shell after generating.")
def generate(length, text, gutenberg, list_gutenberg, mbox, interact):
"""Generate text using the language model."""
import nltk
# Handle --list-gutenberg: list available keys
if list_gutenberg:
nltk.download("gutenberg", quiet=True)
from nltk.corpus import gutenberg as gutenberg_corpus
click.echo("Available Gutenberg corpus keys:")
for key in gutenberg_corpus.fileids():
click.echo(f" {key}")
return
# Determine training corpus
corpus = None
if text:
with open(text, "r") as f:
corpus = f.read().split()
elif gutenberg:
nltk.download("gutenberg", quiet=True)
from nltk.corpus import gutenberg as gutenberg_corpus
corpus = gutenberg_corpus.words(gutenberg)
elif mbox:
mail_text = read_mail_text(mbox)
corpus = mail_text.split()
else:
raise click.UsageError("Must specify one of --text, --gutenberg, or --mbox for training data.")
# Train and generate
model = TinyLanguageModel()
model.train(corpus)
output = model.generate(length)
click.echo(output)
if interact:
import code
code.interact(local=locals(), banner="Entering interactive shell. 'model' and 'output' are available.")
if __name__ == "__main__":
cli()

63
tlm/helpers.py Normal file
View File

@@ -0,0 +1,63 @@
import mailbox
import email
from email.policy import default
from tqdm import tqdm
def rolling_window(iterable, n):
"""Passes a rolling window over the iterable, yielding each n-length tuple.
rolling_window(range(5), 3) -> (0, 1, 2), (1, 2, 3), (2, 3, 4)
"""
it = iter(iterable)
try:
window = [next(it) for _ in range(n)]
while True:
yield tuple(window)
window = window[1:] + [next(it)]
except StopIteration:
return
def read_mail_text(mbox_path):
"""
Extract and concatenate all plaintext content from an mbox file.
"""
texts = []
mbox = mailbox.mbox(
mbox_path,
factory=lambda f: email.message_from_binary_file(f, policy=default)
)
for msg in tqdm(mbox):
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain":
try:
text = part.get_content()
if text:
texts.append(text.strip())
except Exception:
pass
else:
if msg.get_content_type() == "text/plain":
try:
text = msg.get_content()
if text:
texts.append(text.strip())
except Exception:
pass
return "\n\n".join(texts)
def clean_corpus(corpus, max_length=10, remove_numbers=False, exclude=None):
result = []
for word in corpus:
if max_length and len(word) > max_length:
continue
if remove_numbers and word.isnumeric():
continue
if exclude and word in exclude:
continue
result.append(word)
return result
def join(tokens, punctuation=".,?!:;'\""):
"Joins text, but does not give extra space for punctuation"
tokens = [t if t in punctuation else ' ' + t for t in tokens]
return ''.join(tokens).strip()

53
tlm/model.py Normal file
View File

@@ -0,0 +1,53 @@
import json
import random
from .helpers import rolling_window, join
class TinyLanguageModel:
def __init__(self, n=2):
"Create a new model that looks at n words at a time."
self.n = n
self.model = {}
def train(self, words):
"Learn word patterns from a list of words."
for sequence in rolling_window(words, self.n + 1):
pattern = tuple(sequence[:-1])
next_word = sequence[-1]
if pattern not in self.model:
self.model[pattern] = []
self.model[pattern].append(next_word)
def generate(self, length, prompt=None):
"Create new words based on what the model learned."
if not self.model:
raise Exception("The model has not been trained")
output = list(prompt or self.get_random_pattern())
while len(output) < length:
pattern = tuple(output[-self.n:])
if pattern not in self.model:
break
next_word = random.choice(self.model[pattern])
output.append(next_word)
return join(output)
def get_random_pattern(self):
"Randomly chooses one of the observed patterns"
return random.choice(list(self.model.keys()))
def save(self, filepath):
"Save the model to a file."
model_data = {
"n": self.n,
"model": {" ".join(k): v for k, v in self.model.items()}
}
with open(filepath, "w") as f:
json.dump(model_data, f)
def load(self, filepath):
"Load a model from a file."
with open(filepath, "r") as f:
data = json.load(f)
self.n = data["n"]
self.model = {tuple(k.split()): v for k, v in data["model"].items()}