DEV Community

Nam Phạm
Nam Phạm

Posted on

Introducing to a simple classification and create a neural network using Brainjs to do it

An introduction

I write this article especially for my students as many of them have heard about some topics such as machine learning, deep learning, classification, … but still haven't figured how to do it due to difficulties in learning about the ideas, the math, platform to run, languages, libraries usage, … It takes to learn about deep learning and it is in general a very broad topics so in this article, I want to show you how to do a classification task using a deep learning technique called neural network to give you a slice idea of how to do it in general.

So what is a classification? Classification is that you are given an input, and your job is to tell what type of the input is based on some known types. For example, in this article, you are given a measurement of an iris flower (its sepal length, sepal width, petal length, petal width) and you need to tell what variety of that iris flower is (it can be setosa, versicolor or virginica)

The ideas

How can we do that? Basically, you will build a function that takes the above parameters and outputs the type of the iris flower. We can see that it is not possible to generate such a function using classical programming techniques and that is where we resort to the neural network technique of deep learning. This neural network plays the role of the above function and we will train the neural network based on the measured parameter of gathered iris flowers data that we collected and with that the neural network can perform classification task by interpolation for an unknown measurement parameter. Each measurement parameter will be attached to the main label as the type of iris flower.
Thus we have the following:

  • Collect data and corresponding labels
  • Building a neural network
  • Train neural network based on collected data set
  • Verify the results of the neural network
  • Using the above neural network in practice

This article uses the iris flower dataset at

How do we create neural network as said? In fact, there are libraries like tensorflow, pytorch, … dedicated to deep learning, but due to the use of python and high hardware requirements, it is not suitable for those who use javascript as the main programming language. and that's why this article uses brainjs, a library that allows creating a simple neural network using javascript and can fully use the power of the GPU to train through the GPU.js library as a foundation.

Before we get into using brainjs to create and train neural networks we need to take a look at our dataset.

sepal_length sepal_width petal_length petal_width species
5.1 3.5 1.4 0.2 Iris-setosa
4.9 3 1.4 0.2 Iris-setosa
4.7 3.2 1.3 0.2 Iris-setosa
4.6 3.1 1.5 0.2 Iris-setosa
5 3.6 1.4 0.2 Iris-setosa
7 3.2 4.7 1.4 Iris-versicolor
6.4 3.2 4.5 1.5 Iris-versicolor
6.9 3.1 4.9 1.5 Iris-versicolor
5.5 2.3 4 1.3 Iris-versicolor
6.5 2.8 4.6 1.5 Iris-versicolor
5.7 2.8 4.5 1.3 Iris-versicolor
6.3 3.3 6 2.5 Iris-virginica
5.8 2.7 5.1 1.9 Iris-virginica
7.1 3 5.9 2.1 Iris-virginica
6.3 2.9 5.6 1.8 Iris-virginica
6.5 3 5.8 2.2 Iris-virginica
7.6 3 6.6 2.1 Iris-virginica
4.9 2.5 4.5 1.7 Iris-virginica
7.3 2.9 6.3 1.8 Iris-virginica

As you can see an recorded tuple (5.1, 3.5, 1.4, 0.2) is labeled Iris-setosa while (7, 3.2, 4.7, 1.4) is Iris-versicolor and for (6.3, 3.3, 6, 2.5), it is Iris-virginica. Our function, in this case is the neural network, should be able to tell what variety a iris flower is for an arbitrary given input tuple.

Before we dive in into how to create such network, we have to understand the form of the input we feed to the network, and the output we will get there. The input is easy to see that it must be an tuple of 4 numbers, but what's about our output? We first numbered the label Iris-setosa, Iris-versicolor, Iris-virginica 0, 1 and 2 respectively. You may think that our function should output these values, but no. The number is actually the slot in tuple, which indicates the probabilities of the input being in each variety. So the input (5.1, 3.5, 1.4, 0.2) should be mapped to the output of (1, 0, 0) because it is 100% the setosa iris and none for the others. Again, we will have to transform our data into something like this:

sepal_length sepal_width petal_length petal_width Iris-setosa Iris-versicolor Iris-virginica
5.1 3.5 1.4 0.2 1 0 0
4.9 3 1.4 0.2 1 0 0
4.7 3.2 1.3 0.2 1 0 0
4.6 3.1 1.5 0.2 1 0 0
5 3.6 1.4 0.2 1 0 0
7 3.2 4.7 1.4 0 1 0
6.4 3.2 4.5 1.5 0 1 0
6.9 3.1 4.9 1.5 0 1 0
5.5 2.3 4 1.3 0 1 0
6.5 2.8 4.6 1.5 0 1 0
5.7 2.8 4.5 1.3 0 1 0
6.3 3.3 6 2.5 0 0 1
5.8 2.7 5.1 1.9 0 0 1
7.1 3 5.9 2.1 0 0 1
6.3 2.9 5.6 1.8 0 0 1
6.5 3 5.8 2.2 0 0 1
7.6 3 6.6 2.1 0 0 1
4.9 2.5 4.5 1.7 0 0 1
7.3 2.9 6.3 1.8 0 0 1

