Initial commit
This commit is contained in:
71
cli/output.py
Normal file
71
cli/output.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import numpy as np
|
||||
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score
|
||||
|
||||
|
||||
def show_digit(pixels):
|
||||
img = pixels.reshape(28, 28)
|
||||
for row in img:
|
||||
print(" ".join("#" if p > 0.5 else "." for p in row))
|
||||
|
||||
|
||||
def explore(X_train, y_train):
|
||||
print("=" * 60)
|
||||
print("SAMPLE DIGITS")
|
||||
print("=" * 60)
|
||||
for i in range(3):
|
||||
print(f"\nLabel: {y_train[i]}")
|
||||
show_digit(X_train[i])
|
||||
|
||||
unique, counts = np.unique(y_train, return_counts=True)
|
||||
print()
|
||||
print("=" * 60)
|
||||
print("LABEL DISTRIBUTION")
|
||||
print("=" * 60)
|
||||
for digit, count in zip(unique, counts):
|
||||
print(f" {digit}: {count:5d} ({100 * count / len(y_train):.1f}%)")
|
||||
print()
|
||||
|
||||
|
||||
def dataset_summary(n_train, n_test):
|
||||
print("=" * 60)
|
||||
print("DATASET")
|
||||
print("=" * 60)
|
||||
total = n_train + n_test
|
||||
print(f"\n train: {n_train} ({100 * n_train / total:.0f}%) test: {n_test} ({100 * n_test / total:.0f}%)")
|
||||
print()
|
||||
|
||||
|
||||
def evaluation(y_true, y_pred, clf_name):
|
||||
print("=" * 60)
|
||||
print(f"RESULTS: {clf_name}")
|
||||
print("=" * 60)
|
||||
print()
|
||||
accuracy = accuracy_score(y_true, y_pred)
|
||||
avg_f1 = f1_score(y_true, y_pred, average="macro")
|
||||
print(f" accuracy {accuracy:>10.3f}")
|
||||
print(f" average f1 {avg_f1:>10.3f}")
|
||||
print()
|
||||
|
||||
cm = confusion_matrix(y_true, y_pred, labels=list(range(10)))
|
||||
per_digit = cm.diagonal() / cm.sum(axis=1)
|
||||
print("Per-digit accuracy:")
|
||||
for digit, acc in enumerate(per_digit):
|
||||
bar = "+" * int(acc * 30)
|
||||
print(f" {digit}: {acc:.3f} {bar}")
|
||||
print()
|
||||
|
||||
|
||||
def error_analysis(X, y_true, y_pred, n):
|
||||
errors = [
|
||||
(pixels, t, p)
|
||||
for pixels, t, p in zip(X, y_true, y_pred)
|
||||
if t != p
|
||||
]
|
||||
shown = errors[:n]
|
||||
print("=" * 60)
|
||||
print(f"ERROR ANALYSIS ({len(shown)} of {len(errors)} misclassified)")
|
||||
print("=" * 60)
|
||||
for pixels, true_label, pred_label in shown:
|
||||
print(f"\n true={true_label} pred={pred_label}")
|
||||
show_digit(pixels)
|
||||
print()
|
||||
Reference in New Issue
Block a user