DEV Community

Cover image for Decision Trees: The Detective Who Solves Cases by Asking Yes/No Questions
Sachin Kr. Rajput
Sachin Kr. Rajput

Posted on

Decision Trees: The Detective Who Solves Cases by Asking Yes/No Questions

The One-Line Summary: A decision tree makes predictions by asking a series of yes/no questions about the features, splitting the data at each step until it reaches a conclusion — like a game of "20 Questions" that learns which questions to ask from the training data.


The Detective's Method

Detective Oak had an unusual method. While other detectives gathered evidence for months, Oak solved cases in minutes by asking exactly the right questions in exactly the right order.


The Case of the Missing Desserts

The Grand Hotel reported that desserts were disappearing from the kitchen every night. There were 100 staff members. Any of them could be the culprit.

Detective Oak arrived and announced: "I will find your thief by asking just a few questions."

DETECTIVE OAK'S INVESTIGATION:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

100 staff members. One is stealing desserts.

QUESTION 1: "Does this person work the night shift?"
├── YES (35 people) ← Desserts disappear at night!
└── NO (65 people) — Unlikely, desserts vanish at night

    Focus on the 35 night shift workers.

QUESTION 2: "Does this person have kitchen access?"
├── YES (12 people) ← Must access kitchen to steal!
└── NO (23 people) — Can't reach the desserts

    Focus on the 12 with kitchen access.

QUESTION 3: "Has this person been seen near the 
            dessert station after midnight?"
├── YES (3 people) ← Very suspicious!
└── NO (9 people) — Less likely

    Focus on the 3 suspects.

QUESTION 4: "Does this person have chocolate stains
            on their uniform?"
├── YES (1 person) ← CAUGHT!
└── NO (2 people) — Probably innocent

CULPRIT IDENTIFIED: Night-shift baker with chocolate stains.

100 people → 35 → 12 → 3 → 1
Just 4 questions to find the thief!
Enter fullscreen mode Exit fullscreen mode

The hotel manager was amazed. "How did you know which questions to ask?"

Detective Oak smiled. "I asked questions that SPLIT the suspects most effectively. Each question eliminated the maximum number of innocent people while keeping the guilty one in focus."


This Is Exactly How Decision Trees Work

![Decision Trees: How They Work]

The four key concepts: Tree structure, Gini impurity, Information gain, and the overfitting danger

A DECISION TREE:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

                    [Night Shift?]
                    /            \
                  YES            NO
                  /                \
         [Kitchen Access?]      [INNOCENT]
          /          \
        YES          NO
        /              \
  [Near Desserts      [INNOCENT]
   After Midnight?]
    /        \
  YES        NO
  /            \
[Chocolate    [INNOCENT]
 Stains?]
 /     \
YES    NO
 |       |
GUILTY  INNOCENT


Each internal node = A QUESTION (feature test)
Each branch = An ANSWER (yes/no)
Each leaf = A PREDICTION (guilty/innocent)
Enter fullscreen mode Exit fullscreen mode

The Anatomy of a Decision Tree

TREE TERMINOLOGY:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

                [ROOT NODE]          ← First question
               /           \            (most important split)
              /             \
        [INTERNAL]      [INTERNAL]   ← Follow-up questions
        /      \        /      \
       /        \      /        \
    [LEAF]  [LEAF] [LEAF]   [LEAF]  ← Final predictions
                                        (no more questions)


ROOT NODE:    The first question asked
              Splits ALL data

INTERNAL NODE: Intermediate questions
               Splits a SUBSET of data

LEAF NODE:    Final prediction
              No more splits
              Also called "terminal node"

BRANCH:       The path from one node to another
              Represents an answer (yes/no)

DEPTH:        How many questions from root to leaf
              Deeper = More specific = Risk of overfitting
Enter fullscreen mode Exit fullscreen mode

How Does the Tree Know Which Question to Ask?

This is the key insight. Detective Oak didn't ask random questions — he asked questions that split the suspects most effectively.

But what makes a split "effective"?


The Goal: Purity

THE CONCEPT OF PURITY:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

A node is "PURE" if all samples in it belong to ONE class.

PURE NODE (perfect):
┌─────────────────┐
│ ● ● ● ● ● ● ● ● │  All same class!
│ ● ● ● ● ● ● ● ● │  We can confidently predict.
└─────────────────┘

IMPURE NODE (mixed):
┌─────────────────┐
│ ● ● ○ ● ○ ○ ● ● │  Mixed classes!
│ ○ ● ● ○ ● ○ ○ ● │  We're uncertain.
└─────────────────┘


THE GOAL OF EACH SPLIT:
Make the child nodes MORE PURE than the parent.

