80 lines
1.9 KiB
Python
80 lines
1.9 KiB
Python
import numpy as np
|
|
import pandas as pd
|
|
import matplotlib.pyplot as plt
|
|
import seaborn as sns
|
|
|
|
|
|
def _is_discrete(values, max_unique=30):
|
|
return values.nunique() <= max_unique
|
|
|
|
|
|
def _jittered(values, fraction):
|
|
gaps = np.diff(np.sort(values.unique()))
|
|
gap = gaps.min() if len(gaps) > 0 else 1
|
|
noise = np.random.normal(0, gap * fraction, size=len(values))
|
|
return values + noise
|
|
|
|
|
|
def _maybe_jittered(values, jitter):
|
|
if jitter is False:
|
|
return values
|
|
if jitter is None:
|
|
return _jittered(values, 0.15) if _is_discrete(values) else values
|
|
fraction = 0.15 if jitter is True else jitter
|
|
return _jittered(values, fraction)
|
|
|
|
|
|
def _default_size(n):
|
|
if n > 50000:
|
|
return 6
|
|
if n > 5000:
|
|
return 15
|
|
if n > 500:
|
|
return 25
|
|
return 40
|
|
|
|
|
|
def _default_opacity(n):
|
|
if n > 50000:
|
|
return 0.05
|
|
if n > 5000:
|
|
return 0.15
|
|
if n > 500:
|
|
return 0.3
|
|
return 0.5
|
|
|
|
|
|
def _predicted_line(model, x, column):
|
|
line_x = np.linspace(x.min(), x.max(), 100)
|
|
line_X = pd.DataFrame({column: line_x})
|
|
return line_x, model.predict(line_X)
|
|
|
|
|
|
def plot_regression(X, y, model, size=None, opacity=None, jitter=None):
|
|
"""Scatterplot of a single predictor against y, with the model's predicted line."""
|
|
column = X.columns[0]
|
|
x = X[column]
|
|
n = len(x)
|
|
|
|
if size is None:
|
|
size = _default_size(n)
|
|
if opacity is None:
|
|
opacity = _default_opacity(n)
|
|
|
|
plot_x = _maybe_jittered(x, jitter)
|
|
plot_y = _maybe_jittered(y, jitter)
|
|
sns.scatterplot(x=plot_x, y=plot_y, s=size, alpha=opacity, edgecolor="none")
|
|
|
|
if _is_discrete(x):
|
|
plt.xticks(sorted(x.unique()))
|
|
if _is_discrete(y):
|
|
plt.yticks(sorted(y.unique()))
|
|
|
|
line_x, line_y = _predicted_line(model, x, column)
|
|
plt.plot(line_x, line_y, color="crimson")
|
|
|
|
plt.xlabel(column)
|
|
plt.ylabel(y.name)
|
|
plt.title(f"{y.name} vs {column}")
|
|
plt.show()
|