DEV Community

Yu Fen Lin
Yu Fen Lin

Posted on

Build a decision tree in R

Overview & Purpose

With this article, we will build a decision trees model based on the Titanic data set that predicts whether a given person survived or not.

Steps:

  1. Initial data understanding and preparation
  2. Build, train, and test the model
  3. Evaluate the performance of the model

1. Understanding the data set

We will use Titanic Passenger Survival Data Set. This data set provides information on the fate of passengers on the fatal maiden voyage of the ocean liner "Titanic", summarized according to economic status (class), sex, age and survival. Below is a brief description of the 12 variables in the data set :

  • PassengerId:<int> Serial Number
  • Survived:<int> Contains binary Values of 0 and 1
    • 0: Passenger did not survive
    • 1: Passenger Survived
  • Pclass: <int> Ticket Class - 1st Class, 2nd Class or 3rd Class Ticket
  • Name: <chr> Name of the passenger
  • Sex: <chr> Male or Female
  • Age: <dbl> Age in years
  • SibSp: <int> No. of Siblings / Spouses — brothers, sisters and/or husband/wife
  • Parch: <int> No. of parents/children — mother/father and/or daughter, son
  • Ticket: <chr> Serial Number
  • Fare: <dbl> Passenger fare
  • Cabin: <chr> Cabin Number
  • Embarked: <chr> Port of Embarkment
    • C: Cherbourg
    • Q: Queenstown
    • S: Southhampton

Load necessary data

Remove all objects in the Global Environment and load titanic data.

rm(list = ls())

# install necessary packages
library(tidyverse) 
install.packages("titanic") 
# load necessary packages
library(titanic)
# load necessary data 
titanic <-
  titanic_train

Take a look.

titanic %>%
  View(title = "Titanic")

Produce the summaries of data.
Summary() is one important functions that help in summarising each attribute in the dataset.

> summary(titanic)
  PassengerId       Survived          Pclass          Name          
 Min.   :  1.0   Min.   :0.0000   Min.   :1.000   Length:891        
 1st Qu.:223.5   1st Qu.:0.0000   1st Qu.:2.000   Class :character  
 Median :446.0   Median :0.0000   Median :3.000   Mode  :character  
 Mean   :446.0   Mean   :0.3838   Mean   :2.309                     
 3rd Qu.:668.5   3rd Qu.:1.0000   3rd Qu.:3.000                     
 Max.   :891.0   Max.   :1.0000   Max.   :3.000                     

     Sex                 Age            SibSp           Parch       
 Length:891         Min.   : 0.42   Min.   :0.000   Min.   :0.0000  
 Class :character   1st Qu.:20.12   1st Qu.:0.000   1st Qu.:0.0000  
 Mode  :character   Median :28.00   Median :0.000   Median :0.0000  
                    Mean   :29.70   Mean   :0.523   Mean   :0.3816  
                    3rd Qu.:38.00   3rd Qu.:1.000   3rd Qu.:0.0000  
                    Max.   :80.00   Max.   :8.000   Max.   :6.0000  
                    NA's   :177                                     
    Ticket               Fare           Cabin             Embarked        
 Length:891         Min.   :  0.00   Length:891         Length:891        
 Class :character   1st Qu.:  7.91   Class :character   Class :character  
 Mode  :character   Median : 14.45   Mode  :character   Mode  :character  
                    Mean   : 32.20                                        
                    3rd Qu.: 31.00                                        
                    Max.   :512.33   

There is two "" in Embarked. Drop them.

> titanic$Embarked[grepl("^\\s*$", titanic$Embarked)] 
[1] "" ""
> titanic <- droplevels(titanic[!grepl("^\\s*$", titanic$Embarked),,drop=FALSE])

There is also 177 NA's in Age. Use mean of age to fill NA's

> summary(titanic$Age) 
   Min. 1st Qu.  Median    Mean 3rd Qu.    Max.    NA's 
   0.42   20.00   28.00   29.64   38.00   80.00     177 

> titanic$Age[is.na(titanic$Age)] <- 
  round(mean(titanic$Age, na.rm = TRUE))

Set categorical variables. Variables can be classified as categorical or quantitative.

  • Categorical variables take on values that are names or labels. ex. Embarked in our dataset.
  • Quantitative variables are numerical. They represent a measurable quantity. ex. Age in our dataset.
titanic$Survived = as.factor(titanic$Survived)
titanic$Pclass = as.factor(titanic$Pclass)
titanic$Embarked = as.factor(titanic$Embarked)
titanic$Sex_num = if_else(titanic$Sex == "male",
                          1,
                          0)
titanic$Sex_num = as.factor(titanic$Sex_num)

Okay, we get the data ready to use.

2. Build, train, and test the model

Choose the variables we would like to discuss. We choose Survived, Pclass, Age, SibSp, Parch, Fare, Sex_num, and Embarked.

df <- 
  titanic %>% 
  select(Survived, Pclass, Age, SibSp, Parch, Fare, Sex_num, Embarked)

Check the target variable, Survived.Good, it is not a huge class imbalance.

> df %>%count(Survived)
# A tibble: 2 x 2
  Survived     n
  <fct>    <int>
1 0          549
2 1          340

Check the distribution and correlation between variables.

library(psych)
pairs.panels(df[,],
             ellipses=FALSE,
             pch = 19, 
             hist.col="blue")

Distribution and correlation between variables

Split train and test data. Set 75% is train data.

library(caret) 
set.seed(2019)
trainIndex <- createDataPartition(df$Survived, p=0.75, list = FALSE)
train <- df[trainIndex,]
test <- df[-trainIndex,]

Build decision tree model

tree <- rpart(Survived ~., data=train, method='class')

What does the decision tree look like?

library(rpart)
prp(tree,         
    faclen=0,           
    fallen.leaves=TRUE, 
    shadow.col="gray",
    ) 

Decision Tree

Another fancier way to take a look a decision tree.

library(rpart.plot)
rpart.plot(tree)

rpart

3. Evaluate the performance of the model

Use test data to evaluate the performance of the model.

X_test <-
  test %>%
  select(Pclass, Age, SibSp, Parch, Fare, Sex_num, Embarked)
pred <- predict(tree, newdata=X_test, type=c("class"))

Calculate confusion matrix and plot it.

confus.matrix <- table(real=test$Survived, predict=pred)
fourfoldplot(confus.matrix, color = c("#CC6666", "#99CC99"),
             conf.level = 0, margin = 1, main = "Confusion Matrix")

confus matrix

The accuracy of the model

> sum(diag(confus.matrix))/sum(confus.matrix)
[1] 0.8333333

Hope you found this article helpful.

Latest comments (0)