BEFORE SPLIT:          AFTER SPLIT:
┌─────────────┐       ┌───────┐  ┌───────┐
│ ● ● ○ ○ ● ○ │  →    │ ● ● ● │  │ ○ ○ ○ │
│ ● ○ ● ○ ● ○ │       │ ● ● ● │  │ ○ ○ ○ │
└─────────────┘       └───────┘  └───────┘
   (impure)           (pure!)    (pure!)

A good question SEPARATES the classes!
Enter fullscreen mode Exit fullscreen mode

Measuring Impurity: Gini Index

The most common way to measure impurity:

GINI IMPURITY:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Gini(node) = 1 - Σ(pᵢ)²

Where pᵢ = proportion of class i in the node


EXAMPLE 1: Pure node (all class A)
Proportions: p_A = 1.0, p_B = 0.0
Gini = 1 - (1.0² + 0.0²) = 1 - 1 = 0.0 ← PURE!


EXAMPLE 2: Perfectly mixed (50% each)
Proportions: p_A = 0.5, p_B = 0.5
Gini = 1 - (0.5² + 0.5²) = 1 - 0.5 = 0.5 ← IMPURE!


EXAMPLE 3: Mostly class A (80/20)
Proportions: p_A = 0.8, p_B = 0.2
Gini = 1 - (0.8² + 0.2²) = 1 - 0.68 = 0.32 ← Somewhat pure


INTERPRETATION:
Gini = 0.0    Perfect purity (all one class)
Gini = 0.5    Maximum impurity (for 2 classes)
Lower Gini = Better (more pure)
Enter fullscreen mode Exit fullscreen mode

![Gini Impurity Visualization]

Gini impurity ranges from 0 (pure) to 0.5 (maximum impurity for binary classification)

import numpy as np

def gini_impurity(labels):
    """Calculate Gini impurity of a node."""
    if len(labels) == 0:
        return 0

    # Count each class
    _, counts = np.unique(labels, return_counts=True)
    proportions = counts / len(labels)

    # Gini = 1 - sum(p^2)
    return 1 - np.sum(proportions ** 2)

# Examples
print("GINI IMPURITY EXAMPLES")
print("="*50)

examples = [
    ("Pure (all A)", ['A']*10),
    ("Pure (all B)", ['B']*10),
    ("50-50 split", ['A']*5 + ['B']*5),
    ("80-20 split", ['A']*8 + ['B']*2),
    ("90-10 split", ['A']*9 + ['B']*1),
]

for name, labels in examples:
    gini = gini_impurity(labels)
    print(f"{name:<20} Gini = {gini:.4f}")
Enter fullscreen mode Exit fullscreen mode

Output:

GINI IMPURITY EXAMPLES
==================================================
Pure (all A)         Gini = 0.0000
Pure (all B)         Gini = 0.0000
50-50 split          Gini = 0.5000
80-20 split          Gini = 0.3200
90-10 split          Gini = 0.1800
Enter fullscreen mode Exit fullscreen mode

Measuring Impurity: Entropy

An alternative measure from information theory:

ENTROPY:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Entropy(node) = -Σ pᵢ × log₂(pᵢ)

Where pᵢ = proportion of class i


EXAMPLE 1: Pure node (all class A)
Entropy = -1.0 × log₂(1.0) = 0.0 ← PURE!


EXAMPLE 2: Perfectly mixed (50% each)
Entropy = -0.5 × log₂(0.5) - 0.5 × log₂(0.5)
        = -0.5 × (-1) - 0.5 × (-1)
        = 1.0 ← IMPURE!


INTERPRETATION:
Entropy = 0.0    Perfect purity
Entropy = 1.0    Maximum impurity (for 2 classes)
Lower Entropy = Better (more pure)


GINI vs ENTROPY:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Both measure impurity. Both work well.
Gini is slightly faster (no logarithm).
Entropy has information-theoretic meaning.
In practice: Results are usually very similar.
Scikit-learn uses Gini by default.
Enter fullscreen mode Exit fullscreen mode
def entropy(labels):
    """Calculate entropy of a node."""
    if len(labels) == 0:
        return 0

    _, counts = np.unique(labels, return_counts=True)
    proportions = counts / len(labels)

    # Avoid log(0) by filtering out zero proportions
    proportions = proportions[proportions > 0]

    return -np.sum(proportions * np.log2(proportions))

print("\nGINI vs ENTROPY COMPARISON")
print("="*50)
print(f"{'Distribution':<20} {'Gini':<10} {'Entropy':<10}")
print("-"*40)

