DEV Community

Cover image for Pruning Decision Trees: The Bonsai Master Who Taught ML Engineers When to Stop
Sachin Kr. Rajput
Sachin Kr. Rajput

Posted on

Pruning Decision Trees: The Bonsai Master Who Taught ML Engineers When to Stop

The One-Line Summary: Prevent decision tree overfitting by limiting growth (pre-pruning with max_depth, min_samples_split, min_samples_leaf) or by growing fully then cutting back (post-pruning with cost-complexity pruning), finding the sweet spot where the tree captures patterns without memorizing noise.


The Tale of Two Trees

In the Garden of Machine Learning, two decision trees were planted on the same day, fed the same training data.


Tree #1: Wild Willow (The Overfitter)

Wild Willow had one philosophy: "More splits = More knowledge!"

WILD WILLOW'S GROWTH:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Training Data: 100 patients, 10 features

Wild Willow kept splitting...
  Depth 1: "Age > 50?"
  Depth 2: "Blood pressure > 140?"
  Depth 3: "Cholesterol > 200?"
  Depth 5: "Patient ID = 47?" ← Wait, what?!
  Depth 10: "Visited on a Tuesday?" ← This is getting weird...
  Depth 20: "Had coffee that morning?" ← STOP!

Final tree:
  - Depth: 25 levels
  - Leaves: 98 (almost one per patient!)
  - Training accuracy: 100% 🎉
  - Test accuracy: 52% 😱

Wild Willow MEMORIZED the training data!
Each patient got their own personal leaf node.
New patients? Complete failure.
Enter fullscreen mode Exit fullscreen mode

Tree #2: Balanced Bonsai (The Generalizer)

Balanced Bonsai had a different philosophy: "Split only when it truly helps."

BALANCED BONSAI'S GROWTH:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Same Training Data: 100 patients, 10 features

Balanced Bonsai was selective...
  Depth 1: "Age > 50?" ← Strong predictor!
  Depth 2: "Blood pressure > 140?" ← Important!
  Depth 3: "Cholesterol > 200?" ← Useful!
  Depth 4: "Hmm, further splits don't help much..."
  STOP. No more splitting needed.

Final tree:
  - Depth: 4 levels
  - Leaves: 12
  - Training accuracy: 87%
  - Test accuracy: 85% ✓

Balanced Bonsai learned PATTERNS, not examples!
Slightly worse on training data.
MUCH better on new patients.
Enter fullscreen mode Exit fullscreen mode

The Gardener's Wisdom

The old gardener who tended both trees explained:

"Wild Willow grew without restraint, reaching for every data point like a branch reaching for every ray of sunlight. It captured everything — including the noise, the accidents, the meaningless quirks.

Balanced Bonsai knew when to stop. It captured the strong patterns and ignored the noise. That's why it thrives with new data while Wild Willow withers."

This is the essence of preventing overfitting.

![Overfitting Overview]

The overfitting problem: Wild Willow memorizes while Balanced Bonsai generalizes


What Is Overfitting?

OVERFITTING DEFINED:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Overfitting occurs when a model learns the training data
TOO WELL — including the noise and random fluctuations —
and fails to generalize to new, unseen data.

SYMPTOMS:
✗ Training accuracy MUCH higher than test accuracy
✗ Model is overly complex (deep tree, many leaves)
✗ Small changes in data cause big changes in predictions
✗ Model captures noise as if it were signal

ANALOGY:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Imagine a student who memorizes:
"Q: What's 2+2? A: 4"
"Q: What's 3+3? A: 6"
"Q: What's 5+5? A: 10"

They get 100% on the practice test!

But when asked "What's 4+4?", they're lost.
They memorized ANSWERS, not ADDITION.

An overfit decision tree does the same thing —
memorizing training examples instead of learning patterns.
Enter fullscreen mode Exit fullscreen mode

Why Do Decision Trees Overfit?

Decision trees are greedy and will keep splitting until every leaf is pure unless you stop them.

import numpy as np
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification

# Create a dataset
X, y = make_classification(
    n_samples=500, n_features=20, n_informative=10,
    n_redundant=5, random_state=42
)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42
)

print("THE OVERFITTING DEMONSTRATION")
print("="*60)

# Unrestricted tree (Wild Willow)
wild_tree = DecisionTreeClassifier(random_state=42)
wild_tree.fit(X_train, y_train)

