Files
lab_estimation/plotting.py
Chris Proctor b81e182942 Initial commit
2026-06-22 16:11:05 -04:00

80 lines
1.9 KiB
Python

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
def _is_discrete(values, max_unique=30):
return values.nunique() <= max_unique
def _jittered(values, fraction):
gaps = np.diff(np.sort(values.unique()))
gap = gaps.min() if len(gaps) > 0 else 1
noise = np.random.normal(0, gap * fraction, size=len(values))
return values + noise
def _maybe_jittered(values, jitter):
if jitter is False:
return values
if jitter is None:
return _jittered(values, 0.15) if _is_discrete(values) else values
fraction = 0.15 if jitter is True else jitter
return _jittered(values, fraction)
def _default_size(n):
if n > 50000:
return 6
if n > 5000:
return 15
if n > 500:
return 25
return 40
def _default_opacity(n):
if n > 50000:
return 0.05
if n > 5000:
return 0.15
if n > 500:
return 0.3
return 0.5
def _predicted_line(model, x, column):
line_x = np.linspace(x.min(), x.max(), 100)
line_X = pd.DataFrame({column: line_x})
return line_x, model.predict(line_X)
def plot_regression(X, y, model, size=None, opacity=None, jitter=None):
"""Scatterplot of a single predictor against y, with the model's predicted line."""
column = X.columns[0]
x = X[column]
n = len(x)
if size is None:
size = _default_size(n)
if opacity is None:
opacity = _default_opacity(n)
plot_x = _maybe_jittered(x, jitter)
plot_y = _maybe_jittered(y, jitter)
sns.scatterplot(x=plot_x, y=plot_y, s=size, alpha=opacity, edgecolor="none")
if _is_discrete(x):
plt.xticks(sorted(x.unique()))
if _is_discrete(y):
plt.yticks(sorted(y.unique()))
line_x, line_y = _predicted_line(model, x, column)
plt.plot(line_x, line_y, color="crimson")
plt.xlabel(column)
plt.ylabel(y.name)
plt.title(f"{y.name} vs {column}")
plt.show()