for name, labels in examples:
    g = gini_impurity(labels)
    e = entropy(labels)
    print(f"{name:<20} {g:<10.4f} {e:<10.4f}")
Enter fullscreen mode Exit fullscreen mode

Information Gain: Choosing the Best Split

Now we can measure impurity. But how do we choose the BEST question?

Information Gain = Reduction in impurity after a split

INFORMATION GAIN:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

                Parent Impurity
                      │
                      ▼
               ┌──────────────┐
               │  Gini = 0.48 │
               │  (before)    │
               └──────┬───────┘
                      │
              Split on Feature X
                      │
         ┌────────────┴────────────┐
         ▼                         ▼
   ┌──────────┐             ┌──────────┐
   │ Gini=0.1 │             │ Gini=0.2 │
   │ (40%)    │             │ (60%)    │
   └──────────┘             └──────────┘
    Left Child              Right Child


Weighted Average Impurity After Split:
= 0.40 × 0.1 + 0.60 × 0.2 = 0.04 + 0.12 = 0.16

Information Gain:
= Parent Impurity - Weighted Child Impurity
= 0.48 - 0.16 = 0.32 ✓

Higher Information Gain = Better Split!
Enter fullscreen mode Exit fullscreen mode

![Information Gain Visualization]

Information Gain measures how much a split reduces impurity — higher is better!

def information_gain(parent_labels, left_labels, right_labels, criterion='gini'):
    """Calculate information gain from a split."""

    if criterion == 'gini':
        impurity_func = gini_impurity
    else:
        impurity_func = entropy

    # Parent impurity
    parent_impurity = impurity_func(parent_labels)

    # Weighted child impurity
    n = len(parent_labels)
    n_left = len(left_labels)
    n_right = len(right_labels)

    weighted_child_impurity = (
        (n_left / n) * impurity_func(left_labels) +
        (n_right / n) * impurity_func(right_labels)
    )

    # Information gain
    return parent_impurity - weighted_child_impurity

# Example
print("INFORMATION GAIN EXAMPLE")
print("="*50)

parent = ['A']*10 + ['B']*10  # 50-50 split

# Good split (separates classes)
left_good = ['A']*9 + ['B']*1
right_good = ['A']*1 + ['B']*9

# Bad split (doesn't separate)
left_bad = ['A']*5 + ['B']*5
right_bad = ['A']*5 + ['B']*5

ig_good = information_gain(parent, left_good, right_good)
ig_bad = information_gain(parent, left_bad, right_bad)

print(f"Parent: 10 A's, 10 B's (Gini = {gini_impurity(parent):.4f})")
print(f"\nGood split (9A,1B | 1A,9B): IG = {ig_good:.4f}")
print(f"Bad split (5A,5B | 5A,5B):  IG = {ig_bad:.4f}")
print(f"\nGood split has {ig_good/ig_bad if ig_bad > 0 else 'infinitely'}x more information gain!")
Enter fullscreen mode Exit fullscreen mode

Output:

INFORMATION GAIN EXAMPLE
==================================================
Parent: 10 A's, 10 B's (Gini = 0.5000)

Good split (9A,1B | 1A,9B): IG = 0.3200
Bad split (5A,5B | 5A,5B):  IG = 0.0000

Good split has infinitely more information gain!
Enter fullscreen mode Exit fullscreen mode

Building a Decision Tree: Step by Step

![Tree Building Process]

The four steps: Start with all data → Try all features → Split on best → Repeat until done

Let's build a tree from scratch to understand the algorithm:

import numpy as np
import pandas as pd

# Create a dataset: Will the customer buy?
data = {
    'Age': ['Young', 'Young', 'Middle', 'Senior', 'Senior', 
            'Senior', 'Middle', 'Young', 'Young', 'Senior',
            'Young', 'Middle', 'Middle', 'Senior'],
    'Income': ['High', 'High', 'High', 'Medium', 'Low',
               'Low', 'Low', 'Medium', 'Low', 'Medium',
               'Medium', 'Medium', 'High', 'Medium'],
    'Student': ['No', 'No', 'No', 'No', 'Yes',
                'Yes', 'Yes', 'No', 'Yes', 'Yes',
                'Yes', 'No', 'Yes', 'No'],
    'Credit': ['Fair', 'Excellent', 'Fair', 'Fair', 'Fair',
               'Excellent', 'Excellent', 'Fair', 'Fair', 'Fair',
               'Excellent', 'Excellent', 'Fair', 'Excellent'],
    'Buys': ['No', 'No', 'Yes', 'Yes', 'Yes',
             'No', 'Yes', 'No', 'Yes', 'Yes',
             'Yes', 'Yes', 'Yes', 'No']
}