And now, we can train our network


Brainjs is a js library that allow users to create, train and reuse the neurtal networks they created. Brainjs can be used in browser environment and this article focus on training a neural network in browser. You should have Firefox or Google Chrome installed to run the example.

Understand how to work with Brainjs

Prepare the data

The data is an js array whose elements are the rows from the dataset and each row must be in the form of

    input: [inputNumber0, inputNumber1, inputNumber2, ..., inputNumberM],
    output: [outputNumber0, outputNumber1, outputNumber2, ..., outputNumberN]
Enter fullscreen mode Exit fullscreen mode

for example, the row

sepal_length sepal_width petal_length petal_width Iris-setosa Iris-versicolor Iris-virginica
5.1 3.5 1.4 0.2 1 0 0

will be

    input: [5.1, 3.5, 1.4, 0.2],
    output: [1, 0, 0]

Enter fullscreen mode Exit fullscreen mode

Create a neural network

We create a neural network in Brainjs using the following code

let net = new brain.NeuralNetwork({
                    binaryThresh: 0.5,
                    hiddenLayers: [3, 3, 2],
                    activation: "sigmoid",
Enter fullscreen mode Exit fullscreen mode

Here, hiddenLayers parameter determine the number of layers in the neural network and number of neurons in each layers.
The activation parameter determine the activation function being used at the last hidden layer before the output.

Train the network

After creating the network, we can train the network using the following code

net.train(trainingData, {
                    iterations: 1000,
                    learningRate: 0.3,
Enter fullscreen mode Exit fullscreen mode

The iterations determines how many round the net will run
The learningRate determines how large the network parameters should be updated

Use the trained network to do classification task

You can use the network to do classification task by calling[value0, value1, value2, value3]);
Enter fullscreen mode Exit fullscreen mode

The output is the probabilities of each type in the classification

Extract the trained network data

After training the network, you can extract the network data by running

let extracted = net.toJSON()
Enter fullscreen mode Exit fullscreen mode

Reload trained network

With the extracted data, you can now recreate the network without training it by

Enter fullscreen mode Exit fullscreen mode

Provided example

User should have tool like http-server, Vite installed and know how to use the tool from the command line. I use Vite here since I'm using it for other projects as well.


Create a directory for the project

You should be able to create a directory for a project

Download and convert the csv data to json

Download the data from the kaggle link I mentioned earlier and use tool like csv2json at to convert data and download it to your directory. Name it data.json

Create index.html

In your directory, create a index.html file with following code

<!DOCTYPE html>
        <meta charset="utf-8" />
        <meta name="viewport" content="width=device-width, initial-scale=1, maximum-scale=1, minimum-scale=1, user-scalable=no, viewport-fit=cover" />
        <meta name="apple-mobile-web-app-capable" content="yes" />
        <title>Kaggle Iris dataset training</title>
        <script src=""></script>
        <h1>Kaggle Iris dataset training using brainjs</h1>
            <button onclick="handleClick()">Click to train</button>
            <textarea id="output" rows="40" cols="80" readonly></textarea>

            let running = false;
            let trained = null;
            async function handleClick() {
                if (running) return;

                running = true;

                try {
                    let net = train(await getTrainingData());

                    trained = net;
                    document.getElementById("output").value = JSON.stringify(net.toJSON(), null, 4);
                } finally {
                    running = false;

            async function getTrainingData() {
                return (await (await fetch("data.json")).json()).map((o) => ({
                    input: [o.sepal_length, o.sepal_width, o.petal_length, o.petal_width],
                    output: [o.species == "Iris-setosa" ? 1 : 0, o.species == "Iris-versicolor" ? 1 : 0, o.species == "Iris-virginica" ? 1 : 0],

            function train(trainingData) {
                let net = new brain.NeuralNetwork({
                    binaryThresh: 0.5,
                    hiddenLayers: [3, 3, 2],
                    activation: "sigmoid",

                net.train(trainingData, {
                    iterations: 1000,
                    learningRate: 0.3,

                return net;
Enter fullscreen mode Exit fullscreen mode
Run a web server from your directory

Fire up a web server by using http-server or Vite

Click run to train

Go to your local web server and click the button. The code will download the data from data.json file, transform it to Brainjs data form, create a neural network and feed the data to the network, train it and finally output the trained networked into the textarea element in the form of json

Sorry for not implementing the UI to run the classification but the trained network is stored in the global variable trained. You can easily do the classificatoin by runing the in the console

The article won't cover all the aspects of neural network and deep learning in general but I hope you know what to do with the network especially when you write js.

Have fun with Brainjs and have a good day.

Top comments (0)