"""Evaluate a spam classifier. Usage: spam -e spam classifiers.manual.ManualClassifier spam classifiers.features.FeatureClassifier spam classifiers.manual.ManualClassifier -t 0.2 spam classifiers.manual.ManualClassifier -a spam classifiers.manual.ManualClassifier -a 5 """ import argparse import importlib from sklearn.model_selection import train_test_split import cli.output as out from cli.data import load_spam def load_classifier(class_path): module_path, class_name = class_path.rsplit(".", 1) module = importlib.import_module(module_path) return getattr(module, class_name)() def main(): parser = argparse.ArgumentParser( description="Train and evaluate a spam classifier." ) parser.add_argument( "classifier", nargs="?", help="Fully-qualified class, e.g. classifiers.manual.ManualClassifier", ) parser.add_argument( "-e", "--explore", action="store_true", help="Drop into an interactive shell with the dataset loaded as `df`", ) parser.add_argument( "-t", "--test-ratio", type=float, default=0.3, help="Fraction held out for testing (default: 0.3)", ) parser.add_argument( "-a", "--error-analysis", type=int, nargs="?", const=10, default=None, metavar="N", help="Show up to N misclassified examples (default: 10)", ) args = parser.parse_args() if not args.classifier and not args.explore: parser.print_help() return df = load_spam() if args.explore: out.explore(df) if not args.classifier: return X = df["message"].values y = df["label"].values X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=args.test_ratio, random_state=42 ) out.dataset_summary(df, len(X_train), len(X_test)) clf = load_classifier(args.classifier) clf.fit(X_train, y_train) y_pred = clf.predict(X_test) out.evaluation(y_test, y_pred, type(clf).__name__) out.feature_weights(clf) if args.error_analysis is not None: out.error_analysis(X_test, y_test, y_pred, args.error_analysis) if __name__ == "__main__": main()