df = pd.DataFrame(data)
print("CUSTOMER PURCHASE DATASET")
print("="*60)
print(df.to_string(index=False))
print(f"\nTotal: {len(df)} customers, {sum(df['Buys']=='Yes')} buyers, {sum(df['Buys']=='No')} non-buyers")
Enter fullscreen mode Exit fullscreen mode
CUSTOMER PURCHASE DATASET
============================================================
    Age  Income Student    Credit Buys
  Young    High      No      Fair   No
  Young    High      No Excellent   No
 Middle    High      No      Fair  Yes
 Senior  Medium      No      Fair  Yes
 Senior     Low     Yes      Fair  Yes
 Senior     Low     Yes Excellent   No
 Middle     Low     Yes Excellent  Yes
  Young  Medium      No      Fair   No
  Young     Low     Yes      Fair  Yes
 Senior  Medium     Yes      Fair  Yes
  Young  Medium     Yes Excellent  Yes
 Middle  Medium      No Excellent  Yes
 Middle    High     Yes      Fair  Yes
 Senior  Medium      No Excellent   No

Total: 14 customers, 9 buyers, 5 non-buyers
Enter fullscreen mode Exit fullscreen mode

Step 1: Calculate Information Gain for Each Feature

def calculate_ig_for_feature(df, feature, target='Buys'):
    """Calculate information gain for splitting on a feature."""

    parent_labels = df[target].values
    parent_gini = gini_impurity(parent_labels)

    # Get unique values of the feature
    values = df[feature].unique()

    # Calculate weighted child impurity
    weighted_child_impurity = 0
    split_info = []

    for value in values:
        child_df = df[df[feature] == value]
        child_labels = child_df[target].values
        weight = len(child_df) / len(df)
        child_gini = gini_impurity(child_labels)
        weighted_child_impurity += weight * child_gini

        # Count classes
        n_yes = sum(child_labels == 'Yes')
        n_no = sum(child_labels == 'No')
        split_info.append((value, n_yes, n_no, child_gini))

    ig = parent_gini - weighted_child_impurity

    return ig, split_info

print("STEP 1: FINDING THE BEST FIRST SPLIT")
print("="*60)
print(f"\nParent Gini: {gini_impurity(df['Buys'].values):.4f}")
print(f"(9 Yes, 5 No out of 14)")

print("\nInformation Gain for each feature:")
print("-"*60)

for feature in ['Age', 'Income', 'Student', 'Credit']:
    ig, split_info = calculate_ig_for_feature(df, feature)
    print(f"\n{feature}: IG = {ig:.4f}")
    for value, n_yes, n_no, gini in split_info:
        print(f"  {value}: {n_yes} Yes, {n_no} No (Gini={gini:.4f})")

# Find best feature
best_feature = max(['Age', 'Income', 'Student', 'Credit'],
                   key=lambda f: calculate_ig_for_feature(df, f)[0])
best_ig, _ = calculate_ig_for_feature(df, best_feature)
print(f"\n{'='*60}")
print(f"BEST SPLIT: {best_feature} (IG = {best_ig:.4f})")
Enter fullscreen mode Exit fullscreen mode

Output:

STEP 1: FINDING THE BEST FIRST SPLIT
============================================================

Parent Gini: 0.4592
(9 Yes, 5 No out of 14)

Information Gain for each feature:
------------------------------------------------------------

Age: IG = 0.0939
  Young: 2 Yes, 3 No (Gini=0.4800)
  Middle: 4 Yes, 0 No (Gini=0.0000)
  Senior: 3 Yes, 2 No (Gini=0.4800)

Income: IG = 0.0117
  High: 2 Yes, 2 No (Gini=0.5000)
  Medium: 4 Yes, 2 No (Gini=0.4444)
  Low: 3 Yes, 1 No (Gini=0.3750)

Student: IG = 0.1518
  No: 3 Yes, 4 No (Gini=0.4898)
  Yes: 6 Yes, 1 No (Gini=0.2449)

Credit: IG = 0.0474
  Fair: 6 Yes, 2 No (Gini=0.3750)
  Excellent: 3 Yes, 3 No (Gini=0.5000)

============================================================
BEST SPLIT: Student (IG = 0.1518)
Enter fullscreen mode Exit fullscreen mode

Step 2: Make the First Split

