DEV Community

Cover image for 57. Decision Trees: The AI That Plays 20 Questions
Akhilesh
Akhilesh

Posted on

57. Decision Trees: The AI That Plays 20 Questions

You've played 20 questions before. You think of something. Someone asks yes/no questions to figure out what it is. Is it alive? Is it bigger than a car? Does it have legs?

Each question narrows down the possibilities until they get to the answer.

A decision tree does exactly that. It learns a series of yes/no questions from your data. Then it uses those questions to classify new examples.

It's one of the most intuitive ML models out there. You can actually look at it and understand every decision it makes. That's rare.


What You'll Learn Here

  • How a decision tree builds itself by asking questions
  • What entropy and information gain are (plain English, no scary math)
  • How to build and visualize a decision tree with scikit-learn
  • Why trees overfit so badly and how to control it
  • How to read feature importance from a tree

How a Tree Makes a Decision

Imagine you're trying to classify whether someone will buy a product based on their age and income.

A decision tree might learn this logic:

Is income > 50k?
├── YES: Is age > 30?
│         ├── YES: → Will Buy  (leaf node)
│         └── NO:  → Won't Buy (leaf node)
└── NO:  → Won't Buy            (leaf node)
Enter fullscreen mode Exit fullscreen mode

Every internal node is a question. Every branch is an answer. Every leaf node at the bottom is a final prediction.

To classify a new person, you start at the top and follow the branches that match their features until you hit a leaf.


How the Tree Learns Which Questions to Ask

This is where it gets interesting. The tree doesn't randomly pick questions. It picks the question that does the best job of separating the classes at each step.

The measurement it uses is called entropy.

Entropy measures how mixed up a group is.

  • Entropy = 0: perfectly pure. All examples in this group are the same class.
  • Entropy = 1: perfectly mixed. 50% one class, 50% the other.
import numpy as np

def entropy(p):
    # p = proportion of positive class
    if p == 0 or p == 1:
        return 0  # pure group, no uncertainty
    return -p * np.log2(p) - (1 - p) * np.log2(1 - p)

# See how entropy changes with class proportion
proportions = np.linspace(0.01, 0.99, 100)
entropies   = [entropy(p) for p in proportions]

import matplotlib.pyplot as plt
plt.figure(figsize=(7, 4))
plt.plot(proportions, entropies, color='blue', linewidth=2)
plt.xlabel('Proportion of Class 1')
plt.ylabel('Entropy')
plt.title('Entropy is highest when classes are 50/50')
plt.grid(True, alpha=0.3)
plt.savefig('entropy_curve.png', dpi=100)
plt.show()

# Quick examples
print(f"All one class (p=1.0):  entropy = {entropy(1.0):.3f}")
print(f"50/50 split  (p=0.5):  entropy = {entropy(0.5):.3f}")
print(f"90/10 split  (p=0.9):  entropy = {entropy(0.9):.3f}")
Enter fullscreen mode Exit fullscreen mode

Output:

All one class (p=1.0):  entropy = 0.000
50/50 split  (p=0.5):  entropy = 1.000
90/10 split  (p=0.9):  entropy = 0.469
Enter fullscreen mode Exit fullscreen mode

Information Gain is the reduction in entropy after a split. The tree picks the split that gives the highest information gain. In other words, the question that makes the resulting groups as pure as possible.

def information_gain(parent_entropy, left_group, right_group):
    n_left  = len(left_group)
    n_right = len(right_group)
    n_total = n_left + n_right

    p_left  = sum(left_group)  / n_left   if n_left  > 0 else 0
    p_right = sum(right_group) / n_right  if n_right > 0 else 0

    weighted_entropy = (
        (n_left  / n_total) * entropy(p_left) +
        (n_right / n_total) * entropy(p_right)
    )

    return parent_entropy - weighted_entropy

# Example: 10 samples, 5 positive, 5 negative
parent = [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]
parent_ent = entropy(0.5)  # 50/50, entropy = 1.0

# Split A: left = [1,1,1,1], right = [1,0,0,0,0,0]
gain_a = information_gain(parent_ent, [1,1,1,1], [1,0,0,0,0,0])

# Split B: left = [1,1,1,1,1], right = [0,0,0,0,0]  <- perfect split
gain_b = information_gain(parent_ent, [1,1,1,1,1], [0,0,0,0,0])

print(f"Information Gain - Split A: {gain_a:.3f}")
print(f"Information Gain - Split B: {gain_b:.3f}  <- tree picks this one")
Enter fullscreen mode Exit fullscreen mode

