DEV Community

Cover image for Random Forest (part 2)
Abzal Seitkaziyev
Abzal Seitkaziyev

Posted on

Random Forest (part 2)

In the previous post, we outlined how Random Forest algorithm works. Here I would like to explore the concepts described in part 1 on a dataset by using scikit-learn. Dataset details can be found here.

1) Bootstrap aggregating.
When creating a data set from random samples for each tree, we can estimate that 63.2% (1 - 1/e, for large n) of the data points are unique and the remaining are duplicates. While the number of these sampled data sets is defined by the number of trees in Random Forest.

Tree #1 out of 100, with max_depth=2:
Tree # 1

Tree #71 out of 100, with max_depth=2:
Tree # 71

Tree #100 out of 100, with max_depth=2:
Tree # 100

As we can see the number of unique data points(rows) is close to the expected 63.2% for each sampled data set.

2) Random selection of the features at each step.
Here we used 'auto', which is the same as max_features=sqrt(n_features) per scikit-learn random forest classifier documentation. In this example, the algorithm uses 2 random features (n_features=4) and then applies specified criteria (e.g. Gini) to select the best feature out of these two to split the tree. This randomization of the feature selection is applied at each node till the tree reaches some specified parameter, e.g. maximum depth.

3) Making prediction.
As we know Random Forest creates multiple trees and for the classification problem, each tree has an equal vote. So, the class with more votes will be selected and projected to the user as a predicted answer.
In this case, we have 100 trees, and each test data point (=row) X_test[i] will go through each tree. For example, if we get 60 trees with an answer = class 0 and 40 trees with an answer = class 1. As 60>40, class 0 will be selected as the final answer of Random Forest algorithm for a particular data point.

Top comments (0)