DEV Community

Cover image for The Assumptions of Linear Regression: The GPS That Only Works on Straight, Flat Roads in Perfect Weather
Sachin Kr. Rajput
Sachin Kr. Rajput

Posted on

The Assumptions of Linear Regression: The GPS That Only Works on Straight, Flat Roads in Perfect Weather

The One-Line Summary: Linear regression assumes your data follows a straight line, your errors are random and normally distributed, your error spread is constant everywhere, and your data points are independent. Violate these and your "accurate" model is actually lying to you.


The GPS That Only Works Sometimes

TrueNav GPS was the most accurate navigation system ever built. In tests, it predicted arrival times within 30 seconds. Amazing!

The company shipped it. Reviews poured in:

⭐⭐⭐⭐⭐ "Incredible accuracy on the highway!"
⭐⭐⭐⭐⭐ "Perfect for my daily commute!"
⭐⭐⭐⭐⭐ "Never been late to work!"

Then...

⭐ "Told me to drive through a lake"
⭐ "Said I'd arrive in 20 mins. Took 3 hours through mountains"
⭐ "Completely wrong during rush hour traffic"
⭐ "Works great on clear days, useless in rain"
Enter fullscreen mode Exit fullscreen mode

The Investigation:

TrueNav's testing data:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
✓ All test routes were STRAIGHT highways
✓ All test routes were FLAT terrain
✓ All tests were done in CLEAR weather
✓ All tests were done in LIGHT traffic

TrueNav's assumptions:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
1. Roads are straight (LINEARITY)
2. Road conditions are consistent (HOMOSCEDASTICITY)
3. Weather doesn't affect driving (NORMALITY of errors)
4. Each mile is independent of previous miles (INDEPENDENCE)

Real world:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
✗ Roads curve through mountains
✗ City traffic ≠ highway traffic
✗ Weather causes unpredictable delays
✗ Traffic jams cascade (one slow car affects everyone)
Enter fullscreen mode Exit fullscreen mode

TrueNav violated its assumptions. The "accurate" model was garbage in the real world.


The Four Assumptions of Linear Regression

Linear regression is TrueNav. It works brilliantly — IF four assumptions hold:

┌─────────────────────────────────────────────────────────────┐
│           THE FOUR ASSUMPTIONS (L.I.N.E.)                   │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  L - LINEARITY                                              │
│      The relationship between X and Y is a straight line    │
│                                                             │
│  I - INDEPENDENCE                                           │
│      Each data point is independent of others               │
│                                                             │
│  N - NORMALITY                                              │
│      The residuals follow a normal distribution             │
│                                                             │
│  E - EQUAL VARIANCE (Homoscedasticity)                      │
│      The spread of residuals is constant across all X       │
│                                                             │
└─────────────────────────────────────────────────────────────┘

Violate ANY of these → Your model may be LYING to you
Enter fullscreen mode Exit fullscreen mode

Assumption 1: LINEARITY

The relationship between X and Y must be a straight line.

✓ LINEAR (Assumption MET):        ✗ NON-LINEAR (Assumption VIOLATED):
━━━━━━━━━━━━━━━━━━━━━━━━━━        ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

    Y                                  Y
    │                                  │        ××
    │            ×                     │      ×    ×
    │          × ×                     │    ×        ×
    │        ×                         │  ×            ×
    │      × ×                         │ ×              ×
    │    ×                             │×                ×
    │  × ×                             │                  ×
    │ ×                                │
    └──────────────── X                └──────────────────── X

    A straight line fits well.         A straight line misses the curve!
Enter fullscreen mode Exit fullscreen mode

What Happens If Violated?

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.metrics import r2_score

# True relationship is CURVED (quadratic)
np.random.seed(42)
X = np.linspace(0, 10, 100)
y_true = 2 * X**2 - 10 * X + 50  # Parabola!
y = y_true + np.random.normal(0, 10, 100)

# Fit linear regression (assumes straight line)
model = LinearRegression()
model.fit(X.reshape(-1, 1), y)
y_pred = model.predict(X.reshape(-1, 1))