print(f"\n🌳 WILD WILLOW (No restrictions):")
print(f"   Depth: {wild_tree.get_depth()}")
print(f"   Leaves: {wild_tree.get_n_leaves()}")
print(f"   Training Accuracy: {wild_tree.score(X_train, y_train):.2%}")
print(f"   Test Accuracy: {wild_tree.score(X_test, y_test):.2%}")
print(f"   Gap: {wild_tree.score(X_train, y_train) - wild_tree.score(X_test, y_test):.2%} ← OVERFITTING!")

# Restricted tree (Balanced Bonsai)
bonsai_tree = DecisionTreeClassifier(max_depth=5, min_samples_leaf=10, random_state=42)
bonsai_tree.fit(X_train, y_train)

print(f"\n🌿 BALANCED BONSAI (Pruned):")
print(f"   Depth: {bonsai_tree.get_depth()}")
print(f"   Leaves: {bonsai_tree.get_n_leaves()}")
print(f"   Training Accuracy: {bonsai_tree.score(X_train, y_train):.2%}")
print(f"   Test Accuracy: {bonsai_tree.score(X_test, y_test):.2%}")
print(f"   Gap: {bonsai_tree.score(X_train, y_train) - bonsai_tree.score(X_test, y_test):.2%} ← Healthy!")
Enter fullscreen mode Exit fullscreen mode

Output:

THE OVERFITTING DEMONSTRATION
============================================================

🌳 WILD WILLOW (No restrictions):
   Depth: 19
   Leaves: 156
   Training Accuracy: 100.00%
   Test Accuracy: 78.67%
   Gap: 21.33% ← OVERFITTING!

🌿 BALANCED BONSAI (Pruned):
   Depth: 5
   Leaves: 22
   Training Accuracy: 89.43%
   Test Accuracy: 86.00%
   Gap: 3.43% ← Healthy!
Enter fullscreen mode Exit fullscreen mode

The Two Pruning Strategies

Just like a real gardener, we have two approaches:

PRUNING STRATEGIES:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

1. PRE-PRUNING (Stop Early)
   "Don't let it grow wild in the first place!"

   Set limits BEFORE training:
   • max_depth: Maximum levels
   • min_samples_split: Min samples to split
   • min_samples_leaf: Min samples in leaf
   • max_leaf_nodes: Maximum leaves
   • max_features: Features to consider

   ✓ Fast and simple
   ✗ Might stop too early (miss good splits)


2. POST-PRUNING (Grow Then Cut)
   "Let it grow fully, then trim the excess!"

   Build full tree, then remove branches:
   • Cost-complexity pruning (ccp_alpha)
   • Reduced error pruning

   ✓ Considers the full picture
   ✓ Often finds better trees
   ✗ More computationally expensive
Enter fullscreen mode Exit fullscreen mode

Pre-Pruning: Setting Growth Limits

1. max_depth: How Deep Can It Grow?

import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification

# Create dataset
X, y = make_classification(n_samples=1000, n_features=20, 
                           n_informative=10, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

print("EFFECT OF max_depth")
print("="*60)
print(f"\n{'Depth':<10} {'Train Acc':<12} {'Test Acc':<12} {'Leaves':<10} {'Status'}")
print("-"*55)

depths = [1, 2, 3, 4, 5, 7, 10, 15, 20, None]
results = []

for depth in depths:
    tree = DecisionTreeClassifier(max_depth=depth, random_state=42)
    tree.fit(X_train, y_train)

    train_acc = tree.score(X_train, y_train)
    test_acc = tree.score(X_test, y_test)
    leaves = tree.get_n_leaves()

    gap = train_acc - test_acc
    if gap > 0.15:
        status = "⚠️ OVERFIT"
    elif gap > 0.05:
        status = "⚡ Moderate"
    else:
        status = "✅ Good"

    depth_str = str(depth) if depth else "None"
    print(f"{depth_str:<10} {train_acc:<12.2%} {test_acc:<12.2%} {leaves:<10} {status}")

    results.append((depth if depth else 25, train_acc, test_acc))
Enter fullscreen mode Exit fullscreen mode

Output:

EFFECT OF max_depth
============================================================

Depth      Train Acc    Test Acc     Leaves     Status
-------------------------------------------------------
1          0.77         0.76         2          ✅ Good
2          0.84         0.81         4          ✅ Good
3          0.88         0.84         8          ✅ Good
4          0.91         0.86         14         ✅ Good
5          0.93         0.87         22         ⚡ Moderate
7          0.97         0.86         54         ⚡ Moderate
10         0.99         0.84         118        ⚠️ OVERFIT
15         1.00         0.81         198        ⚠️ OVERFIT
20         1.00         0.80         224        ⚠️ OVERFIT
None       1.00         0.79         238        ⚠️ OVERFIT
Enter fullscreen mode Exit fullscreen mode

![Max Depth Effect]

As depth increases, training accuracy climbs to 100% but test accuracy peaks then drops — classic overfitting!


2. min_samples_split: Minimum Samples to Split

print("\nEFFECT OF min_samples_split")
print("="*60)
print(f"\n{'Min Split':<12} {'Train Acc':<12} {'Test Acc':<12} {'Depth':<8} {'Leaves':<10}")
print("-"*55)

min_splits = [2, 5, 10, 20, 50, 100, 200]

for min_split in min_splits:
    tree = DecisionTreeClassifier(min_samples_split=min_split, random_state=42)
    tree.fit(X_train, y_train)

    print(f"{min_split:<12} {tree.score(X_train, y_train):<12.2%} "
          f"{tree.score(X_test, y_test):<12.2%} {tree.get_depth():<8} {tree.get_n_leaves():<10}")
Enter fullscreen mode Exit fullscreen mode
EFFECT OF min_samples_split
============================================================

Min Split    Train Acc    Test Acc     Depth    Leaves    
-------------------------------------------------------
2            100.00%      79.00%       20       238       
5            100.00%      80.33%       18       192       
10           99.00%       82.00%       15       132       
20           96.57%       84.67%       12       76        
50           91.71%       86.33%       8        37        
100          87.00%       85.33%       6        18        
200          81.29%       81.00%       4        8         
Enter fullscreen mode Exit fullscreen mode
min_samples_split EXPLAINED:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

"A node must have AT LEAST this many samples to be split."

min_samples_split=2 (default):
  Even a node with just 2 samples can be split!
  → Leads to very deep trees, overfitting

min_samples_split=50:
  A node needs 50+ samples to consider splitting.
  → Stops splitting when data gets too thin
  → Prevents memorizing small groups

INTUITION:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━
With only 10 samples in a node, any pattern you find
is likely NOISE, not a real pattern.

With 100+ samples, patterns are more likely to be REAL.
Enter fullscreen mode Exit fullscreen mode

3. min_samples_leaf: Minimum Samples in a Leaf

print("\nEFFECT OF min_samples_leaf")
print("="*60)
print(f"\n{'Min Leaf':<12} {'Train Acc':<12} {'Test Acc':<12} {'Depth':<8} {'Leaves':<10}")
print("-"*55)

min_leafs = [1, 2, 5, 10, 20, 50, 100]

for min_leaf in min_leafs:
    tree = DecisionTreeClassifier(min_samples_leaf=min_leaf, random_state=42)
    tree.fit(X_train, y_train)

    print(f"{min_leaf:<12} {tree.score(X_train, y_train):<12.2%} "
          f"{tree.score(X_test, y_test):<12.2%} {tree.get_depth():<8} {tree.get_n_leaves():<10}")
Enter fullscreen mode Exit fullscreen mode
EFFECT OF min_samples_leaf
============================================================

Min Leaf     Train Acc    Test Acc     Depth    Leaves    
-------------------------------------------------------
1            100.00%      79.00%       20       238       
2            100.00%      80.00%       19       190       
5            97.71%       83.33%       15       113       
10           94.00%       85.67%       11       58        
20           89.43%       86.00%       8        32        
50           83.43%       83.00%       5        14        
100          77.14%       77.67%       3        7         
Enter fullscreen mode Exit fullscreen mode
min_samples_leaf EXPLAINED:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

"Every leaf must have AT LEAST this many samples."

min_samples_leaf=1 (default):
  A leaf can have just 1 sample!
  → Creates very specific (memorized) leaves

min_samples_leaf=20:
  Every leaf needs 20+ samples.
  → Each prediction is based on 20+ examples
  → More statistically reliable predictions

DIFFERENCE FROM min_samples_split:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

min_samples_split: Can I split this node?
min_samples_leaf:  Are the resulting leaves big enough?

Example with min_samples_leaf=10:
  Node has 50 samples.
  Split would create: 45 left, 5 right.
  REJECTED! Right leaf has only 5 < 10.
Enter fullscreen mode Exit fullscreen mode

4. max_leaf_nodes: Cap the Total Leaves

print("\nEFFECT OF max_leaf_nodes")
print("="*60)
print(f"\n{'Max Leaves':<12} {'Train Acc':<12} {'Test Acc':<12} {'Depth':<8} {'Actual Leaves':<15}")
print("-"*60)

max_leaves_list = [2, 5, 10, 20, 50, 100, None]

for max_leaves in max_leaves_list:
    tree = DecisionTreeClassifier(max_leaf_nodes=max_leaves, random_state=42)
    tree.fit(X_train, y_train)

    max_str = str(max_leaves) if max_leaves else "None"
    print(f"{max_str:<12} {tree.score(X_train, y_train):<12.2%} "
          f"{tree.score(X_test, y_test):<12.2%} {tree.get_depth():<8} {tree.get_n_leaves():<15}")
Enter fullscreen mode Exit fullscreen mode
EFFECT OF max_leaf_nodes
============================================================

Max Leaves   Train Acc    Test Acc     Depth    Actual Leaves  
------------------------------------------------------------
2            77.14%       76.33%       1        2              
5            85.00%       82.67%       3        5              
10           89.86%       85.33%       5        10             
20           93.14%       86.33%       7        20             
50           97.57%       85.33%       12       50             
100          99.43%       83.67%       16       100            
None         100.00%      79.00%       20       238            
Enter fullscreen mode Exit fullscreen mode
max_leaf_nodes EXPLAINED:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

"The tree can have AT MOST this many leaf nodes."

max_leaf_nodes=20:
  Tree grows, but stops when it hits 20 leaves.
  The algorithm prioritizes the BEST splits.

ADVANTAGE:
  - Direct control over model complexity
  - Tree picks the most valuable splits
  - Very interpretable (you know exact size)

WHEN TO USE:
  - When you need a specific complexity level
  - When interpretability is crucial
  - When you want to compare models of equal size
Enter fullscreen mode Exit fullscreen mode

5. max_features: Limit Features Per Split

print("\nEFFECT OF max_features")
print("="*60)
print(f"\n{'Max Features':<15} {'Train Acc':<12} {'Test Acc':<12} {'Depth':<8}")
print("-"*50)

max_features_list = [1, 5, 10, 'sqrt', 'log2', None]

for max_feat in max_features_list:
    tree = DecisionTreeClassifier(max_features=max_feat, random_state=42)
    tree.fit(X_train, y_train)

    print(f"{str(max_feat):<15} {tree.score(X_train, y_train):<12.2%} "
          f"{tree.score(X_test, y_test):<12.2%} {tree.get_depth():<8}")
Enter fullscreen mode Exit fullscreen mode
max_features EXPLAINED:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

"At each split, consider only this many features."

max_features=None (default):
  Consider ALL features at every split.

max_features='sqrt':
  Consider √n features (n = total features).
  For 20 features: √20 ≈ 4 features per split.

max_features='log2':
  Consider log₂(n) features.
  For 20 features: log₂(20) ≈ 4 features.

WHY LIMIT FEATURES?
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

1. Adds randomness → Reduces overfitting
2. Faster training (fewer comparisons)
3. Key ingredient in Random Forests!
4. Prevents over-reliance on dominant features
Enter fullscreen mode Exit fullscreen mode

Post-Pruning: Grow Then Cut Back

Cost-Complexity Pruning (ccp_alpha)

This is the most powerful pruning technique:

COST-COMPLEXITY PRUNING:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

1. Grow the full tree (no restrictions)
2. Calculate the "cost" of each subtree
3. Remove subtrees that aren't worth their complexity

THE FORMULA:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Cost = Impurity + α × (Number of Leaves)

Where α (alpha) is the complexity penalty.

• α = 0: No penalty → Full tree (overfit)
• α = large: Heavy penalty → Tiny tree (underfit)
• α = just right: Optimal trade-off!


INTUITION:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

"Is this split WORTH the added complexity?"

If a split reduces impurity by 0.001 but adds a leaf,
and α = 0.01, then:
  Benefit: 0.001 (impurity reduction)
  Cost: 0.01 (penalty for extra leaf)
  → NOT WORTH IT! Prune this split.
Enter fullscreen mode Exit fullscreen mode
from sklearn.tree import DecisionTreeClassifier
import numpy as np

# Get the cost-complexity pruning path
tree = DecisionTreeClassifier(random_state=42)
path = tree.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = path.ccp_alphas
impurities = path.impurities

print("COST-COMPLEXITY PRUNING PATH")
print("="*60)
print(f"\nFound {len(ccp_alphas)} alpha values to test")
print(f"Alpha range: {ccp_alphas.min():.6f} to {ccp_alphas.max():.6f}")

# Train trees for different alphas
trees = []
train_scores = []
test_scores = []

for alpha in ccp_alphas:
    tree = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42)
    tree.fit(X_train, y_train)
    trees.append(tree)
    train_scores.append(tree.score(X_train, y_train))
    test_scores.append(tree.score(X_test, y_test))

