Initial commit
This commit is contained in:
0
cli/__init__.py
Normal file
0
cli/__init__.py
Normal file
32
cli/data.py
Normal file
32
cli/data.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import io
|
||||
import os
|
||||
import zipfile
|
||||
import urllib.request
|
||||
|
||||
import pandas as pd
|
||||
from sklearn.datasets import get_data_home
|
||||
|
||||
|
||||
URL = (
|
||||
"https://archive.ics.uci.edu/ml/machine-learning-databases/"
|
||||
"00228/smsspamcollection.zip"
|
||||
)
|
||||
|
||||
|
||||
def load_spam():
|
||||
path = os.path.join(get_data_home(), "spam", "SMSSpamCollection")
|
||||
if not os.path.exists(path):
|
||||
_fetch(path)
|
||||
return pd.read_csv(path, sep="\t", header=None, names=["label", "message"])
|
||||
|
||||
|
||||
def _fetch(dest):
|
||||
os.makedirs(os.path.dirname(dest), exist_ok=True)
|
||||
print("Downloading SMS Spam Collection...")
|
||||
with urllib.request.urlopen(URL) as response:
|
||||
data = response.read()
|
||||
with zipfile.ZipFile(io.BytesIO(data)) as zf:
|
||||
with zf.open("SMSSpamCollection") as f:
|
||||
content = f.read()
|
||||
with open(dest, "wb") as f:
|
||||
f.write(content)
|
||||
89
cli/main.py
Normal file
89
cli/main.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Evaluate a spam classifier.
|
||||
|
||||
Usage:
|
||||
spam -e
|
||||
spam classifiers.manual.ManualClassifier
|
||||
spam classifiers.feature_classifier.FeatureClassifier
|
||||
spam classifiers.manual.ManualClassifier -t 0.2
|
||||
spam classifiers.manual.ManualClassifier -a
|
||||
spam classifiers.manual.ManualClassifier -a 5
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
import cli.output as out
|
||||
from cli.data import load_spam
|
||||
|
||||
|
||||
def load_classifier(class_path):
|
||||
module_path, class_name = class_path.rsplit(".", 1)
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train and evaluate a spam classifier."
|
||||
)
|
||||
parser.add_argument(
|
||||
"classifier",
|
||||
nargs="?",
|
||||
help="Fully-qualified class, e.g. classifiers.manual.ManualClassifier",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e", "--explore",
|
||||
action="store_true",
|
||||
help="Drop into an interactive shell with the dataset loaded as `df`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t", "--test-ratio",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="Fraction held out for testing (default: 0.3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a", "--error-analysis",
|
||||
type=int,
|
||||
nargs="?",
|
||||
const=10,
|
||||
default=None,
|
||||
metavar="N",
|
||||
help="Show up to N misclassified examples (default: 10)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.classifier and not args.explore:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
df = load_spam()
|
||||
|
||||
if args.explore:
|
||||
out.explore(df)
|
||||
if not args.classifier:
|
||||
return
|
||||
|
||||
X = df["message"].values
|
||||
y = df["label"].values
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=args.test_ratio, random_state=42
|
||||
)
|
||||
|
||||
out.dataset_summary(df, len(X_train), len(X_test))
|
||||
|
||||
clf = load_classifier(args.classifier)
|
||||
clf.fit(X_train, y_train)
|
||||
y_pred = clf.predict(X_test)
|
||||
|
||||
out.evaluation(y_test, y_pred, type(clf).__name__)
|
||||
out.feature_weights(clf)
|
||||
|
||||
if args.error_analysis is not None:
|
||||
out.error_analysis(X_test, y_test, y_pred, args.error_analysis)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
79
cli/output.py
Normal file
79
cli/output.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import code
|
||||
|
||||
import pandas as pd
|
||||
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score
|
||||
|
||||
|
||||
def explore(df):
|
||||
print("The dataset is available as `df`. Press Ctrl-D or type exit() to quit.\n")
|
||||
code.interact(local={"df": df, "pd": pd}, banner="")
|
||||
|
||||
|
||||
def dataset_summary(df, n_train, n_test):
|
||||
print("=" * 60)
|
||||
print("DATASET")
|
||||
print("=" * 60)
|
||||
counts = df["label"].value_counts()
|
||||
total = n_train + n_test
|
||||
print(f"\n Total messages: {len(df)}")
|
||||
for label, n in counts.items():
|
||||
print(f" {label:4s}: {n:5d} ({100 * n / len(df):.1f}%)")
|
||||
print(f"\n train: {n_train} ({100 * n_train / total:.0f}%) test: {n_test} ({100 * n_test / total:.0f}%)")
|
||||
print()
|
||||
|
||||
|
||||
def evaluation(y_true, y_pred, clf_name):
|
||||
print("=" * 60)
|
||||
print(f"RESULTS: {clf_name}")
|
||||
print("=" * 60)
|
||||
print()
|
||||
print(f"{'':12s} {'precision':>10} {'recall':>10} {'f1':>10}")
|
||||
for label in ["ham", "spam"]:
|
||||
p = precision_score(y_true, y_pred, pos_label=label, zero_division=0)
|
||||
r = recall_score(y_true, y_pred, pos_label=label, zero_division=0)
|
||||
f = f1_score(y_true, y_pred, pos_label=label, zero_division=0)
|
||||
print(f" {label:<10} {p:>10.3f} {r:>10.3f} {f:>10.3f}")
|
||||
avg_f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
|
||||
print(f"\n {'average f1':10} {'':>10} {'':>10} {avg_f1:>10.3f}")
|
||||
print()
|
||||
|
||||
cm = confusion_matrix(y_true, y_pred, labels=["ham", "spam"])
|
||||
print("Confusion matrix:")
|
||||
print(f"{'':18s} {'pred ham':>10} {'pred spam':>10}")
|
||||
print(f"{'actual ham':18s} {cm[0][0]:>10} {cm[0][1]:>10}")
|
||||
print(f"{'actual spam':18s} {cm[1][0]:>10} {cm[1][1]:>10}")
|
||||
print()
|
||||
|
||||
|
||||
def feature_weights(clf, top_n=10):
|
||||
if not hasattr(clf, "feature_weights"):
|
||||
return
|
||||
weights = clf.feature_weights(top_n=top_n)
|
||||
if not weights:
|
||||
return
|
||||
print("=" * 60)
|
||||
print(f"TOP {len(weights)} FEATURES BY WEIGHT")
|
||||
print("=" * 60)
|
||||
for name, w in weights:
|
||||
direction = "spam" if w > 0 else "ham "
|
||||
bar_len = min(int(abs(w) * 5), 25)
|
||||
bar = ("+" if w > 0 else "-") * bar_len
|
||||
print(f" {name:<28} {w:+.3f} → {direction} {bar}")
|
||||
print()
|
||||
|
||||
|
||||
def error_analysis(X, y_true, y_pred, n):
|
||||
errors = [
|
||||
(x, t, p)
|
||||
for x, t, p in zip(X, y_true, y_pred)
|
||||
if t != p
|
||||
]
|
||||
shown = errors[:n]
|
||||
print("=" * 60)
|
||||
print(f"ERROR ANALYSIS ({len(shown)} of {len(errors)} misclassified)")
|
||||
print("=" * 60)
|
||||
for msg, true_label, pred_label in shown:
|
||||
display = msg if len(msg) <= 80 else msg[:77] + "..."
|
||||
print(f"\n true={true_label:<4} pred={pred_label}")
|
||||
print(f" {display}")
|
||||
print()
|
||||
Reference in New Issue
Block a user