# Calculate R²
r2 = r2_score(y, y_pred)

print("LINEARITY VIOLATION EXAMPLE")
print("="*50)
print(f"True relationship: y = 2x² - 10x + 50 (CURVED)")
print(f"Model assumes: y = mx + b (STRAIGHT)")
print(f"\nR² score: {r2:.3f}")
print(f"→ Looks decent, but model is FUNDAMENTALLY WRONG!")
Enter fullscreen mode Exit fullscreen mode

Output:

LINEARITY VIOLATION EXAMPLE
==================================================
True relationship: y = 2x² - 10x + 50 (CURVED)
Model assumes: y = mx + b (STRAIGHT)

R² score: 0.847
→ Looks decent, but model is FUNDAMENTALLY WRONG!
Enter fullscreen mode Exit fullscreen mode

R² of 0.85 sounds good, but the model completely misses the pattern!

How to Detect Linearity Violations

def check_linearity(X, y, model):
    """Check linearity assumption with residual plot."""
    y_pred = model.predict(X.reshape(-1, 1) if X.ndim == 1 else X)
    residuals = y - y_pred

    plt.figure(figsize=(12, 4))

    # Plot 1: Scatter with regression line
    plt.subplot(1, 2, 1)
    plt.scatter(X, y, alpha=0.6, label='Data')
    plt.plot(X, y_pred, 'r-', linewidth=2, label='Linear fit')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.title('Data with Linear Fit')
    plt.legend()

    # Plot 2: Residuals vs X (KEY DIAGNOSTIC!)
    plt.subplot(1, 2, 2)
    plt.scatter(X, residuals, alpha=0.6)
    plt.axhline(y=0, color='r', linestyle='--')
    plt.xlabel('X')
    plt.ylabel('Residuals')
    plt.title('Residuals vs X\n(Should be RANDOM scatter around 0)')

    plt.tight_layout()
    plt.savefig('linearity_check.png', dpi=150)
    plt.show()

    # Detection rule
    print("\nLINEARITY CHECK:")
    print("Look at residuals plot:")
    print("  ✓ Random scatter around 0 → Linearity OK")
    print("  ✗ Curved pattern → Linearity VIOLATED")

check_linearity(X, y, model)
Enter fullscreen mode Exit fullscreen mode
WHAT YOU'LL SEE:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

LINEARITY MET:                    LINEARITY VIOLATED:
(Residuals are random)            (Residuals show a pattern)

Residuals                         Residuals
    │   ×     ×    ×                  │         × × ×
    │ ×    ×     ×                    │       ×       ×
  0 │──────────────────             0 │× ×               × ×
    │    ×   ×   ×                    │   × ×       × ×
    │  ×       ×                      │       × × ×
    └───────────────── X              └───────────────────── X

    RANDOM = GOOD                     CURVED = BAD!
Enter fullscreen mode Exit fullscreen mode

How to Fix Linearity Violations

# Option 1: Transform the features
X_transformed = np.column_stack([X, X**2])  # Add polynomial term
model.fit(X_transformed, y)

# Option 2: Use polynomial regression
from sklearn.preprocessing import PolynomialFeatures
poly = PolynomialFeatures(degree=2)
X_poly = poly.fit_transform(X.reshape(-1, 1))
model.fit(X_poly, y)

# Option 3: Use a non-linear model
from sklearn.ensemble import RandomForestRegressor
model = RandomForestRegressor()
model.fit(X.reshape(-1, 1), y)
Enter fullscreen mode Exit fullscreen mode

Assumption 2: INDEPENDENCE

Each data point must be independent of others.

✓ INDEPENDENT (Assumption MET):    ✗ DEPENDENT (Assumption VIOLATED):
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Random sample of 100 people:       Time series of stock prices:
Person 1's height doesn't          Today's price DEPENDS on yesterday's!
affect Person 2's height.          
                                   Day 1: $100 → Day 2: $102 → Day 3: $98
Each data point is a fresh         ↑          ↑           ↑
independent observation.           These are CONNECTED, not independent!