# Find optimal alpha
best_idx = np.argmax(test_scores)
best_alpha = ccp_alphas[best_idx]
best_tree = trees[best_idx]

print(f"\n{'Alpha':<12} {'Train Acc':<12} {'Test Acc':<12} {'Leaves':<10}")
print("-"*50)

# Show selected alphas
indices = [0, len(ccp_alphas)//4, len(ccp_alphas)//2, 
           best_idx, 3*len(ccp_alphas)//4, len(ccp_alphas)-1]
indices = sorted(set(indices))

for i in indices:
    print(f"{ccp_alphas[i]:<12.6f} {train_scores[i]:<12.2%} "
          f"{test_scores[i]:<12.2%} {trees[i].get_n_leaves():<10}")

print(f"\n🏆 OPTIMAL: alpha={best_alpha:.6f}, Test Acc={test_scores[best_idx]:.2%}")
Enter fullscreen mode Exit fullscreen mode

Output:

COST-COMPLEXITY PRUNING PATH
============================================================

Found 156 alpha values to test
Alpha range: 0.000000 to 0.064286

Alpha        Train Acc    Test Acc     Leaves    
--------------------------------------------------
0.000000     100.00%      79.00%       238       
0.000429     98.29%       82.33%       139       
0.001286     95.57%       85.00%       73        
0.002667     91.86%       87.00%       37        
0.007273     86.43%       85.67%       17        
0.064286     77.14%       76.33%       2         

🏆 OPTIMAL: alpha=0.002667, Test Acc=87.00%
Enter fullscreen mode Exit fullscreen mode

![CCP Alpha Effect]

Cost-complexity pruning finds the optimal alpha where test accuracy peaks


Finding the Best Alpha with Cross-Validation

from sklearn.model_selection import cross_val_score
import numpy as np

print("FINDING OPTIMAL ALPHA WITH CROSS-VALIDATION")
print("="*60)

# Get alpha candidates
tree = DecisionTreeClassifier(random_state=42)
path = tree.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = path.ccp_alphas

# Use fewer alphas for speed
alpha_candidates = ccp_alphas[::5]  # Every 5th alpha

cv_scores = []
for alpha in alpha_candidates:
    tree = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42)
    scores = cross_val_score(tree, X_train, y_train, cv=5)
    cv_scores.append(scores.mean())

