DEV Community

Idrees Khan
Idrees Khan

Posted on

Useful LSTM network example using brain.js

Problem

They say every problem has a solution (not necessarily). I am using an app called Splitwise. There, as you try to type description while adding an expense (check the screenshot no 3 on PlayStore), it will automatically select a category for you. I was thinking if I can do same (without writing complex code) through any JavaScript machine learning library. And guess what? i checked out brain.js.

Solution

Fortunately the library does solve this kind of problem very easily and that is through common machine learning model called RNN (Recurrent Neural Networks). One of the special kind of RNN network (for above use-case I used) is LSTM (Long Short Term Memory) network. A tremendous explanation is done in this article which helped me a lot too.

The Data

You can find a working example in this repo. So first we will need to build a model and for model we need data. You could use data from your existing database or from any other source. But you definitely need some sort of data. For this demonstration, I have added static data in JSON here. Please note that I have not organized the data on purpose as in real world, you will not always have an organized data. A sample 1 record is shown below.
Expense sample data record

Enough! Show me the demo

First, we need to install brain.js by simply running following command.
$ npm i brain.js --save
Next we need to prepare a training set from our data. A training set (in our case) should have an input and output properties. The input I will be passing is the description property of our data and the output property I will be passing the category property. In simple words I want to training my model on existing description which has appropriate selected category. The final code looks like below:

Alt Text

Training

Now that I have the data set or training data, its time to build our model. We can do that in brain.js by creating an instance of LSTM network and call train() method.

brain.js train()

The train() method accepts the 2 parameters i.e data and optional config. The config object has several useful properties which you will need to play with as it depends on your data. You might need to specially play with iterations and errorThresh. Full config object can be found here. If you still want to know more on errorThresh and iterations properties then checkout this awesome answer on StackOverflow.

Finally as I don't want any delay while utilizing this model, I will export it using brain.js helper method i.e network.toJSON(). All we need to do now is to save it somewhere and do our predictions. I will run now the demo project using http://localhost:3000/app/build and let it generate the .json file. You will have to be patience as it will take time to build the model. Its a good idea to decrease the iterations if you want to see results soon.

Prediction

Now that our model is ready, we can simply use this .json file in our client project i.e mobile/web app. For this demo, i will simply use it in api.

Prediction

Simple create here an instance of LSTM network and call fromJSON() method. This will make sure you build the model once and re-use it everywhere you want.
network.run<string>(description) wil return the predictions that we are expecting. Its time to test it.

Testing

Head over to browser in Demo project and type http://localhost:3000/app/predict?description=grocery and you will get following:

Prediction Demo

Final Thoughts

Although brain.js is making everything easy for us, but it is still good to know the basics of what you algorithm you need to use for a particular problem. Also be aware that brain.js is still in alpha. The example i have provided might not work in the future.

Top comments (1)

Collapse
 
mrhili profile image
Mrhili Mohamed Amine

Hi how to implement PPO in brain.js with reward and punishement