Output:

Information Gain - Split A: 0.278
Information Gain - Split B: 1.000  <- tree picks this one
Enter fullscreen mode Exit fullscreen mode

The tree tests every possible feature and every possible split point. It picks the one with the highest information gain. Then it repeats for each child node. It keeps going until the leaves are pure or it hits a stopping condition.


Building Your First Decision Tree

from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

iris = load_iris()
X, y = iris.data, iris.target

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Start with a simple shallow tree
tree = DecisionTreeClassifier(max_depth=3, random_state=42)
tree.fit(X_train, y_train)

y_pred = tree.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.3f}")
Enter fullscreen mode Exit fullscreen mode

Output:

Accuracy: 0.967
Enter fullscreen mode Exit fullscreen mode

Now let's actually read the tree it built:

# Print the tree as text - you can read every decision
rules = export_text(tree, feature_names=iris.feature_names)
print(rules)
Enter fullscreen mode Exit fullscreen mode

Output:

|--- petal length (cm) <= 2.45
|   |--- class: setosa
|--- petal length (cm) >  2.45
|   |--- petal width (cm) <= 1.75
|   |   |--- petal length (cm) <= 4.95
|   |   |   |--- class: versicolor
|   |   |--- petal length (cm) >  4.95
|   |   |   |--- class: virginica
|   |--- petal width (cm) >  1.75
|   |   |--- class: virginica
Enter fullscreen mode Exit fullscreen mode

Read this out loud. It's literally asking questions and following answers.

"Is petal length <= 2.45? Yes? That's setosa. No? Check petal width..."

You can explain every single prediction this model makes. That's called interpretability and it's a big deal in real-world ML.


Visualizing the Tree Properly

from sklearn.tree import plot_tree
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 6))
plot_tree(
    tree,
    feature_names=iris.feature_names,
    class_names=iris.target_names,
    filled=True,        # color nodes by majority class
    rounded=True,
    fontsize=10
)
plt.title('Decision Tree - Iris Dataset (max_depth=3)')
plt.savefig('decision_tree_viz.png', dpi=100, bbox_inches='tight')
plt.show()
Enter fullscreen mode Exit fullscreen mode

The colors show which class dominates each node. Darker color means purer node. Light color means mixed.


Why Trees Overfit So Badly

A decision tree with no depth limit will keep splitting until every leaf has exactly one training example. That's 100% training accuracy and terrible test accuracy.

# Watch the overfitting happen
print(f"{'Depth':<8} {'Train Acc':<12} {'Test Acc':<12} {'Gap'}")
print("-" * 45)

for depth in [1, 2, 3, 4, 5, 10, None]:
    t = DecisionTreeClassifier(max_depth=depth, random_state=42)
    t.fit(X_train, y_train)

    train_acc = accuracy_score(y_train, t.predict(X_train))
    test_acc  = accuracy_score(y_test,  t.predict(X_test))

    label = str(depth) if depth else 'None'
    print(f"{label:<8} {train_acc:.3f}        {test_acc:.3f}        {train_acc - test_acc:.3f}")
Enter fullscreen mode Exit fullscreen mode

Output:

Depth    Train Acc    Test Acc     Gap
---------------------------------------------
1        0.675        0.667        0.008
2        0.942        0.900        0.042
3        0.975        0.967        0.008   <- sweet spot
4        0.983        0.933        0.050
5        0.992        0.933        0.059
10       1.000        0.933        0.067
None     1.000        0.933        0.067
Enter fullscreen mode Exit fullscreen mode

Depth 3 gives the best test accuracy on this dataset. Beyond that, the gap grows. The tree is memorizing training examples, not learning the pattern.


Controlling Tree Complexity

You have several ways to stop a tree from going too deep.

# All the main hyperparameters that control overfitting
tree_controlled = DecisionTreeClassifier(
    max_depth=4,          # max levels in the tree
    min_samples_split=10, # need at least 10 samples to split a node
    min_samples_leaf=5,   # leaf nodes must have at least 5 samples
    max_features='sqrt',  # only consider sqrt(n_features) features per split
    random_state=42
)

tree_controlled.fit(X_train, y_train)
print(f"Controlled tree test accuracy: {accuracy_score(y_test, tree_controlled.predict(X_test)):.3f}")
print(f"Number of leaves: {tree_controlled.get_n_leaves()}")
print(f"Tree depth: {tree_controlled.get_depth()}")
Enter fullscreen mode Exit fullscreen mode

