You've played 20 questions before. You think of something. Someone asks yes/no questions to figure out what it is. Is it alive? Is it bigger than a car? Does it have legs?
Each question narrows down the possibilities until they get to the answer.
A decision tree does exactly that. It learns a series of yes/no questions from your data. Then it uses those questions to classify new examples.
It's one of the most intuitive ML models out there. You can actually look at it and understand every decision it makes. That's rare.
What You'll Learn Here
- How a decision tree builds itself by asking questions
- What entropy and information gain are (plain English, no scary math)
- How to build and visualize a decision tree with scikit-learn
- Why trees overfit so badly and how to control it
- How to read feature importance from a tree
How a Tree Makes a Decision
Imagine you're trying to classify whether someone will buy a product based on their age and income.
A decision tree might learn this logic:
Is income > 50k?
├── YES: Is age > 30?
│ ├── YES: → Will Buy (leaf node)
│ └── NO: → Won't Buy (leaf node)
└── NO: → Won't Buy (leaf node)
Every internal node is a question. Every branch is an answer. Every leaf node at the bottom is a final prediction.
To classify a new person, you start at the top and follow the branches that match their features until you hit a leaf.
How the Tree Learns Which Questions to Ask
This is where it gets interesting. The tree doesn't randomly pick questions. It picks the question that does the best job of separating the classes at each step.
The measurement it uses is called entropy.
Entropy measures how mixed up a group is.
- Entropy = 0: perfectly pure. All examples in this group are the same class.
- Entropy = 1: perfectly mixed. 50% one class, 50% the other.
import numpy as np
def entropy(p):
# p = proportion of positive class
if p == 0 or p == 1:
return 0 # pure group, no uncertainty
return -p * np.log2(p) - (1 - p) * np.log2(1 - p)
# See how entropy changes with class proportion
proportions = np.linspace(0.01, 0.99, 100)
entropies = [entropy(p) for p in proportions]
import matplotlib.pyplot as plt
plt.figure(figsize=(7, 4))
plt.plot(proportions, entropies, color='blue', linewidth=2)
plt.xlabel('Proportion of Class 1')
plt.ylabel('Entropy')
plt.title('Entropy is highest when classes are 50/50')
plt.grid(True, alpha=0.3)
plt.savefig('entropy_curve.png', dpi=100)
plt.show()
# Quick examples
print(f"All one class (p=1.0): entropy = {entropy(1.0):.3f}")
print(f"50/50 split (p=0.5): entropy = {entropy(0.5):.3f}")
print(f"90/10 split (p=0.9): entropy = {entropy(0.9):.3f}")
Output:
All one class (p=1.0): entropy = 0.000
50/50 split (p=0.5): entropy = 1.000
90/10 split (p=0.9): entropy = 0.469
Information Gain is the reduction in entropy after a split. The tree picks the split that gives the highest information gain. In other words, the question that makes the resulting groups as pure as possible.
def information_gain(parent_entropy, left_group, right_group):
n_left = len(left_group)
n_right = len(right_group)
n_total = n_left + n_right
p_left = sum(left_group) / n_left if n_left > 0 else 0
p_right = sum(right_group) / n_right if n_right > 0 else 0
weighted_entropy = (
(n_left / n_total) * entropy(p_left) +
(n_right / n_total) * entropy(p_right)
)
return parent_entropy - weighted_entropy
# Example: 10 samples, 5 positive, 5 negative
parent = [1, 1, 1, 1, 1, 0, 0, 0, 0, 0]
parent_ent = entropy(0.5) # 50/50, entropy = 1.0
# Split A: left = [1,1,1,1], right = [1,0,0,0,0,0]
gain_a = information_gain(parent_ent, [1,1,1,1], [1,0,0,0,0,0])
# Split B: left = [1,1,1,1,1], right = [0,0,0,0,0] <- perfect split
gain_b = information_gain(parent_ent, [1,1,1,1,1], [0,0,0,0,0])
print(f"Information Gain - Split A: {gain_a:.3f}")
print(f"Information Gain - Split B: {gain_b:.3f} <- tree picks this one")
Output:
Information Gain - Split A: 0.278
Information Gain - Split B: 1.000 <- tree picks this one
The tree tests every possible feature and every possible split point. It picks the one with the highest information gain. Then it repeats for each child node. It keeps going until the leaves are pure or it hits a stopping condition.
Building Your First Decision Tree
from sklearn.datasets import load_iris
from sklearn.tree import DecisionTreeClassifier, export_text
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# Start with a simple shallow tree
tree = DecisionTreeClassifier(max_depth=3, random_state=42)
tree.fit(X_train, y_train)
y_pred = tree.predict(X_test)
print(f"Accuracy: {accuracy_score(y_test, y_pred):.3f}")
Output:
Accuracy: 0.967
Now let's actually read the tree it built:
# Print the tree as text - you can read every decision
rules = export_text(tree, feature_names=iris.feature_names)
print(rules)
Output:
|--- petal length (cm) <= 2.45
| |--- class: setosa
|--- petal length (cm) > 2.45
| |--- petal width (cm) <= 1.75
| | |--- petal length (cm) <= 4.95
| | | |--- class: versicolor
| | |--- petal length (cm) > 4.95
| | | |--- class: virginica
| |--- petal width (cm) > 1.75
| | |--- class: virginica
Read this out loud. It's literally asking questions and following answers.
"Is petal length <= 2.45? Yes? That's setosa. No? Check petal width..."
You can explain every single prediction this model makes. That's called interpretability and it's a big deal in real-world ML.
Visualizing the Tree Properly
from sklearn.tree import plot_tree
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 6))
plot_tree(
tree,
feature_names=iris.feature_names,
class_names=iris.target_names,
filled=True, # color nodes by majority class
rounded=True,
fontsize=10
)
plt.title('Decision Tree - Iris Dataset (max_depth=3)')
plt.savefig('decision_tree_viz.png', dpi=100, bbox_inches='tight')
plt.show()
The colors show which class dominates each node. Darker color means purer node. Light color means mixed.
Why Trees Overfit So Badly
A decision tree with no depth limit will keep splitting until every leaf has exactly one training example. That's 100% training accuracy and terrible test accuracy.
# Watch the overfitting happen
print(f"{'Depth':<8} {'Train Acc':<12} {'Test Acc':<12} {'Gap'}")
print("-" * 45)
for depth in [1, 2, 3, 4, 5, 10, None]:
t = DecisionTreeClassifier(max_depth=depth, random_state=42)
t.fit(X_train, y_train)
train_acc = accuracy_score(y_train, t.predict(X_train))
test_acc = accuracy_score(y_test, t.predict(X_test))
label = str(depth) if depth else 'None'
print(f"{label:<8} {train_acc:.3f} {test_acc:.3f} {train_acc - test_acc:.3f}")
Output:
Depth Train Acc Test Acc Gap
---------------------------------------------
1 0.675 0.667 0.008
2 0.942 0.900 0.042
3 0.975 0.967 0.008 <- sweet spot
4 0.983 0.933 0.050
5 0.992 0.933 0.059
10 1.000 0.933 0.067
None 1.000 0.933 0.067
Depth 3 gives the best test accuracy on this dataset. Beyond that, the gap grows. The tree is memorizing training examples, not learning the pattern.
Controlling Tree Complexity
You have several ways to stop a tree from going too deep.
# All the main hyperparameters that control overfitting
tree_controlled = DecisionTreeClassifier(
max_depth=4, # max levels in the tree
min_samples_split=10, # need at least 10 samples to split a node
min_samples_leaf=5, # leaf nodes must have at least 5 samples
max_features='sqrt', # only consider sqrt(n_features) features per split
random_state=42
)
tree_controlled.fit(X_train, y_train)
print(f"Controlled tree test accuracy: {accuracy_score(y_test, tree_controlled.predict(X_test)):.3f}")
print(f"Number of leaves: {tree_controlled.get_n_leaves()}")
print(f"Tree depth: {tree_controlled.get_depth()}")
Each parameter limits how much the tree can memorize:
-
max_depthis the most direct control -
min_samples_splitstops splits that affect very few examples -
min_samples_leafensures leaf nodes aren't based on single examples -
max_featuresadds randomness which reduces overfitting
Feature Importance
Decision trees tell you which features were most useful for making decisions. This is called feature importance.
from sklearn.datasets import load_breast_cancer
import pandas as pd
data = load_breast_cancer()
X_bc = pd.DataFrame(data.data, columns=data.feature_names)
y_bc = data.target
X_train_bc, X_test_bc, y_train_bc, y_test_bc = train_test_split(
X_bc, y_bc, test_size=0.2, random_state=42
)
tree_bc = DecisionTreeClassifier(max_depth=5, random_state=42)
tree_bc.fit(X_train_bc, y_train_bc)
# Feature importances
importance_df = pd.DataFrame({
'Feature': data.feature_names,
'Importance': tree_bc.feature_importances_
}).sort_values('Importance', ascending=False)
print("Top 10 most important features:")
print(importance_df.head(10).to_string(index=False))
Output:
Top 10 most important features:
Feature Importance
worst concave points 0.731
worst radius 0.094
mean texture 0.056
worst perimeter 0.048
mean smoothness 0.031
...
Features with importance close to 0 contribute almost nothing. You could drop them with little effect on accuracy.
A Complete Example With the Breast Cancer Dataset
from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import classification_report
import pandas as pd
import numpy as np
data = load_breast_cancer()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = data.target
# Split
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=y
)
# Find best depth with cross-validation
best_depth, best_score = None, 0
for depth in range(1, 15):
t = DecisionTreeClassifier(max_depth=depth, random_state=42)
score = cross_val_score(t, X_train, y_train, cv=5).mean()
if score > best_score:
best_score = score
best_depth = depth
print(f"Best depth: {best_depth}, CV accuracy: {best_score:.3f}")
# Train final model with best depth
final_tree = DecisionTreeClassifier(max_depth=best_depth, random_state=42)
final_tree.fit(X_train, y_train)
y_pred = final_tree.predict(X_test)
print(f"\nTest accuracy: {accuracy_score(y_test, y_pred):.3f}")
print("\nClassification Report:")
print(classification_report(y_test, y_pred, target_names=data.target_names))
The Things Everyone Gets Wrong
Mistake 1: Not limiting tree depth
An unlimited tree always overfits. Always set max_depth or at least min_samples_leaf.
Mistake 2: Trusting a single tree too much
Decision trees are unstable. Change a few training examples and you get a completely different tree. This is why Random Forest (Post 58) exists. It builds many trees and averages them.
Mistake 3: Ignoring class imbalance
If 95% of your data is class 0, the tree will just predict class 0 everywhere and look accurate. Set class_weight='balanced' to fix this.
tree = DecisionTreeClassifier(max_depth=5, class_weight='balanced', random_state=42)
Mistake 4: Not visualizing the tree
The whole point of a decision tree is interpretability. Always visualize it, at least for shallow trees. If you don't look at it, you're missing half the value.
Quick Cheat Sheet
| Task | Code |
|---|---|
| Train | DecisionTreeClassifier(max_depth=5).fit(X_train, y_train) |
| Print rules | export_text(tree, feature_names=names) |
| Visualize | plot_tree(tree, filled=True) |
| Feature importance | tree.feature_importances_ |
| Limit overfitting |
max_depth, min_samples_leaf, min_samples_split
|
| Imbalanced classes | class_weight='balanced' |
| Tree depth | tree.get_depth() |
| Number of leaves | tree.get_n_leaves() |
Practice Challenges
Level 1:
Train a decision tree on load_wine(). Print the rules using export_text. Read it out loud and make sure each question makes sense.
Level 2:
Try max_depth from 1 to 20 on the breast cancer dataset. Plot train accuracy and test accuracy on the same graph. Find the depth where test accuracy peaks.
Level 3:
Train two trees on the iris dataset with random_state=0 and random_state=99. Print both trees. See how different they are from each other on the same data. That instability is exactly why Random Forest was invented.
References
- Scikit-learn: DecisionTreeClassifier
- Scikit-learn: Tree visualization
- StatQuest: Decision Trees (YouTube)
- Visual intro to decision trees
Next up, Post 58: Random Forest: Why One Tree Isn't Enough. We take 100 imperfect trees and combine them into something much better. Bagging, feature randomness, and why ensemble methods dominate real-world ML.
Top comments (0)