FIRST SPLIT: Student?
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

                 [All 14 customers]
                  9 Yes, 5 No
                  Gini = 0.459
                        │
                  Is Student?
                        │
         ┌──────────────┴──────────────┐
         │                             │
        YES                           NO
         │                             │
   ┌─────┴─────┐               ┌───────┴───────┐
   │ 7 customers│               │ 7 customers   │
   │ 6 Yes, 1 No│               │ 3 Yes, 4 No   │
   │ Gini=0.245 │               │ Gini=0.490    │
   └───────────┘               └───────────────┘

   Almost pure!                 Still mixed...
   (86% Yes)                    Need more splits!
Enter fullscreen mode Exit fullscreen mode

Step 3: Continue Splitting (Recursively)

print("STEP 2-3: RECURSIVE SPLITTING")
print("="*60)

# Split the data
students = df[df['Student'] == 'Yes']
non_students = df[df['Student'] == 'No']

print("\n--- LEFT BRANCH: Students (7 people) ---")
print(f"6 Yes, 1 No (Gini = {gini_impurity(students['Buys'].values):.4f})")
print("\nShould we split further?")

for feature in ['Age', 'Income', 'Credit']:
    ig, split_info = calculate_ig_for_feature(students, feature)
    print(f"  {feature}: IG = {ig:.4f}")

print("\n--- RIGHT BRANCH: Non-Students (7 people) ---")
print(f"3 Yes, 4 No (Gini = {gini_impurity(non_students['Buys'].values):.4f})")
print("\nShould we split further?")

for feature in ['Age', 'Income', 'Credit']:
    ig, split_info = calculate_ig_for_feature(non_students, feature)
    print(f"  {feature}: IG = {ig:.4f}")
Enter fullscreen mode Exit fullscreen mode

Output:

STEP 2-3: RECURSIVE SPLITTING
============================================================

--- LEFT BRANCH: Students (7 people) ---
6 Yes, 1 No (Gini = 0.2449)

Should we split further?
  Age: IG = 0.2449
  Income: IG = 0.0204
  Credit: IG = 0.1020

--- RIGHT BRANCH: Non-Students (7 people) ---
3 Yes, 4 No (Gini = 0.4898)

Should we split further?
  Age: IG = 0.4898
  Income: IG = 0.1711
  Credit: IG = 0.0000
Enter fullscreen mode Exit fullscreen mode

The Complete Tree

THE FINAL DECISION TREE:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

                      [Student?]
                      /        \
                    YES         NO
                    /             \
            [Age?]              [Age?]
           /   |   \           /   |   \
       Young Middle Senior  Young Middle Senior
         |     |      |       |     |      |
        YES   YES    ???     NO    YES    ???

(The ??? nodes need more splits or become leaves)


INTERPRETATION:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

To predict if a customer will buy:

1. Is the customer a student?
   - If YES and Middle-aged → Will Buy
   - If YES and Young → Will Buy
   - If YES and Senior → Check further...

2. If NOT a student:
   - If Middle-aged → Will Buy
   - If Young → Won't Buy
   - If Senior → Check further...

The tree learned these rules from the data!
Enter fullscreen mode Exit fullscreen mode

Code: Building a Tree with Scikit-Learn

import numpy as np
import pandas as pd
from sklearn.tree import DecisionTreeClassifier, plot_tree
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt

# Prepare the data
df_encoded = df.copy()
label_encoders = {}

for col in ['Age', 'Income', 'Student', 'Credit', 'Buys']:
    le = LabelEncoder()
    df_encoded[col] = le.fit_transform(df[col])
    label_encoders[col] = le

X = df_encoded[['Age', 'Income', 'Student', 'Credit']]
y = df_encoded['Buys']

# Build the tree
tree = DecisionTreeClassifier(
    criterion='gini',      # Use Gini impurity
    max_depth=3,           # Limit depth to prevent overfitting
    min_samples_leaf=1,    # Minimum samples in a leaf
    random_state=42
)
tree.fit(X, y)

print("DECISION TREE WITH SCIKIT-LEARN")
print("="*60)
print(f"\nTree Depth: {tree.get_depth()}")
print(f"Number of Leaves: {tree.get_n_leaves()}")
print(f"Training Accuracy: {tree.score(X, y):.2%}")

# Feature importances
print("\nFeature Importances:")
for name, importance in zip(['Age', 'Income', 'Student', 'Credit'], tree.feature_importances_):
    print(f"  {name}: {importance:.4f}")

# Visualize
plt.figure(figsize=(20, 10))
plot_tree(tree, 
          feature_names=['Age', 'Income', 'Student', 'Credit'],
          class_names=['No', 'Yes'],
          filled=True,
          rounded=True,
          fontsize=10)
plt.title("Decision Tree: Will the Customer Buy?", fontsize=14)
plt.tight_layout()
plt.savefig('decision_tree_example.png', dpi=150, bbox_inches='tight')
print("\nTree visualization saved!")
Enter fullscreen mode Exit fullscreen mode