Each parameter limits how much the tree can memorize:

  • max_depth is the most direct control
  • min_samples_split stops splits that affect very few examples
  • min_samples_leaf ensures leaf nodes aren't based on single examples
  • max_features adds randomness which reduces overfitting

Feature Importance

Decision trees tell you which features were most useful for making decisions. This is called feature importance.

from sklearn.datasets import load_breast_cancer
import pandas as pd

data = load_breast_cancer()
X_bc = pd.DataFrame(data.data, columns=data.feature_names)
y_bc = data.target

X_train_bc, X_test_bc, y_train_bc, y_test_bc = train_test_split(
    X_bc, y_bc, test_size=0.2, random_state=42
)

tree_bc = DecisionTreeClassifier(max_depth=5, random_state=42)
tree_bc.fit(X_train_bc, y_train_bc)

# Feature importances
importance_df = pd.DataFrame({
    'Feature':    data.feature_names,
    'Importance': tree_bc.feature_importances_
}).sort_values('Importance', ascending=False)

print("Top 10 most important features:")
print(importance_df.head(10).to_string(index=False))
Enter fullscreen mode Exit fullscreen mode

Output:

Top 10 most important features:
              Feature  Importance
    worst concave points    0.731
          worst radius    0.094
         mean texture    0.056
       worst perimeter    0.048
         mean smoothness    0.031
...
Enter fullscreen mode Exit fullscreen mode

Features with importance close to 0 contribute almost nothing. You could drop them with little effect on accuracy.


A Complete Example With the Breast Cancer Dataset

from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import classification_report
import pandas as pd
import numpy as np

data = load_breast_cancer()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = data.target

# Split
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Find best depth with cross-validation
best_depth, best_score = None, 0

for depth in range(1, 15):
    t = DecisionTreeClassifier(max_depth=depth, random_state=42)
    score = cross_val_score(t, X_train, y_train, cv=5).mean()
    if score > best_score:
        best_score = score
        best_depth = depth

print(f"Best depth: {best_depth}, CV accuracy: {best_score:.3f}")

# Train final model with best depth
final_tree = DecisionTreeClassifier(max_depth=best_depth, random_state=42)
final_tree.fit(X_train, y_train)
y_pred = final_tree.predict(X_test)

print(f"\nTest accuracy: {accuracy_score(y_test, y_pred):.3f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=data.target_names))
Enter fullscreen mode Exit fullscreen mode

The Things Everyone Gets Wrong

Mistake 1: Not limiting tree depth

An unlimited tree always overfits. Always set max_depth or at least min_samples_leaf.

Mistake 2: Trusting a single tree too much

Decision trees are unstable. Change a few training examples and you get a completely different tree. This is why Random Forest (Post 58) exists. It builds many trees and averages them.

Mistake 3: Ignoring class imbalance

If 95% of your data is class 0, the tree will just predict class 0 everywhere and look accurate. Set class_weight='balanced' to fix this.

tree = DecisionTreeClassifier(max_depth=5, class_weight='balanced', random_state=42)
Enter fullscreen mode Exit fullscreen mode

Mistake 4: Not visualizing the tree

The whole point of a decision tree is interpretability. Always visualize it, at least for shallow trees. If you don't look at it, you're missing half the value.


Quick Cheat Sheet

Task Code
Train DecisionTreeClassifier(max_depth=5).fit(X_train, y_train)
Print rules export_text(tree, feature_names=names)
Visualize plot_tree(tree, filled=True)
Feature importance tree.feature_importances_
Limit overfitting max_depth, min_samples_leaf, min_samples_split
Imbalanced classes class_weight='balanced'
Tree depth tree.get_depth()
Number of leaves tree.get_n_leaves()

Practice Challenges

Level 1:
Train a decision tree on load_wine(). Print the rules using export_text. Read it out loud and make sure each question makes sense.

Level 2:
Try max_depth from 1 to 20 on the breast cancer dataset. Plot train accuracy and test accuracy on the same graph. Find the depth where test accuracy peaks.

Level 3:
Train two trees on the iris dataset with random_state=0 and random_state=99. Print both trees. See how different they are from each other on the same data. That instability is exactly why Random Forest was invented.


References


Next up, Post 58: Random Forest: Why One Tree Isn't Enough. We take 100 imperfect trees and combine them into something much better. Bagging, feature randomness, and why ensemble methods dominate real-world ML.

Top comments (0)