import numpy as np import pandas as pd import matplotlib.pyplot as plt import seaborn as sns def _is_discrete(values, max_unique=30): return values.nunique() <= max_unique def _jittered(values, fraction): gaps = np.diff(np.sort(values.unique())) gap = gaps.min() if len(gaps) > 0 else 1 noise = np.random.normal(0, gap * fraction, size=len(values)) return values + noise def _maybe_jittered(values, jitter): if jitter is False: return values if jitter is None: return _jittered(values, 0.15) if _is_discrete(values) else values fraction = 0.15 if jitter is True else jitter return _jittered(values, fraction) def _default_size(n): if n > 50000: return 6 if n > 5000: return 15 if n > 500: return 25 return 40 def _default_opacity(n): if n > 50000: return 0.05 if n > 5000: return 0.15 if n > 500: return 0.3 return 0.5 def _predicted_line(model, x, column): line_x = np.linspace(x.min(), x.max(), 100) line_X = pd.DataFrame({column: line_x}) return line_x, model.predict(line_X) def plot_regression(X, y, model, size=None, opacity=None, jitter=None): """Scatterplot of a single predictor against y, with the model's predicted line.""" column = X.columns[0] x = X[column] n = len(x) if size is None: size = _default_size(n) if opacity is None: opacity = _default_opacity(n) plot_x = _maybe_jittered(x, jitter) plot_y = _maybe_jittered(y, jitter) sns.scatterplot(x=plot_x, y=plot_y, s=size, alpha=opacity, edgecolor="none") if _is_discrete(x): plt.xticks(sorted(x.unique())) if _is_discrete(y): plt.yticks(sorted(y.unique())) line_x, line_y = _predicted_line(model, x, column) plt.plot(line_x, line_y, color="crimson") plt.xlabel(column) plt.ylabel(y.name) plt.title(f"{y.name} vs {column}") plt.show()