Decision Trees for Regression

Trees aren't just for classification! They can predict continuous values too:

REGRESSION TREES:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Instead of predicting a CLASS, predict a NUMBER.

Classification Tree:             Regression Tree:
Leaf → Most common class         Leaf → Average value

Split criterion:                 Split criterion:
Gini or Entropy                  MSE (Mean Squared Error)


EXAMPLE: Predicting House Price

                  [Sqft > 2000?]
                  /            \
                YES             NO
                /                \
        [Pool?]              [Bedrooms > 2?]
        /     \               /          \
      YES     NO            YES          NO
       |       |             |            |
   $450K    $350K         $280K        $180K


For a 2500 sqft house with pool:
→ Sqft > 2000? YES
→ Pool? YES
→ Prediction: $450,000
Enter fullscreen mode Exit fullscreen mode
from sklearn.tree import DecisionTreeRegressor
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split

# Create regression data
np.random.seed(42)
X = np.random.rand(200, 1) * 10  # One feature: 0-10
y = np.sin(X).ravel() + np.random.randn(200) * 0.2  # Noisy sine wave

# Split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# Fit trees with different depths
print("REGRESSION TREE: EFFECT OF DEPTH")
print("="*60)

for depth in [1, 3, 5, 10, None]:
    tree = DecisionTreeRegressor(max_depth=depth, random_state=42)
    tree.fit(X_train, y_train)

    train_score = tree.score(X_train, y_train)
    test_score = tree.score(X_test, y_test)

    depth_str = str(depth) if depth else "None"
    print(f"Depth={depth_str:<4} | Train R²: {train_score:.4f} | Test R²: {test_score:.4f}")
Enter fullscreen mode Exit fullscreen mode

Output:

REGRESSION TREE: EFFECT OF DEPTH
============================================================
Depth=1    | Train R²: 0.5234 | Test R²: 0.4891
Depth=3    | Train R²: 0.8234 | Test R²: 0.7891
Depth=5    | Train R²: 0.9234 | Test R²: 0.8234
Depth=10   | Train R²: 0.9912 | Test R²: 0.7123
Depth=None | Train R²: 1.0000 | Test R²: 0.5234  ← Overfit!
Enter fullscreen mode Exit fullscreen mode

The Dark Side: Overfitting

Decision trees have a dangerous tendency:

![Overfitting Visualization]

Left: Accuracy vs depth showing the overfitting zone. Right: What good vs overfit trees look like

THE OVERFITTING PROBLEM:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

An unpruned tree will keep splitting until every leaf
is pure. This means it MEMORIZES the training data!


EXAMPLE: Training data with 100 samples