Also OK:                           Also VIOLATED:
- Random sample of houses          - Multiple measurements per patient
- Random sample of products        - Students in the same classroom
- Random sample of transactions    - Employees at the same company
Enter fullscreen mode Exit fullscreen mode

What Happens If Violated?

import numpy as np
from sklearn.linear_model import LinearRegression

# Create DEPENDENT data (time series)
np.random.seed(42)
n = 100

# Autoregressive process: today depends on yesterday
y = np.zeros(n)
y[0] = 50
for i in range(1, n):
    y[i] = 0.9 * y[i-1] + np.random.normal(0, 5)  # Depends on previous!

X = np.arange(n)

# Fit linear regression (assumes independence)
model = LinearRegression()
model.fit(X.reshape(-1, 1), y)

print("INDEPENDENCE VIOLATION EXAMPLE")
print("="*50)
print("Data: Time series where y[t] = 0.9 × y[t-1] + noise")
print("      Each point DEPENDS on the previous one!")
print(f"\nR² score: {model.score(X.reshape(-1, 1), y):.3f}")
print("\nProblems:")
print("  1. Standard errors are UNDERESTIMATED")
print("  2. Confidence intervals are TOO NARROW")
print("  3. Statistical tests are UNRELIABLE")
print("  4. Model may see patterns that aren't real")
Enter fullscreen mode Exit fullscreen mode

How to Detect Independence Violations

from scipy import stats

def check_independence(y_true, y_pred):
    """Check independence using Durbin-Watson test."""
    residuals = y_true - y_pred

    # Durbin-Watson test
    # DW ≈ 2: No autocorrelation (GOOD)
    # DW < 1.5: Positive autocorrelation (BAD)
    # DW > 2.5: Negative autocorrelation (BAD)

    n = len(residuals)
    diff_squared = np.sum(np.diff(residuals)**2)
    sum_squared = np.sum(residuals**2)
    dw = diff_squared / sum_squared

    print("\nINDEPENDENCE CHECK (Durbin-Watson Test):")
    print(f"Durbin-Watson statistic: {dw:.3f}")

    if 1.5 < dw < 2.5:
        print("✓ No significant autocorrelation detected")
    elif dw <= 1.5:
        print("✗ POSITIVE autocorrelation detected!")
        print("  → Residuals are following each other")
    else:
        print("✗ NEGATIVE autocorrelation detected!")
        print("  → Residuals are alternating")

    # Visual check: plot residuals in order
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(residuals, 'o-', alpha=0.7)
    plt.axhline(y=0, color='r', linestyle='--')
    plt.xlabel('Observation Order')
    plt.ylabel('Residual')
    plt.title('Residuals Over Time\n(Should look RANDOM)')

    plt.subplot(1, 2, 2)
    plt.scatter(residuals[:-1], residuals[1:], alpha=0.6)
    plt.xlabel('Residual[t]')
    plt.ylabel('Residual[t+1]')
    plt.title('Residual Autocorrelation\n(Should be RANDOM cloud)')

    plt.tight_layout()
    plt.savefig('independence_check.png', dpi=150)
    plt.show()

    return dw

dw = check_independence(y, model.predict(X.reshape(-1, 1)))
Enter fullscreen mode Exit fullscreen mode

How to Fix Independence Violations

# Option 1: Use time series models instead
from statsmodels.tsa.arima.model import ARIMA
model = ARIMA(y, order=(1, 0, 0))  # AR(1) model
results = model.fit()

# Option 2: Add lagged features
X_lagged = np.column_stack([X[1:], y[:-1]])  # Include y[t-1] as feature
model.fit(X_lagged, y[1:])

# Option 3: Use robust standard errors
import statsmodels.api as sm
X_with_const = sm.add_constant(X)
model = sm.OLS(y, X_with_const)
results = model.fit(cov_type='HAC', cov_kwds={'maxlags': 5})  # Newey-West
Enter fullscreen mode Exit fullscreen mode

Assumption 3: NORMALITY of Residuals