# Find best
best_idx = np.argmax(cv_scores)
best_alpha = alpha_candidates[best_idx]

print(f"\nBest alpha from CV: {best_alpha:.6f}")
print(f"CV Score: {cv_scores[best_idx]:.2%}")

# Train final model
final_tree = DecisionTreeClassifier(ccp_alpha=best_alpha, random_state=42)
final_tree.fit(X_train, y_train)

print(f"\nFinal Model:")
print(f"  Depth: {final_tree.get_depth()}")
print(f"  Leaves: {final_tree.get_n_leaves()}")
print(f"  Training Accuracy: {final_tree.score(X_train, y_train):.2%}")
print(f"  Test Accuracy: {final_tree.score(X_test, y_test):.2%}")
Enter fullscreen mode Exit fullscreen mode

The Complete Pruning Toolkit

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import GridSearchCV
import numpy as np

print("THE COMPLETE PRUNING TOOLKIT")
print("="*60)

# All pruning parameters in one place
param_grid = {
    'max_depth': [3, 5, 7, 10, None],
    'min_samples_split': [2, 10, 20, 50],
    'min_samples_leaf': [1, 5, 10, 20],
    'max_leaf_nodes': [10, 20, 50, None],
}

print(f"\nSearching {np.prod([len(v) for v in param_grid.values()])} combinations...")

tree = DecisionTreeClassifier(random_state=42)
grid_search = GridSearchCV(tree, param_grid, cv=5, scoring='accuracy', n_jobs=-1)
grid_search.fit(X_train, y_train)

print(f"\n🏆 Best Parameters:")
for param, value in grid_search.best_params_.items():
    print(f"   {param}: {value}")

print(f"\nCV Score: {grid_search.best_score_:.2%}")
print(f"Test Score: {grid_search.score(X_test, y_test):.2%}")

best_tree = grid_search.best_estimator_
print(f"\nBest Tree Structure:")
print(f"   Depth: {best_tree.get_depth()}")
print(f"   Leaves: {best_tree.get_n_leaves()}")
Enter fullscreen mode Exit fullscreen mode