Unrestricted tree might create:
- 100 leaves (one per sample!)
- Perfect training accuracy (100%)
- Terrible test accuracy (it memorized, didn't learn)


SYMPTOMS OF OVERFITTING:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

✗ Tree is very deep
✗ Many leaves have just 1-2 samples  
✗ Training accuracy >> Test accuracy
✗ Small changes in data cause big changes in tree


THE DETECTIVE ANALOGY:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Overfitting is like a detective who memorizes:
"The thief was a 5'10" male who wore blue socks
 on Tuesday and had eaten pasta for lunch."

This won't help catch future thieves!

We want general patterns:
"The thief had kitchen access and worked nights."
Enter fullscreen mode Exit fullscreen mode

Preventing Overfitting: Pruning

PRUNING STRATEGIES:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

1. PRE-PRUNING (stop early):
   - max_depth: Limit tree depth
   - min_samples_split: Min samples to split a node
   - min_samples_leaf: Min samples in a leaf
   - max_leaf_nodes: Maximum number of leaves

2. POST-PRUNING (grow then trim):
   - Cost-complexity pruning (ccp_alpha)
   - Grow full tree, then remove branches that
     don't improve validation performance


SCIKIT-LEARN PARAMETERS:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

DecisionTreeClassifier(
    max_depth=5,           # Don't go deeper than 5
    min_samples_split=10,  # Need 10+ samples to split
    min_samples_leaf=5,    # Each leaf needs 5+ samples
    max_leaf_nodes=20,     # Max 20 leaves
    ccp_alpha=0.01         # Post-pruning strength
)
Enter fullscreen mode Exit fullscreen mode
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# Create a more complex dataset
X, y = make_classification(
    n_samples=1000, n_features=20, n_informative=10,
    n_redundant=5, random_state=42
)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

print("EFFECT OF PRUNING PARAMETERS")
print("="*60)

configs = [
    {"max_depth": None, "min_samples_leaf": 1},   # No pruning
    {"max_depth": 5, "min_samples_leaf": 1},      # Limit depth
    {"max_depth": None, "min_samples_leaf": 10},  # Min leaf samples
    {"max_depth": 5, "min_samples_leaf": 5},      # Both
]

print(f"\n{'Config':<35} {'Train Acc':<12} {'Test Acc':<12} {'Depth':<8} {'Leaves'}")
print("-"*75)

for config in configs:
    tree = DecisionTreeClassifier(**config, random_state=42)
    tree.fit(X_train, y_train)

    train_acc = tree.score(X_train, y_train)
    test_acc = tree.score(X_test, y_test)

    config_str = f"depth={config['max_depth']}, leaf={config['min_samples_leaf']}"
    print(f"{config_str:<35} {train_acc:<12.2%} {test_acc:<12.2%} {tree.get_depth():<8} {tree.get_n_leaves()}")
Enter fullscreen mode Exit fullscreen mode

Output:

EFFECT OF PRUNING PARAMETERS
============================================================

Config                              Train Acc    Test Acc     Depth    Leaves
---------------------------------------------------------------------------
depth=None, leaf=1                  100.00%      82.33%       20       247
depth=5, leaf=1                     92.86%       85.33%       5        32
depth=None, leaf=10                 92.14%       86.00%       14       52
depth=5, leaf=5                     90.86%       86.67%       5        25
Enter fullscreen mode Exit fullscreen mode

Notice: Less pruning → Higher training accuracy but LOWER test accuracy (overfitting!)


Advantages and Disadvantages

ADVANTAGES OF DECISION TREES:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

✓ Easy to understand and visualize
  (You can draw the tree and explain each decision)

✓ No feature scaling needed
  (Splits are based on thresholds, not distances)

✓ Handles both numerical and categorical features
  (Unlike many algorithms)

✓ Handles non-linear relationships
  (Unlike linear regression/logistic regression)

✓ Feature importance built-in
  (See which features matter most)

✓ Fast prediction
  (Just follow the branches)


DISADVANTAGES OF DECISION TREES:
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

✗ Prone to overfitting
  (Without pruning, memorizes training data)

✗ Unstable
  (Small changes in data can create very different trees)

✗ Greedy algorithm
  (Locally optimal splits, not globally optimal)

✗ Biased toward features with many levels
  (Features with more categories get more chances to split)

✗ Can't extrapolate
  (Predictions limited to range seen in training)

✗ Struggles with XOR-like patterns
  (Needs many splits for diagonal boundaries)
Enter fullscreen mode Exit fullscreen mode

Complete Decision Tree Implementation from Scratch

import numpy as np
from collections import Counter

class DecisionTreeFromScratch:
    """A decision tree classifier built from scratch."""

    def __init__(self, max_depth=None, min_samples_split=2, min_samples_leaf=1):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.tree = None

    def _gini(self, y):
        """Calculate Gini impurity."""
        if len(y) == 0:
            return 0
        counts = Counter(y)
        proportions = [count / len(y) for count in counts.values()]
        return 1 - sum(p**2 for p in proportions)

    def _information_gain(self, y, y_left, y_right):
        """Calculate information gain from a split."""
        parent_gini = self._gini(y)
        n = len(y)
        n_left, n_right = len(y_left), len(y_right)

        if n_left == 0 or n_right == 0:
            return 0

        child_gini = (n_left/n) * self._gini(y_left) + (n_right/n) * self._gini(y_right)
        return parent_gini - child_gini

    def _best_split(self, X, y):
        """Find the best feature and threshold to split on."""
        best_gain = 0
        best_feature = None
        best_threshold = None

        n_features = X.shape[1]

        for feature in range(n_features):
            thresholds = np.unique(X[:, feature])

            for threshold in thresholds:
                left_mask = X[:, feature] <= threshold
                right_mask = ~left_mask

                if sum(left_mask) < self.min_samples_leaf or sum(right_mask) < self.min_samples_leaf:
                    continue

                gain = self._information_gain(y, y[left_mask], y[right_mask])

                if gain > best_gain:
                    best_gain = gain
                    best_feature = feature
                    best_threshold = threshold

        return best_feature, best_threshold, best_gain

    def _build_tree(self, X, y, depth=0):
        """Recursively build the decision tree."""
        n_samples, n_features = X.shape
        n_classes = len(np.unique(y))

        # Stopping conditions
        if (self.max_depth is not None and depth >= self.max_depth) or \
           n_classes == 1 or \
           n_samples < self.min_samples_split:
            # Return leaf node
            return {'leaf': True, 'prediction': Counter(y).most_common(1)[0][0]}

        # Find best split
        feature, threshold, gain = self._best_split(X, y)

        if feature is None:
            return {'leaf': True, 'prediction': Counter(y).most_common(1)[0][0]}

        # Split the data
        left_mask = X[:, feature] <= threshold
        right_mask = ~left_mask

        # Recursively build children
        left_child = self._build_tree(X[left_mask], y[left_mask], depth + 1)
        right_child = self._build_tree(X[right_mask], y[right_mask], depth + 1)

        return {
            'leaf': False,
            'feature': feature,
            'threshold': threshold,
            'left': left_child,
            'right': right_child
        }

    def fit(self, X, y):
        """Build the tree from training data."""
        self.tree = self._build_tree(np.array(X), np.array(y))
        return self

    def _predict_one(self, x, node):
        """Predict for a single sample."""
        if node['leaf']:
            return node['prediction']

        if x[node['feature']] <= node['threshold']:
            return self._predict_one(x, node['left'])
        else:
            return self._predict_one(x, node['right'])

    def predict(self, X):
        """Predict for multiple samples."""
        return [self._predict_one(x, self.tree) for x in np.array(X)]

    def print_tree(self, node=None, indent=""):
        """Pretty print the tree."""
        if node is None:
            node = self.tree

        if node['leaf']:
            print(f"{indent}Predict: {node['prediction']}")
        else:
            print(f"{indent}Feature {node['feature']} <= {node['threshold']:.2f}?")
            print(f"{indent}├── Yes:")
            self.print_tree(node['left'], indent + "")
            print(f"{indent}└── No:")
            self.print_tree(node['right'], indent + "    ")

# Test our implementation
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# Load data
iris = load_iris()
X_train, X_test, y_train, y_test = train_test_split(
    iris.data, iris.target, test_size=0.3, random_state=42
)

# Our tree
our_tree = DecisionTreeFromScratch(max_depth=3)
our_tree.fit(X_train, y_train)
our_pred = our_tree.predict(X_test)

# Sklearn tree
from sklearn.tree import DecisionTreeClassifier
sklearn_tree = DecisionTreeClassifier(max_depth=3, random_state=42)
sklearn_tree.fit(X_train, y_train)
sklearn_pred = sklearn_tree.predict(X_test)

print("DECISION TREE FROM SCRATCH vs SKLEARN")
print("="*60)
print(f"\nOur tree accuracy:     {accuracy_score(y_test, our_pred):.2%}")
print(f"Sklearn tree accuracy: {accuracy_score(y_test, sklearn_pred):.2%}")

print("\nOur tree structure:")
our_tree.print_tree()
Enter fullscreen mode Exit fullscreen mode

Key Takeaways

  1. Decision trees ask yes/no questions — Each split divides data based on a feature threshold

  2. Gini impurity measures mixing — Lower Gini = purer node = better split

  3. Information gain guides splits — Choose the question that reduces impurity most

  4. Trees are built recursively — Split → Check stopping conditions → Repeat

  5. Overfitting is the main enemy — Use pruning (max_depth, min_samples, etc.)

  6. Trees are interpretable — You can visualize and explain every decision

  7. No scaling needed — Splits are based on thresholds, not distances

  8. Foundation for powerful ensembles — Random Forest, XGBoost, LightGBM all use trees!


The One-Sentence Summary

A decision tree is like Detective Oak solving cases by asking clever yes/no questions — each question (split) is chosen to separate the suspects (classes) as cleanly as possible, continuing until we're confident enough to make an arrest (prediction), while being careful not to memorize irrelevant details (overfitting) that won't help catch future criminals.


What's Next in This Series?

Now that you understand how a single tree works, you're ready for:

  1. Random Forests — What if we had 100 detectives voting?
  2. Bagging — Training on different subsets
  3. Boosting — Learning from mistakes
  4. XGBoost — The competition winner
  5. LightGBM — Faster and more efficient
  6. CatBoost — Handling categories elegantly

Follow me for the next article in the Tree Based Models series!


Let's Connect!

If Detective Oak made decision trees click, drop a heart!

Questions? Ask in the comments — I read and respond to every one.

What's the deepest decision tree you've built? I once let one grow to depth 50 (for science). Training accuracy: 100%. Test accuracy: 52%. Lesson learned! 🌳


The difference between memorizing answers and learning patterns? Proper pruning. A good decision tree knows when to stop asking questions — that's what separates a wise detective from an obsessive one.


Share this with someone starting their ML journey. Decision trees are the gateway to the most powerful algorithms in competitive ML!

Happy splitting! 🌲

Top comments (0)