Initial commit
This commit is contained in:
79
plotting.py
Normal file
79
plotting.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
|
||||
|
||||
def _is_discrete(values, max_unique=30):
|
||||
return values.nunique() <= max_unique
|
||||
|
||||
|
||||
def _jittered(values, fraction):
|
||||
gaps = np.diff(np.sort(values.unique()))
|
||||
gap = gaps.min() if len(gaps) > 0 else 1
|
||||
noise = np.random.normal(0, gap * fraction, size=len(values))
|
||||
return values + noise
|
||||
|
||||
|
||||
def _maybe_jittered(values, jitter):
|
||||
if jitter is False:
|
||||
return values
|
||||
if jitter is None:
|
||||
return _jittered(values, 0.15) if _is_discrete(values) else values
|
||||
fraction = 0.15 if jitter is True else jitter
|
||||
return _jittered(values, fraction)
|
||||
|
||||
|
||||
def _default_size(n):
|
||||
if n > 50000:
|
||||
return 6
|
||||
if n > 5000:
|
||||
return 15
|
||||
if n > 500:
|
||||
return 25
|
||||
return 40
|
||||
|
||||
|
||||
def _default_opacity(n):
|
||||
if n > 50000:
|
||||
return 0.05
|
||||
if n > 5000:
|
||||
return 0.15
|
||||
if n > 500:
|
||||
return 0.3
|
||||
return 0.5
|
||||
|
||||
|
||||
def _predicted_line(model, x, column):
|
||||
line_x = np.linspace(x.min(), x.max(), 100)
|
||||
line_X = pd.DataFrame({column: line_x})
|
||||
return line_x, model.predict(line_X)
|
||||
|
||||
|
||||
def plot_regression(X, y, model, size=None, opacity=None, jitter=None):
|
||||
"""Scatterplot of a single predictor against y, with the model's predicted line."""
|
||||
column = X.columns[0]
|
||||
x = X[column]
|
||||
n = len(x)
|
||||
|
||||
if size is None:
|
||||
size = _default_size(n)
|
||||
if opacity is None:
|
||||
opacity = _default_opacity(n)
|
||||
|
||||
plot_x = _maybe_jittered(x, jitter)
|
||||
plot_y = _maybe_jittered(y, jitter)
|
||||
sns.scatterplot(x=plot_x, y=plot_y, s=size, alpha=opacity, edgecolor="none")
|
||||
|
||||
if _is_discrete(x):
|
||||
plt.xticks(sorted(x.unique()))
|
||||
if _is_discrete(y):
|
||||
plt.yticks(sorted(y.unique()))
|
||||
|
||||
line_x, line_y = _predicted_line(model, x, column)
|
||||
plt.plot(line_x, line_y, color="crimson")
|
||||
|
||||
plt.xlabel(column)
|
||||
plt.ylabel(y.name)
|
||||
plt.title(f"{y.name} vs {column}")
|
||||
plt.show()
|
||||
Reference in New Issue
Block a user