Output:

THE COMPLETE PRUNING TOOLKIT
============================================================

Searching 320 combinations...

🏆 Best Parameters:
   max_depth: 5
   max_leaf_nodes: 20
   min_samples_leaf: 5
   min_samples_split: 10

CV Score: 86.14%
Test Score: 87.33%

Best Tree Structure:
   Depth: 5
   Leaves: 19
Enter fullscreen mode Exit fullscreen mode

Visualizing Overfitting

import numpy as np
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier

# Create data with noise
np.random.seed(42)
X = np.linspace(0, 10, 100).reshape(-1, 1)
y_true = np.sin(X).ravel()
y = y_true + np.random.randn(100) * 0.3

X_train, X_test = X[:70], X[70:]
y_train, y_test = y[:70], y[70:]

# Fit trees of different depths
depths = [1, 3, 5, 10, 20]
X_plot = np.linspace(0, 10, 200).reshape(-1, 1)

fig, axes = plt.subplots(1, 5, figsize=(20, 4))

for ax, depth in zip(axes, depths):
    tree = DecisionTreeRegressor(max_depth=depth, random_state=42)
    tree.fit(X_train, y_train)

    y_pred = tree.predict(X_plot)

    ax.scatter(X_train, y_train, c='blue', alpha=0.5, label='Train')
    ax.scatter(X_test, y_test, c='red', alpha=0.5, label='Test')
    ax.plot(X_plot, y_pred, 'g-', linewidth=2, label='Prediction')
    ax.plot(X_plot, np.sin(X_plot), 'k--', alpha=0.5, label='True')

    train_score = tree.score(X_train, y_train)
    test_score = tree.score(X_test, y_test)

    ax.set_title(f'Depth={depth}\nTrain R²={train_score:.2f}, Test R²={test_score:.2f}')
    ax.legend(fontsize=8)
    ax.set_xlim(0, 10)