The residuals must follow a normal (bell curve) distribution.

✓ NORMAL (Assumption MET):         ✗ NON-NORMAL (Assumption VIOLATED):
━━━━━━━━━━━━━━━━━━━━━━━━━━         ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

        ┌───┐                              ┌─┐
       ┌┘   └┐                             │ │
      ┌┘     └┐                            │ │
     ┌┘       └┐                           │ │
    ┌┘         └┐                          │ │
───┘            └───                     ──┘ └─────────────────
   -3  -2  -1  0  1  2  3                  0   5   10   15   20

   Bell-shaped, symmetric                  Skewed right (long tail)
   Most errors near zero                   Many small errors, few HUGE ones
Enter fullscreen mode Exit fullscreen mode

What Happens If Violated?

  • Coefficient estimates are still unbiased (good news!)
  • BUT: Confidence intervals are wrong
  • AND: P-values are unreliable
  • AND: Prediction intervals are meaningless
import numpy as np
from scipy import stats

# Create data with NON-NORMAL errors
np.random.seed(42)
X = np.random.uniform(0, 100, 200)

# Non-normal errors: exponential distribution (always positive, skewed)
errors = np.random.exponential(scale=20, size=200) - 20  # Shifted
y = 50 + 2 * X + errors

# Fit model
model = LinearRegression()
model.fit(X.reshape(-1, 1), y)
residuals = y - model.predict(X.reshape(-1, 1))

# Test normality
stat, p_value = stats.shapiro(residuals)

print("NORMALITY VIOLATION EXAMPLE")
print("="*50)
print("True errors: Exponential distribution (skewed)")
print(f"\nShapiro-Wilk test:")
print(f"  Statistic: {stat:.4f}")
print(f"  P-value: {p_value:.4f}")
if p_value < 0.05:
    print("  ✗ Residuals are NOT normally distributed!")
else:
    print("  ✓ Residuals appear normal")
Enter fullscreen mode Exit fullscreen mode

How to Detect Normality Violations

import scipy.stats as stats

def check_normality(residuals):
    """Check normality of residuals."""

    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    # Plot 1: Histogram
    axes[0].hist(residuals, bins=30, density=True, alpha=0.7, edgecolor='black')

    # Overlay normal distribution
    xmin, xmax = axes[0].get_xlim()
    x = np.linspace(xmin, xmax, 100)
    p = stats.norm.pdf(x, residuals.mean(), residuals.std())
    axes[0].plot(x, p, 'r-', linewidth=2, label='Normal distribution')
    axes[0].set_xlabel('Residuals')
    axes[0].set_ylabel('Density')
    axes[0].set_title('Histogram of Residuals')
    axes[0].legend()

    # Plot 2: Q-Q Plot (most important!)
    stats.probplot(residuals, dist="norm", plot=axes[1])
    axes[1].set_title('Q-Q Plot\n(Points should follow diagonal line)')

    # Plot 3: Box plot
    axes[2].boxplot(residuals, vert=True)
    axes[2].set_ylabel('Residuals')
    axes[2].set_title('Box Plot\n(Should be symmetric)')

    plt.tight_layout()
    plt.savefig('normality_check.png', dpi=150)
    plt.show()

    # Statistical tests
    print("\nNORMALITY TESTS:")
    print("-"*50)

    # Shapiro-Wilk (best for n < 5000)
    if len(residuals) < 5000:
        stat, p = stats.shapiro(residuals)
        print(f"Shapiro-Wilk: statistic={stat:.4f}, p-value={p:.4f}")
        print(f"{'✓ Normal' if p > 0.05 else '✗ NOT Normal'}")

    # D'Agostino and Pearson's test
    stat, p = stats.normaltest(residuals)
    print(f"D'Agostino-Pearson: statistic={stat:.4f}, p-value={p:.4f}")
    print(f"{'✓ Normal' if p > 0.05 else '✗ NOT Normal'}")

    # Skewness and Kurtosis
    skew = stats.skew(residuals)
    kurt = stats.kurtosis(residuals)
    print(f"\nSkewness: {skew:.3f} (should be near 0)")
    print(f"Kurtosis: {kurt:.3f} (should be near 0)")

