generated from mwc/lab_tinylm
Applying patch
This commit is contained in:
@@ -6,6 +6,7 @@ requires-python = ">=3.10,<4.0"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"click>=8.3.1",
|
"click>=8.3.1",
|
||||||
"nltk>=3.9.2",
|
"nltk>=3.9.2",
|
||||||
|
"tabulate>=0.9.0",
|
||||||
"tqdm>=4.67.3",
|
"tqdm>=4.67.3",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
60
tlm/cli.py
60
tlm/cli.py
@@ -1,6 +1,7 @@
|
|||||||
import click
|
import click
|
||||||
from .model import TinyLanguageModel
|
from .model import TinyLanguageModel
|
||||||
from .helpers import read_mail_text
|
from .helpers import read_mail_text, join
|
||||||
|
from .tokenization import tokenize_text, tokenize_words
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
@click.group()
|
||||||
@@ -10,15 +11,17 @@ def cli():
|
|||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@click.option("--length", default=50, help="Number of words to generate.")
|
@click.option('-l', "--length", default=50, help="Number of tokens to generate.")
|
||||||
@click.option("--n", default=2, help="Number of words in the context window.")
|
@click.option('-n', "--context-window-words", default=2, help="Number of words in the context window.")
|
||||||
@click.option("--text", type=click.Path(exists=True), multiple=True, help="Text file(s) to use as training corpus. Can be specified multiple times.")
|
@click.option('-f', "--filepath", type=click.Path(exists=True), multiple=True, help="Text file(s) to use as training corpus. Can be specified multiple times.")
|
||||||
@click.option("--gutenberg", multiple=True, help="NLTK Gutenberg corpus key(s). Can be specified multiple times.")
|
@click.option('-g', "--gutenberg", multiple=True, help="NLTK Gutenberg corpus key(s). Can be specified multiple times.")
|
||||||
@click.option("--list-gutenberg", is_flag=True, help="List available Gutenberg corpus keys.")
|
@click.option('-G', "--list-gutenberg", is_flag=True, help="List available Gutenberg corpus keys.")
|
||||||
@click.option("--mbox", type=click.Path(exists=True), help="Mbox file to use for training.")
|
@click.option('-m', "--mbox", type=click.Path(exists=True), help="Mbox file to use for training.")
|
||||||
@click.option("--prompt", help="Prompt to start generation.")
|
@click.option('-p', "--prompt", help="Prompt to start generation.")
|
||||||
@click.option("--interact", is_flag=True, help="Drop into interactive shell after generating.")
|
@click.option('-i', "--interact", is_flag=True, help="Drop into interactive shell after generating.")
|
||||||
def generate(length, n, text, gutenberg, list_gutenberg, mbox, prompt, interact):
|
@click.option('-t', "--tokenize", 'tokenize_opts', multiple=True, type=click.Choice(['lower', 'char', 'alpha']), help="Preprocessing option (can be specified multiple times). 'lower': lowercase all input text. 'char': use characters as tokens instead of words.")
|
||||||
|
@click.option('-v', "--verbose", is_flag=True, help="Display step-by-step generation as a table.")
|
||||||
|
def generate(length, context_window_words, filepath, gutenberg, list_gutenberg, mbox, prompt, interact, tokenize_opts, verbose):
|
||||||
"""Generate text using the language model."""
|
"""Generate text using the language model."""
|
||||||
import nltk
|
import nltk
|
||||||
|
|
||||||
@@ -34,27 +37,46 @@ def generate(length, n, text, gutenberg, list_gutenberg, mbox, prompt, interact)
|
|||||||
# Determine training corpus
|
# Determine training corpus
|
||||||
corpus = []
|
corpus = []
|
||||||
|
|
||||||
if text:
|
if filepath:
|
||||||
for filepath in text:
|
for fp in filepath:
|
||||||
with open(filepath, "r") as f:
|
with open(fp, "r") as f:
|
||||||
corpus.extend(f.read().split())
|
corpus.extend(tokenize_text(f.read(), tokenize_opts))
|
||||||
if gutenberg:
|
if gutenberg:
|
||||||
nltk.download("gutenberg", quiet=True)
|
nltk.download("gutenberg", quiet=True)
|
||||||
from nltk.corpus import gutenberg as gutenberg_corpus
|
from nltk.corpus import gutenberg as gutenberg_corpus
|
||||||
for key in gutenberg:
|
for key in gutenberg:
|
||||||
corpus.extend(gutenberg_corpus.words(key))
|
corpus.extend(tokenize_words(gutenberg_corpus.words(key), tokenize_opts))
|
||||||
if mbox:
|
if mbox:
|
||||||
mail_text = read_mail_text(mbox)
|
mail_text = read_mail_text(mbox)
|
||||||
corpus.extend(mail_text.split())
|
corpus.extend(tokenize_text(mail_text, tokenize_opts))
|
||||||
|
|
||||||
if not corpus:
|
if not corpus:
|
||||||
raise click.UsageError("Must specify at least one of --text, --gutenberg, or --mbox for training data.")
|
raise click.UsageError("Must specify at least one of --text, --gutenberg, or --mbox for training data.")
|
||||||
|
|
||||||
# Train and generate
|
# Train and generate
|
||||||
model = TinyLanguageModel(n=n)
|
model = TinyLanguageModel(n=context_window_words)
|
||||||
model.train(corpus)
|
model.train(corpus)
|
||||||
prompt_words = prompt.split() if prompt else None
|
if prompt:
|
||||||
output = model.generate(length, prompt=prompt_words)
|
prompt_tokens = tokenize_text(prompt, tokenize_opts)
|
||||||
|
else:
|
||||||
|
prompt_tokens = None
|
||||||
|
join_fn = ''.join if 'char' in tokenize_opts else None
|
||||||
|
display_join = join_fn or join
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
from tabulate import tabulate
|
||||||
|
rows = []
|
||||||
|
import textwrap
|
||||||
|
def step_callback(pattern, options, chosen):
|
||||||
|
opts = textwrap.fill(', '.join(sorted(set(options))), width=60)
|
||||||
|
rows.append([display_join(list(pattern)), opts, chosen])
|
||||||
|
output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn, step_callback=step_callback)
|
||||||
|
click.echo(tabulate(rows, headers=["Context", "Options", "Selected"], tablefmt="simple"))
|
||||||
|
click.echo()
|
||||||
|
|
||||||
|
else:
|
||||||
|
output = model.generate(length, prompt=prompt_tokens, join_fn=join_fn)
|
||||||
|
|
||||||
click.echo(output)
|
click.echo(output)
|
||||||
|
|
||||||
if interact:
|
if interact:
|
||||||
|
|||||||
11
tlm/model.py
11
tlm/model.py
@@ -18,7 +18,7 @@ class TinyLanguageModel:
|
|||||||
self.model[pattern] = []
|
self.model[pattern] = []
|
||||||
self.model[pattern].append(next_word)
|
self.model[pattern].append(next_word)
|
||||||
|
|
||||||
def generate(self, length, prompt=None):
|
def generate(self, length, prompt=None, join_fn=None, step_callback=None):
|
||||||
"Create new words based on what the model learned."
|
"Create new words based on what the model learned."
|
||||||
if not self.model:
|
if not self.model:
|
||||||
raise Exception("The model has not been trained")
|
raise Exception("The model has not been trained")
|
||||||
@@ -27,9 +27,12 @@ class TinyLanguageModel:
|
|||||||
pattern = tuple(output[-self.n:])
|
pattern = tuple(output[-self.n:])
|
||||||
if pattern not in self.model:
|
if pattern not in self.model:
|
||||||
break
|
break
|
||||||
next_word = random.choice(self.model[pattern])
|
options = self.model[pattern]
|
||||||
output.append(next_word)
|
chosen = random.choice(options)
|
||||||
return join(output)
|
if step_callback:
|
||||||
|
step_callback(pattern, options, chosen)
|
||||||
|
output.append(chosen)
|
||||||
|
return (join_fn or join)(output)
|
||||||
|
|
||||||
def get_random_pattern(self):
|
def get_random_pattern(self):
|
||||||
"Randomly chooses one of the observed patterns"
|
"Randomly chooses one of the observed patterns"
|
||||||
|
|||||||
11
uv.lock
generated
11
uv.lock
generated
@@ -38,6 +38,7 @@ source = { editable = "." }
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "click" },
|
{ name = "click" },
|
||||||
{ name = "nltk" },
|
{ name = "nltk" },
|
||||||
|
{ name = "tabulate" },
|
||||||
{ name = "tqdm" },
|
{ name = "tqdm" },
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -45,6 +46,7 @@ dependencies = [
|
|||||||
requires-dist = [
|
requires-dist = [
|
||||||
{ name = "click", specifier = ">=8.3.1" },
|
{ name = "click", specifier = ">=8.3.1" },
|
||||||
{ name = "nltk", specifier = ">=3.9.2" },
|
{ name = "nltk", specifier = ">=3.9.2" },
|
||||||
|
{ name = "tabulate", specifier = ">=0.9.0" },
|
||||||
{ name = "tqdm", specifier = ">=4.67.3" },
|
{ name = "tqdm", specifier = ">=4.67.3" },
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -184,6 +186,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/95/e4/a3b9480c78cf8ee86626cb06f8d931d74d775897d44201ccb813097ae697/regex-2026.1.15-cp314-cp314t-win_arm64.whl", hash = "sha256:ca89c5e596fc05b015f27561b3793dc2fa0917ea0d7507eebb448efd35274a70", size = 274837 },
|
{ url = "https://files.pythonhosted.org/packages/95/e4/a3b9480c78cf8ee86626cb06f8d931d74d775897d44201ccb813097ae697/regex-2026.1.15-cp314-cp314t-win_arm64.whl", hash = "sha256:ca89c5e596fc05b015f27561b3793dc2fa0917ea0d7507eebb448efd35274a70", size = 274837 },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tabulate"
|
||||||
|
version = "0.9.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/ec/fe/802052aecb21e3797b8f7902564ab6ea0d60ff8ca23952079064155d1ae1/tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c", size = 81090 }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/40/44/4a5f08c96eb108af5cb50b41f76142f0afa346dfa99d5296fe7202a11854/tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f", size = 35252 },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tqdm"
|
name = "tqdm"
|
||||||
version = "4.67.3"
|
version = "4.67.3"
|
||||||
|
|||||||
Reference in New Issue
Block a user