80 lines
2.6 KiB
Python
80 lines
2.6 KiB
Python
import code
|
|
|
|
import pandas as pd
|
|
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score
|
|
|
|
|
|
def explore(df):
|
|
print("The dataset is available as `df`. Press Ctrl-D or type exit() to quit.\n")
|
|
code.interact(local={"df": df, "pd": pd}, banner="")
|
|
|
|
|
|
def dataset_summary(df, n_train, n_test):
|
|
print("=" * 60)
|
|
print("DATASET")
|
|
print("=" * 60)
|
|
counts = df["label"].value_counts()
|
|
total = n_train + n_test
|
|
print(f"\n Total messages: {len(df)}")
|
|
for label, n in counts.items():
|
|
print(f" {label:4s}: {n:5d} ({100 * n / len(df):.1f}%)")
|
|
print(f"\n train: {n_train} ({100 * n_train / total:.0f}%) test: {n_test} ({100 * n_test / total:.0f}%)")
|
|
print()
|
|
|
|
|
|
def evaluation(y_true, y_pred, clf_name):
|
|
print("=" * 60)
|
|
print(f"RESULTS: {clf_name}")
|
|
print("=" * 60)
|
|
print()
|
|
print(f"{'':12s} {'precision':>10} {'recall':>10} {'f1':>10}")
|
|
for label in ["ham", "spam"]:
|
|
p = precision_score(y_true, y_pred, pos_label=label, zero_division=0)
|
|
r = recall_score(y_true, y_pred, pos_label=label, zero_division=0)
|
|
f = f1_score(y_true, y_pred, pos_label=label, zero_division=0)
|
|
print(f" {label:<10} {p:>10.3f} {r:>10.3f} {f:>10.3f}")
|
|
avg_f1 = f1_score(y_true, y_pred, average="macro", zero_division=0)
|
|
print(f"\n {'average f1':10} {'':>10} {'':>10} {avg_f1:>10.3f}")
|
|
print()
|
|
|
|
cm = confusion_matrix(y_true, y_pred, labels=["ham", "spam"])
|
|
print("Confusion matrix:")
|
|
print(f"{'':18s} {'pred ham':>10} {'pred spam':>10}")
|
|
print(f"{'actual ham':18s} {cm[0][0]:>10} {cm[0][1]:>10}")
|
|
print(f"{'actual spam':18s} {cm[1][0]:>10} {cm[1][1]:>10}")
|
|
print()
|
|
|
|
|
|
def feature_weights(clf, top_n=10):
|
|
if not hasattr(clf, "feature_weights"):
|
|
return
|
|
weights = clf.feature_weights(top_n=top_n)
|
|
if not weights:
|
|
return
|
|
print("=" * 60)
|
|
print(f"TOP {len(weights)} FEATURES BY WEIGHT")
|
|
print("=" * 60)
|
|
for name, w in weights:
|
|
direction = "spam" if w > 0 else "ham "
|
|
bar_len = min(int(abs(w) * 5), 25)
|
|
bar = ("+" if w > 0 else "-") * bar_len
|
|
print(f" {name:<28} {w:+.3f} → {direction} {bar}")
|
|
print()
|
|
|
|
|
|
def error_analysis(X, y_true, y_pred, n):
|
|
errors = [
|
|
(x, t, p)
|
|
for x, t, p in zip(X, y_true, y_pred)
|
|
if t != p
|
|
]
|
|
shown = errors[:n]
|
|
print("=" * 60)
|
|
print(f"ERROR ANALYSIS ({len(shown)} of {len(errors)} misclassified)")
|
|
print("=" * 60)
|
|
for msg, true_label, pred_label in shown:
|
|
display = msg if len(msg) <= 80 else msg[:77] + "..."
|
|
print(f"\n true={true_label:<4} pred={pred_label}")
|
|
print(f" {display}")
|
|
print()
|