Initial commit
This commit is contained in:
89
cli/main.py
Normal file
89
cli/main.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Evaluate a spam classifier.
|
||||
|
||||
Usage:
|
||||
spam -e
|
||||
spam classifiers.manual.ManualClassifier
|
||||
spam classifiers.feature_classifier.FeatureClassifier
|
||||
spam classifiers.manual.ManualClassifier -t 0.2
|
||||
spam classifiers.manual.ManualClassifier -a
|
||||
spam classifiers.manual.ManualClassifier -a 5
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import importlib
|
||||
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
import cli.output as out
|
||||
from cli.data import load_spam
|
||||
|
||||
|
||||
def load_classifier(class_path):
|
||||
module_path, class_name = class_path.rsplit(".", 1)
|
||||
module = importlib.import_module(module_path)
|
||||
return getattr(module, class_name)()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Train and evaluate a spam classifier."
|
||||
)
|
||||
parser.add_argument(
|
||||
"classifier",
|
||||
nargs="?",
|
||||
help="Fully-qualified class, e.g. classifiers.manual.ManualClassifier",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e", "--explore",
|
||||
action="store_true",
|
||||
help="Drop into an interactive shell with the dataset loaded as `df`",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t", "--test-ratio",
|
||||
type=float,
|
||||
default=0.3,
|
||||
help="Fraction held out for testing (default: 0.3)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a", "--error-analysis",
|
||||
type=int,
|
||||
nargs="?",
|
||||
const=10,
|
||||
default=None,
|
||||
metavar="N",
|
||||
help="Show up to N misclassified examples (default: 10)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
if not args.classifier and not args.explore:
|
||||
parser.print_help()
|
||||
return
|
||||
|
||||
df = load_spam()
|
||||
|
||||
if args.explore:
|
||||
out.explore(df)
|
||||
if not args.classifier:
|
||||
return
|
||||
|
||||
X = df["message"].values
|
||||
y = df["label"].values
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y, test_size=args.test_ratio, random_state=42
|
||||
)
|
||||
|
||||
out.dataset_summary(df, len(X_train), len(X_test))
|
||||
|
||||
clf = load_classifier(args.classifier)
|
||||
clf.fit(X_train, y_train)
|
||||
y_pred = clf.predict(X_test)
|
||||
|
||||
out.evaluation(y_test, y_pred, type(clf).__name__)
|
||||
out.feature_weights(clf)
|
||||
|
||||
if args.error_analysis is not None:
|
||||
out.error_analysis(X_test, y_test, y_pred, args.error_analysis)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user