check_normality(residuals)
Enter fullscreen mode Exit fullscreen mode
Q-Q PLOT INTERPRETATION:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

NORMAL (GOOD):                    SKEWED (BAD):
Points follow the line            Points curve away

Sample │        ×                 Sample │              ×
Quant. │      ×                   Quant. │           ×
       │    ×                            │         ×
       │  ×                              │      ×
       │×                                │   ×
       └────────────── Theoretical       └×─────────── Theoretical

HEAVY TAILS (BAD):                LIGHT TAILS (usually OK):
Points deviate at extremes        Points compress at extremes

Sample │          ×               Sample │        ×─────
Quant. │        ×                 Quant. │      ×
       │      ×                          │    ×
       │    ×                            │  ×
       │  ×                              │×
       │×                                └─────────────
       └──────────────
Enter fullscreen mode Exit fullscreen mode

How to Fix Normality Violations

# Option 1: Transform the target variable
y_log = np.log(y)  # If y is always positive and right-skewed
model.fit(X.reshape(-1, 1), y_log)

# Option 2: Use Box-Cox transformation
from scipy.stats import boxcox
y_transformed, lambda_param = boxcox(y + 1)  # +1 if y has zeros
model.fit(X.reshape(-1, 1), y_transformed)

# Option 3: Remove outliers
from scipy import stats
z_scores = np.abs(stats.zscore(residuals))
mask = z_scores < 3  # Keep only within 3 std devs
model.fit(X[mask].reshape(-1, 1), y[mask])

# Option 4: Use robust regression
from sklearn.linear_model import HuberRegressor
model = HuberRegressor()  # Less sensitive to outliers
model.fit(X.reshape(-1, 1), y)

# Option 5: Bootstrap confidence intervals (no normality required!)
from sklearn.utils import resample
boot_coefs = []
for _ in range(1000):
    X_boot, y_boot = resample(X, y)
    model.fit(X_boot.reshape(-1, 1), y_boot)
    boot_coefs.append(model.coef_[0])
ci_low, ci_high = np.percentile(boot_coefs, [2.5, 97.5])
Enter fullscreen mode Exit fullscreen mode

Assumption 4: HOMOSCEDASTICITY (Equal Variance)

The spread of residuals must be constant across all X values.

✓ HOMOSCEDASTIC (Assumption MET):  ✗ HETEROSCEDASTIC (Assumption VIOLATED):
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━  ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Residuals                          Residuals
    │   ×   ×       ×   ×              │                    ×   ×
    │ ×   ×   × × ×   ×                │                  ×   ×
  0 │────────────────────────        0 │──────────────────────────
    │   × ×   × ×   × ×                │     ×   ×   × ×   ×
    │ ×     ×     ×     ×              │   × × ×
    └───────────────────── X           └───────────────────────── X

    Constant spread (GOOD!)            Spread INCREASES with X (BAD!)
    "Homoscedastic"                    "Heteroscedastic" (funnel shape)
Enter fullscreen mode Exit fullscreen mode

Real-World Example

SALARY PREDICTION:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Entry-level employees (X = low experience):
  Salary range: $45,000 - $55,000
  Variance: SMALL (everyone earns similar amounts)

