DEV Community

Cover image for Plotting Decision Trees using Python
Ruthvik Raja M.V
Ruthvik Raja M.V

Posted on

Plotting Decision Trees using Python

Hello folks,
To plot Decision Trees using python as output the following code can be implemented:-
Alt Text

Alt Text

Before, executing the python code download the dataset from the following link:

# Decision Tree Classifier
import pandas as pd
from sklearn.model_selection import train_test_split
# This is used to split our data into training and testing sets
from sklearn import tree # Here tree is a module
from sklearn.metrics import accuracy_score
# Used to check the goodness of our model
import matplotlib.pyplot as plt
# Used to plot figures

# storing our excel file in df1 # This function is used to check whether our data consists of any missing or null values
X_train, X_test, Y_train, Y_test=train_test_split(X, y, test_size=0.2, random_state=0)
# Here test_size = 0.2 means it uses 20% of our input data for testing and 80% for training set
# random_state = 0 means every time it uses the same set of testing and training set for evaluation

# Using Entropy for computing the Decision Tree,Y_train)
pred=clftree1.predict(X_test)    # Predicting the values for our test data
accuracy_score1=accuracy_score(Y_test, pred)   # Finding the accuracy score of our model

fig, ax = plt.subplots(nrows = 1, ncols = 1, figsize = (10,10),dpi=300)
# Let us create a figure with size (10X10) and density per inch = 300
tree.plot_tree(clftree1, feature_names=list(df1.columns),class_names="01",filled =True)
# plot_tree is used to plot our decision tree. The parameters are our Decision Tree, feature names, class names to be displayed in
  # string format (or) as a list, filled=True will automatically fill colours to our tree etc

# Using Gini Index for computing the Decision Tree,Y_train)
pred=clftree2.predict(X_test)    # Predicting the values for our test data
accuracy_score2=accuracy_score(Y_test, pred)   # Finding the accuracy score of our model

fig, ax = plt.subplots(nrows = 1,ncols = 1,figsize = (10,10),
tree.plot_tree(clftree2, feature_names=list(df1.columns),
class_names="01", filled=True)
Enter fullscreen mode Exit fullscreen mode


Top comments (0)