plt.suptitle('Effect of Tree Depth on Overfitting', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig('depth_overfitting_visual.png', dpi=150, bbox_inches='tight')
plt.show()
Enter fullscreen mode Exit fullscreen mode

![Depth Overfitting Visual]

As depth increases, the tree fits training data better but test performance degrades — the predictions become jagged and overfit to noise


The Bias-Variance Trade-off

THE FUNDAMENTAL TRADE-OFF:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Total Error = Bias² + Variance + Irreducible Noise


BIAS (Underfitting):
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

"The model is too simple to capture the pattern."

Symptoms:
• Both training AND test accuracy are low
• Model makes systematic errors
• Tree is too shallow

Example: Depth=1 tree trying to fit a complex pattern.


VARIANCE (Overfitting):
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

"The model is too complex and captures noise."

Symptoms:
• Training accuracy high, test accuracy low
• Model changes drastically with different data
• Tree is too deep

Example: Depth=20 tree memorizing training data.


THE SWEET SPOT:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

        │
   E    │    ╲  Variance
   r    │     ╲
   r    │      ╲____
   o    │    ___    ╲
   r    │   ╱       ╲
        │  ╱  Total   ╲
        │ ╱   Error    ╲
        │╱_______________╲______
        │     Bias²     ╲
        │________________╲_______
                          ↑
                    Sweet Spot
                    (Optimal Complexity)
Enter fullscreen mode Exit fullscreen mode

![Bias Variance Tradeoff]

Finding the sweet spot: enough complexity to capture patterns, not so much that we capture noise


Practical Guidelines

WHEN TO USE EACH TECHNIQUE:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

max_depth: 
  Start here! Most intuitive.
  Try: 3, 5, 7, 10
  Use when: You want direct control over tree size.

min_samples_leaf:
  Very effective! Ensures statistical reliability.
  Try: 1% to 5% of training data (e.g., 10-50)
  Use when: You want each prediction backed by data.

min_samples_split:
  Similar to min_samples_leaf but less strict.
  Try: 2× your min_samples_leaf value
  Use when: You want nodes to have enough data before splitting.

max_leaf_nodes:
  Direct complexity control.
  Try: 10, 20, 50
  Use when: You need exactly N complexity levels.

ccp_alpha:
  Most sophisticated! Automatic optimization.
  Find with cross-validation.
  Use when: You want the algorithm to find optimal pruning.


RECOMMENDED WORKFLOW:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

1. Start with max_depth=5, min_samples_leaf=10
2. Check train vs test gap
3. If gap > 10%: More pruning needed
4. If both low: Less pruning needed
5. Use GridSearchCV to find optimal combination
6. Consider ccp_alpha for fine-tuning
Enter fullscreen mode Exit fullscreen mode

Complete Example: From Overfit to Optimal

import numpy as np
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
import warnings
warnings.filterwarnings('ignore')

# Load real dataset
data = load_breast_cancer()
X, y = data.data, data.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

print("FROM OVERFIT TO OPTIMAL: A COMPLETE WORKFLOW")
print("="*60)
print(f"\nDataset: Breast Cancer (569 samples, 30 features)")
print(f"Training: {len(X_train)}, Test: {len(X_test)}")

# Step 1: Baseline (overfit)
print("\n" + "="*60)
print("STEP 1: Baseline (No Pruning)")
print("="*60)

baseline = DecisionTreeClassifier(random_state=42)
baseline.fit(X_train, y_train)

print(f"Depth: {baseline.get_depth()}, Leaves: {baseline.get_n_leaves()}")
print(f"Training: {baseline.score(X_train, y_train):.2%}")
print(f"Test: {baseline.score(X_test, y_test):.2%}")
print(f"Gap: {baseline.score(X_train, y_train) - baseline.score(X_test, y_test):.2%} ← OVERFITTING!")

# Step 2: Simple pruning
print("\n" + "="*60)
print("STEP 2: Simple Pre-Pruning")
print("="*60)

simple = DecisionTreeClassifier(max_depth=5, min_samples_leaf=5, random_state=42)
simple.fit(X_train, y_train)

print(f"Depth: {simple.get_depth()}, Leaves: {simple.get_n_leaves()}")
print(f"Training: {simple.score(X_train, y_train):.2%}")
print(f"Test: {simple.score(X_test, y_test):.2%}")
print(f"Gap: {simple.score(X_train, y_train) - simple.score(X_test, y_test):.2%}")

# Step 3: Grid search
print("\n" + "="*60)
print("STEP 3: Grid Search Optimization")
print("="*60)

param_grid = {
    'max_depth': [3, 5, 7, 10],
    'min_samples_split': [2, 10, 20],
    'min_samples_leaf': [1, 5, 10],
    'max_leaf_nodes': [10, 20, 30, None]
}

grid = GridSearchCV(
    DecisionTreeClassifier(random_state=42),
    param_grid, cv=5, scoring='accuracy', n_jobs=-1
)
grid.fit(X_train, y_train)

print(f"Best Parameters: {grid.best_params_}")
print(f"CV Score: {grid.best_score_:.2%}")
print(f"Test Score: {grid.score(X_test, y_test):.2%}")

best_tree = grid.best_estimator_
print(f"Best Tree: Depth={best_tree.get_depth()}, Leaves={best_tree.get_n_leaves()}")

# Step 4: Cost-complexity pruning
print("\n" + "="*60)
print("STEP 4: Cost-Complexity Pruning")
print("="*60)

# Find optimal alpha
path = DecisionTreeClassifier(random_state=42).cost_complexity_pruning_path(X_train, y_train)
alphas = path.ccp_alphas[:-1]  # Remove last (trivial tree)

cv_scores = []
for alpha in alphas:
    tree = DecisionTreeClassifier(ccp_alpha=alpha, random_state=42)
    scores = cross_val_score(tree, X_train, y_train, cv=5)
    cv_scores.append(scores.mean())

best_alpha = alphas[np.argmax(cv_scores)]
ccp_tree = DecisionTreeClassifier(ccp_alpha=best_alpha, random_state=42)
ccp_tree.fit(X_train, y_train)

print(f"Optimal Alpha: {best_alpha:.6f}")
print(f"Depth: {ccp_tree.get_depth()}, Leaves: {ccp_tree.get_n_leaves()}")
print(f"Training: {ccp_tree.score(X_train, y_train):.2%}")
print(f"Test: {ccp_tree.score(X_test, y_test):.2%}")

# Summary
print("\n" + "="*60)
print("SUMMARY: TEST ACCURACY COMPARISON")
print("="*60)
print(f"Baseline (overfit):     {baseline.score(X_test, y_test):.2%}")
print(f"Simple pruning:         {simple.score(X_test, y_test):.2%}")
print(f"Grid search optimized:  {grid.score(X_test, y_test):.2%}")
print(f"CCP optimized:          {ccp_tree.score(X_test, y_test):.2%}")
Enter fullscreen mode Exit fullscreen mode

Output:

FROM OVERFIT TO OPTIMAL: A COMPLETE WORKFLOW
============================================================

Dataset: Breast Cancer (569 samples, 30 features)
Training: 398, Test: 171

============================================================
STEP 1: Baseline (No Pruning)
============================================================
Depth: 7, Leaves: 21
Training: 100.00%
Test: 93.57%
Gap: 6.43% ← OVERFITTING!

============================================================
STEP 2: Simple Pre-Pruning
============================================================
Depth: 5, Leaves: 14
Training: 98.49%
Test: 95.32%
Gap: 3.17%

============================================================
STEP 3: Grid Search Optimization
============================================================
Best Parameters: {'max_depth': 5, 'max_leaf_nodes': 10, 'min_samples_leaf': 5, 'min_samples_split': 10}
CV Score: 93.72%
Test Score: 95.91%
Best Tree: Depth=5, Leaves=10

============================================================
STEP 4: Cost-Complexity Pruning
============================================================
Optimal Alpha: 0.010050
Depth: 4, Leaves: 8
Training: 96.48%
Test: 96.49%

============================================================
SUMMARY: TEST ACCURACY COMPARISON
============================================================
Baseline (overfit):     93.57%
Simple pruning:         95.32%
Grid search optimized:  95.91%
CCP optimized:          96.49%
Enter fullscreen mode Exit fullscreen mode

Quick Reference Card

DECISION TREE PRUNING: CHEAT SHEET
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

PRE-PRUNING PARAMETERS:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

max_depth:         Maximum tree depth
                   Start with: 3-10

min_samples_split: Min samples to split a node
                   Start with: 10-50 (or 1-5% of data)

min_samples_leaf:  Min samples in each leaf
                   Start with: 5-20 (or 0.5-2% of data)

max_leaf_nodes:    Maximum number of leaves
                   Start with: 10-50

max_features:      Features considered per split
                   Options: None, 'sqrt', 'log2', int


POST-PRUNING:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

ccp_alpha:         Complexity penalty
                   Find with cross-validation
                   Higher = more pruning


SIGNS OF OVERFITTING:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

✗ Training acc >> Test acc (gap > 10%)
✗ Very deep tree (depth > 15)
✗ Many leaves (close to number of samples)
✗ Perfect training accuracy (100%)


WORKFLOW:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

1. Train baseline (no restrictions)
2. Check train vs test gap
3. Apply pre-pruning (max_depth, min_samples_leaf)
4. Use GridSearchCV for optimization
5. Try ccp_alpha for fine-tuning
6. Select model with best cross-validation score
Enter fullscreen mode Exit fullscreen mode

Key Takeaways

  1. Overfitting = memorizing, not learning — The tree captures noise instead of patterns

  2. Signs of overfitting: High training accuracy, low test accuracy, very deep tree

  3. Pre-pruning stops early — Set limits before training (max_depth, min_samples_*)

  4. Post-pruning cuts back — Grow fully, then remove branches (ccp_alpha)

  5. max_depth is your first tool — Start with 3-10, adjust based on results

  6. min_samples_leaf ensures reliability — Each prediction is backed by enough data

  7. ccp_alpha is most sophisticated — Automatically finds optimal pruning level

  8. Use cross-validation — Never tune on test data, use GridSearchCV


The One-Sentence Summary

Preventing decision tree overfitting is like being a bonsai master: Wild Willow grew in every direction and captured every quirk (memorized), while Balanced Bonsai grew strategically with max_depth, min_samples_leaf, and ccp_alpha to capture only the important patterns (generalized) — and that's why Balanced Bonsai thrives with new data while Wild Willow withers.


What's Next?

Now that you understand overfitting prevention, you're ready for:

  1. Random Forests — Many pruned trees voting together
  2. Ensemble Methods — Combining models for better results
  3. Gradient Boosting — Trees that learn from mistakes
  4. Feature Importance — Which features matter most?

Follow me for the next article in the Tree Based Models series!


Let's Connect!

If the bonsai master made pruning click, drop a heart!

Questions? Ask in the comments — I read and respond to every one.

What's your go-to pruning strategy? I usually start with max_depth=5 and min_samples_leaf=10, then fine-tune with ccp_alpha! 🌿


The difference between a wild tree and a bonsai? Strategic cuts. The wild tree reaches everywhere but masters nothing; the bonsai focuses its energy and creates beauty. Your decision tree can be either — the choice is in your hyperparameters.


Share this with someone struggling with overfitting. The bonsai master awaits!

Happy pruning! ✂️🌳

Top comments (0)