Senior executives (X = high experience):
  Salary range: $150,000 - $500,000
  Variance: HUGE (some CEOs earn millions, others don't)

This is HETEROSCEDASTICITY!
The spread of salaries INCREASES with experience.
Linear regression's uncertainty estimates will be WRONG.
Enter fullscreen mode Exit fullscreen mode

What Happens If Violated?

import numpy as np

# Create HETEROSCEDASTIC data
np.random.seed(42)
X = np.random.uniform(10, 100, 200)

# Error variance INCREASES with X
errors = np.random.normal(0, 1, 200) * X * 0.3  # Bigger X → bigger error!
y = 100 + 2 * X + errors

# Fit model
model = LinearRegression()
model.fit(X.reshape(-1, 1), y)

print("HETEROSCEDASTICITY VIOLATION EXAMPLE")
print("="*50)
print("True relationship: y = 100 + 2x + ε")
print("Where ε has variance proportional to X")
print("\nProblems:")
print("  1. Coefficient estimates are INEFFICIENT (not best possible)")
print("  2. Standard errors are WRONG")
print("  3. Confidence intervals are WRONG")
print("  4. Hypothesis tests are UNRELIABLE")
Enter fullscreen mode Exit fullscreen mode

How to Detect Heteroscedasticity

import statsmodels.api as sm
from statsmodels.stats.diagnostic import het_breuschpagan, het_white

def check_homoscedasticity(X, y, model):
    """Check for constant variance of residuals."""

    y_pred = model.predict(X.reshape(-1, 1) if X.ndim == 1 else X)
    residuals = y - y_pred

    # Visual check
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.scatter(y_pred, residuals, alpha=0.6)
    plt.axhline(y=0, color='r', linestyle='--')
    plt.xlabel('Predicted Values')
    plt.ylabel('Residuals')
    plt.title('Residuals vs Predicted\n(Should be constant spread)')

    plt.subplot(1, 2, 2)
    plt.scatter(y_pred, np.abs(residuals), alpha=0.6)
    plt.xlabel('Predicted Values')
    plt.ylabel('|Residuals|')
    plt.title('Absolute Residuals vs Predicted\n(Should be FLAT, not increasing)')

    plt.tight_layout()
    plt.savefig('homoscedasticity_check.png', dpi=150)
    plt.show()

    # Statistical tests
    X_with_const = sm.add_constant(X.reshape(-1, 1) if X.ndim == 1 else X)

    # Breusch-Pagan test
    bp_stat, bp_p, _, _ = het_breuschpagan(residuals, X_with_const)

    print("\nHOMOSCEDASTICITY TESTS:")
    print("-"*50)
    print(f"Breusch-Pagan test:")
    print(f"  Statistic: {bp_stat:.4f}")
    print(f"  P-value: {bp_p:.4f}")
    if bp_p < 0.05:
        print("  ✗ HETEROSCEDASTICITY detected!")
    else:
        print("  ✓ No significant heteroscedasticity")

    return bp_p

p_value = check_homoscedasticity(X, y, model)
Enter fullscreen mode Exit fullscreen mode
VISUAL PATTERNS:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

HOMOSCEDASTIC (GOOD):             HETEROSCEDASTIC (BAD):
Constant band width               Funnel shape

    │× × × × × × × × ×                │× ×           × ×  ×
    │─────────────────              0 │──× × × × × ×─────────
    │× × × × × × × × ×                │× ×     × × ×  ×  ×
    └───────────────── Predicted      └───────────────────── Predicted

                                   Width INCREASES → HETEROSCEDASTIC!
Enter fullscreen mode Exit fullscreen mode

How to Fix Heteroscedasticity

# Option 1: Transform the target variable
y_log = np.log(y)
model.fit(X.reshape(-1, 1), y_log)

# Option 2: Weighted Least Squares (WLS)
# Give less weight to high-variance observations
weights = 1 / (X ** 2)  # If variance ∝ X²
import statsmodels.api as sm
X_const = sm.add_constant(X)
wls_model = sm.WLS(y, X_const, weights=weights)
results = wls_model.fit()

# Option 3: Use robust standard errors
ols_model = sm.OLS(y, X_const)
results = ols_model.fit(cov_type='HC3')  # Heteroscedasticity-consistent

# Option 4: Use generalized least squares
from statsmodels.regression.linear_model import GLS
gls_model = GLS(y, X_const, sigma=weights)
results = gls_model.fit()
Enter fullscreen mode Exit fullscreen mode

The Complete Assumption Check

import numpy as np
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from scipy import stats
import statsmodels.api as sm
from statsmodels.stats.diagnostic import het_breuschpagan

def full_assumption_check(X, y, feature_name='X', target_name='y'):
    """
    Complete diagnostic check for linear regression assumptions.
    """
    # Fit model
    model = LinearRegression()
    X_reshaped = X.reshape(-1, 1) if X.ndim == 1 else X
    model.fit(X_reshaped, y)
    y_pred = model.predict(X_reshaped)
    residuals = y - y_pred

    print("="*70)
    print("LINEAR REGRESSION ASSUMPTION DIAGNOSTICS")
    print("="*70)

    # Create diagnostic plots
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))

    # =====================================================
    # 1. LINEARITY CHECK
    # =====================================================
    ax1 = axes[0, 0]
    ax1.scatter(y_pred, residuals, alpha=0.5)
    ax1.axhline(y=0, color='r', linestyle='--', linewidth=2)
    ax1.set_xlabel('Predicted Values')
    ax1.set_ylabel('Residuals')
    ax1.set_title('1. LINEARITY: Residuals vs Predicted\n(Should be random around 0)')

    # Simple linearity test: correlation between residuals and y_pred
    corr, _ = stats.spearmanr(y_pred, residuals**2)

    print("\n1. LINEARITY")
    print("-"*50)
    if abs(corr) < 0.3:
        print("   ✓ No obvious non-linear pattern detected")
    else:
        print("   ⚠️  Possible non-linear pattern (corr={:.3f})".format(corr))
    print("   → Check residual plot for curved patterns")

    # =====================================================
    # 2. INDEPENDENCE CHECK
    # =====================================================
    ax2 = axes[0, 1]
    ax2.plot(range(len(residuals)), residuals, 'o-', alpha=0.5, markersize=3)
    ax2.axhline(y=0, color='r', linestyle='--', linewidth=2)
    ax2.set_xlabel('Observation Order')
    ax2.set_ylabel('Residuals')
    ax2.set_title('2. INDEPENDENCE: Residuals Over Time\n(Should be random, no pattern)')

    # Durbin-Watson
    n = len(residuals)
    dw = np.sum(np.diff(residuals)**2) / np.sum(residuals**2)

    print("\n2. INDEPENDENCE")
    print("-"*50)
    print(f"   Durbin-Watson statistic: {dw:.3f}")
    if 1.5 < dw < 2.5:
        print("   ✓ No significant autocorrelation (DW in [1.5, 2.5])")
    else:
        print("   ⚠️  Possible autocorrelation!")

    # =====================================================
    # 3. NORMALITY CHECK
    # =====================================================
    ax3 = axes[1, 0]
    stats.probplot(residuals, dist="norm", plot=ax3)
    ax3.set_title('3. NORMALITY: Q-Q Plot\n(Points should follow diagonal line)')

    # Shapiro-Wilk test
    if len(residuals) <= 5000:
        sw_stat, sw_p = stats.shapiro(residuals)
    else:
        # For large samples, use D'Agostino
        sw_stat, sw_p = stats.normaltest(residuals)

    print("\n3. NORMALITY")
    print("-"*50)
    print(f"   Shapiro-Wilk p-value: {sw_p:.4f}")
    if sw_p > 0.05:
        print("   ✓ Residuals appear normally distributed")
    else:
        print("   ⚠️  Residuals may NOT be normally distributed")
    print(f"   Skewness: {stats.skew(residuals):.3f} (ideal: 0)")
    print(f"   Kurtosis: {stats.kurtosis(residuals):.3f} (ideal: 0)")

    # =====================================================
    # 4. HOMOSCEDASTICITY CHECK
    # =====================================================
    ax4 = axes[1, 1]
    ax4.scatter(y_pred, np.sqrt(np.abs(residuals)), alpha=0.5)
    ax4.set_xlabel('Predicted Values')
    ax4.set_ylabel('√|Residuals|')
    ax4.set_title('4. HOMOSCEDASTICITY: Scale-Location\n(Should be flat, not funnel-shaped)')

    # Add trend line
    z = np.polyfit(y_pred, np.sqrt(np.abs(residuals)), 1)
    p = np.poly1d(z)
    ax4.plot(sorted(y_pred), p(sorted(y_pred)), 'r-', linewidth=2)

    # Breusch-Pagan test
    X_const = sm.add_constant(X_reshaped)
    bp_stat, bp_p, _, _ = het_breuschpagan(residuals, X_const)

    print("\n4. HOMOSCEDASTICITY (Equal Variance)")
    print("-"*50)
    print(f"   Breusch-Pagan p-value: {bp_p:.4f}")
    if bp_p > 0.05:
        print("   ✓ Constant variance (homoscedastic)")
    else:
        print("   ⚠️  Non-constant variance (heteroscedastic)")

    plt.tight_layout()
    plt.savefig('full_assumption_check.png', dpi=150, bbox_inches='tight')
    plt.show()

    # =====================================================
    # SUMMARY
    # =====================================================
    print("\n" + "="*70)
    print("SUMMARY")
    print("="*70)

    issues = []
    if abs(corr) >= 0.3:
        issues.append("Linearity")
    if not (1.5 < dw < 2.5):
        issues.append("Independence")
    if sw_p <= 0.05:
        issues.append("Normality")
    if bp_p <= 0.05:
        issues.append("Homoscedasticity")

    if not issues:
        print("✓ All assumptions appear satisfied!")
        print("  Your linear regression results should be reliable.")
    else:
        print(f"⚠️  Potential issues with: {', '.join(issues)}")
        print("  Consider transformations, different models, or robust methods.")

    return {
        'linearity_corr': corr,
        'durbin_watson': dw,
        'shapiro_p': sw_p,
        'breusch_pagan_p': bp_p
    }

