DEV Community

Davide Santangelo
Davide Santangelo

Posted on

9 4

Simple DecisionTreeClassifier in Python

Decision Trees (DTs) are a non-parametric supervised learning method used for classification and regression. The goal is to create a model that predicts the value of a target variable by learning simple decision rules inferred from the data features.

For instance, in the example below, decision trees learn from data to determine the preferred music genre based on the year and gender of the person. The deeper the tree, the more complex the decision rules and the fitter the model.

DecisionTreeClassifier is a class capable of performing multi-class classification on a dataset.

As with other classifiers, DecisionTreeClassifier takes as input two arrays: an array X, sparse or dense, of size [n_samples, n_features] holding the training samples, and an array Y of integer values, size [n_samples], holding the class labels for the training samples:

from sklearn import tree
import sys

# age
# sex [0: male, 1: female]

features = [
    [18, 0], [19, 0], [22, 0], [25, 0], [28, 0], [31, 0], [34, 0], [40, 0], [45, 0],
    [18, 1], [19, 1], [22, 1], [25, 1], [28, 1], [31, 1], [34, 1], [40, 1], [45, 1]
]

# music genre

labels = [
    'rap', 'rap', 'hip hop', 'hip hop',
    'rock', 'rock', 'rock', 'country', 'country',
    'dance', 'dance', 'hip hop', 'hip hop',
    'rap', 'rap', 'rap', 'classical', 'classical'
]

clf = tree.DecisionTreeClassifier()

clf.fit(features, labels)

# pass age and sex as script params with sys.argv
prediction = clf.predict([[sys.argv[1], sys.argv[2]]])

print(prediction)


Enter fullscreen mode Exit fullscreen mode

Try it!


python3.7 decision_tree_classifier.py 18 1
['dance']

Enter fullscreen mode Exit fullscreen mode

the tree can also be exported in textual format with the function export_text.


from sklearn.tree.export import export_text

decision_tree_text = export_text(clf, feature_names=['age', 'sex'])
print(decision_tree_text)

Enter fullscreen mode Exit fullscreen mode

|--- age <= 37.00
|   |--- age <= 26.50
|   |   |--- age <= 20.50
|   |   |   |--- sex <= 0.50
|   |   |   |   |--- class: rap
|   |   |   |--- sex >  0.50
|   |   |   |   |--- class: dance
|   |   |--- age >  20.50
|   |   |   |--- class: hip hop
|   |--- age >  26.50
|   |   |--- sex <= 0.50
|   |   |   |--- class: rock
|   |   |--- sex >  0.50
|   |   |   |--- class: rap
|--- age >  37.00
|   |--- sex <= 0.50
|   |   |--- class: country
|   |--- sex >  0.50
|   |   |--- class: classical


Enter fullscreen mode Exit fullscreen mode

Image of Timescale

Timescale – the developer's data platform for modern apps, built on PostgreSQL

Timescale Cloud is PostgreSQL optimized for speed, scale, and performance. Over 3 million IoT, AI, crypto, and dev tool apps are powered by Timescale. Try it free today! No credit card required.

Try free

Top comments (0)

A Workflow Copilot. Tailored to You.

Pieces.app image

Our desktop app, with its intelligent copilot, streamlines coding by generating snippets, extracting code from screenshots, and accelerating problem-solving.

Read the docs

👋 Kindness is contagious

Immerse yourself in a wealth of knowledge with this piece, supported by the inclusive DEV Community—every developer, no matter where they are in their journey, is invited to contribute to our collective wisdom.

A simple “thank you” goes a long way—express your gratitude below in the comments!

Gathering insights enriches our journey on DEV and fortifies our community ties. Did you find this article valuable? Taking a moment to thank the author can have a significant impact.

Okay