# Usage
results = full_assumption_check(X, y, 'Square Feet', 'Price')
Enter fullscreen mode Exit fullscreen mode

Quick Reference: The L.I.N.E. Check

Assumption What It Means How to Check Fix If Violated
Linearity X and Y have a straight-line relationship Residuals vs Predicted plot Polynomial features, transform X
Independence Each observation is independent Durbin-Watson test, time plot Time series models, clustered SE
Normality Residuals follow bell curve Q-Q plot, Shapiro-Wilk test Transform Y, robust methods
Equal Variance Residual spread is constant Scale-Location plot, BP test WLS, transform Y, robust SE

Key Takeaways

  1. Assumptions make or break your model — High R² means nothing if assumptions are violated

  2. L.I.N.E. is your checklist — Linearity, Independence, Normality, Equal variance

  3. Residual plots are your best friend — Most violations are visible in plots

  4. Violations have different severity:

    • Non-linearity → Predictions are wrong
    • Non-independence → Standard errors are wrong
    • Non-normality → Confidence intervals are wrong
    • Heteroscedasticity → Efficiency and inference are wrong
  5. Most violations can be fixed — Transformations, robust methods, different models

  6. Check BEFORE trusting results — Not after deployment fails


The One-Sentence Summary

TrueNav GPS was 99% accurate on straight highways in clear weather but completely useless in mountains during storms — linear regression is the same, incredibly powerful when its assumptions hold (linearity, independence, normality, equal variance) but potentially garbage when they don't, which is why you must check before trusting any result.


What's Next?

Now that you understand the assumptions, you're ready for:

  • Residual Analysis Deep Dive — Diagnosing problems from residuals
  • Feature Transformations — When to log, square, or polynomialize
  • Ridge and Lasso Regression — When assumptions are partially violated
  • Robust Regression — When outliers and violations are unavoidable

Follow me for the next article in this series!


Let's Connect!

If L.I.N.E. is now burned into your memory, drop a heart!

Questions? Ask in the comments — I read and respond to every one.

What assumption violation have you encountered? I once had a time series disguised as cross-sectional data. The R² was amazing, the model was useless! 📉


The difference between a model that works in testing and one that works in production? Checking assumptions. TrueNav's engineers never tested on mountain roads. Don't make the same mistake with your linear regression.


Share this with someone who just runs model.fit() and calls it a day. The assumptions check might save them from shipping garbage.

Happy diagnosing! 🔍

Top comments (0)