loading...

Tensorflow.js available on WebAssembly backend πŸ”₯

yashints profile image Yaser Adel Mehraban ・5 min read

Tensorflow.js is a library which lets you perform machine learning in the browser or in Node. It uses the GPU or CPU to do training and calculation, but recently the team have done a great job and brought WebAssembly backend to its ecosystem so that you can perform predictions faster. So without further ado, let's deep dive into this greatness.

Scenario

Let's assume you want to perform object detection in a given image. For this you can use multiple models, but for now let's focus on MobileNet.

We will use parcel to setup our app.

Setup without WASM backend

Packages

We will need to have a package.json file with below setup:

{
  "name": "wasm-parcel",
  "version": "1.0.0",
  "description": "Sample parcel app that uses the WASM backend",
  "scripts": {
    "watch": "parcel index.html --open",
    "build": "parcel build index.html"
  },
  "dependencies": {
    "@tensorflow/tfjs": "^1.4.0"
  },
  "browserslist": [
    "defaults"
  ],
  "devDependencies": {
    "@babel/core": "7.7.5",
    "@babel/plugin-transform-runtime": "^7.7.6",
    "@babel/preset-env": "^7.7.6",
    "parcel-bundler": "^1.12.4",
    "parcel-plugin-static-files-copy": "^2.2.1"
  },
  "keywords": []
}
Enter fullscreen mode Exit fullscreen mode

HTML

Let's setup our app like we do normally with Tensorflow.js. We need to add a div to show the status, and another div which contains an image tag:

<div class="tfjs-example-container">
  <section class="title-area">
    <h1>TensorFlow.js running on WebAssembly backend</h1>
  </section>

  <section>
    <p class="section-head">Status</p>
    <div id="status"></div>
  </section>

  <section>
    <p class="section-head">Image used</p>

    <img id="img" src="./img/piano.jpg" width="224px" />
  </section>
</div>
Enter fullscreen mode Exit fullscreen mode

And let's our JavaScript file before body close tag:

<script src="index.js"></script>
Enter fullscreen mode Exit fullscreen mode

Altogether your HTML should look something like this:

<!DOCTYPE html>
<html lang="en">
  <head>
    <title>Hello!</title>
    <meta charset="utf-8" />
    <meta http-equiv="X-UA-Compatible" content="IE=edge" />
    <meta name="viewport" content="width=device-width, initial-scale=1" />
    <link rel="stylesheet" href="style.css" />
  </head>
  <body>
    <div class="tfjs-example-container">
      <section class="title-area">
        <h1>TensorFlow.js running on WebAssembly backend</h1>
      </section>

      <section>
        <p class="section-head">Status</p>
        <div id="status"></div>
      </section>

      <section>
        <p class="section-head">Image used</p>

        <img id="img" src="./img/piano.jpg" width="224px" />
      </section>

      <script src="index.js"></script>
    </div>
  </body>
</html>
Enter fullscreen mode Exit fullscreen mode

CSS

Let's add some basic styling so that it's not an ugly app when we run it. We're doing some cool stuff here and it deserves good look πŸ˜‰:

/* CSS files add styling rules to your content */

body {
  font-family: helvetica, arial, sans-serif;
  margin: 2em;
}

p {
  max-width: 960px;
  line-height: 1.6em;
}

p.section-head {
  font-variant: small-caps;
  text-transform: uppercase;
  letter-spacing: 0.17em;
  line-height: 1.2em;
  font-weight: 500;
  margin-top: 2em;
  margin-bottom: 1em;
  border-left: 2px solid #ef6c00;
  padding-left: 24px;
  margin-left: -24px;
  color: #818181;
}

Enter fullscreen mode Exit fullscreen mode

JavaScript

In the code we need to perform two main operations, load the model, and use it for prediction. So lets load the model first:

import * as tf from "@tensorflow/tfjs";

let model = await tf.loadGraphModel(
  "https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/2",
  { fromTFHub: true }
);
Enter fullscreen mode Exit fullscreen mode

And we get a reference to the image tag, use browser.fromPixel method of Tensorflow.js to load the image and normalise it. And at last, feed it into predict method of our model:

const imgElement = document.getElementById("img");

let img = tf.browser
    .fromPixels(imgElement)
    .resizeBilinear([224, 224])
    .expandDims(0)
    .toFloat();

const prediction = model.predict(img);
Enter fullscreen mode Exit fullscreen mode

Since we want to compare the timing between the two approach, let's add some timers into our method and measure how long the operation would take. Altogether your JavaScript file should look like this:

import * as tf from "@tensorflow/tfjs";

function status(text) {
  document.getElementById("status").textContent = text;
}

async function main() {

  let model = await tf.loadGraphModel(
    "https://tfhub.dev/google/imagenet/mobilenet_v2_100_224/classification/2",
    { fromTFHub: true }
  );

  const startTime1 = performance.now();  

  const imgElement = document.getElementById("img");
  status("Model loaded!");

  let img = tf.browser
    .fromPixels(imgElement)
    .resizeBilinear([224, 224])
    .expandDims(0)
    .toFloat();

  let startTime2 = performance.now();

  const logits = model.predict(img);

  const totalTime1 = performance.now() - startTime1;
  const totalTime2 = performance.now() - startTime2;
  status(`Done in ${Math.floor(totalTime1)} ms ` +
      `(not including preprocessing: ${Math.floor(totalTime2)} ms)`);

  const values = await logits.data();
  console.log(values);  
}

document.addEventListener("DOMContentLoaded", main);
Enter fullscreen mode Exit fullscreen mode

.babelrc

If you don't setup the babel plugins properly, you will get an error like below:

❌ Uncaught ReferenceError: regeneratorRuntime is not defined at HTMLDocument.main (index.js:8)

That's because we're using an async function on the top level. Add the below config to your .babelrc file to get around the error.

{
  "presets": ["@babel/preset-env"],
  "plugins": ["@babel/plugin-transform-runtime"]
}
Enter fullscreen mode Exit fullscreen mode

Running the app

Now you can run yarn followed by yarn watch and a browser window opens with the app inside. You should see a page like this:

Tensorflow.js prediction using GPU backend in browser to detect an object in a photo

Note the time taken to predict what's in the picture. Now lets add the WebAssebly backend and run the app to see how it performs.

With WASM backend

In order to add the WebAssembly backend you need to install the @tensorflow/tfjs-backend-wasm package:

{
  ...,
  "dependencies": {
    "@tensorflow/tfjs": "^1.4.0",
    "@tensorflow/tfjs-backend-wasm": "1.4.0-alpha3"
  },
  ...
}
Enter fullscreen mode Exit fullscreen mode

Next step is to let Parcel know how to load the WASM file. When the WASM backend is initialized, there will be a fetch / readFile for a file named tfjs-backend-wasm.wasm relative to the main JS file. That's why we need to use this section if we're using a bundler.

{
  ...,
  "staticFiles": {
    "staticPath": "./node_modules/@tensorflow/tfjs-backend-wasm/dist",
    "excludeGlob": [
      "**/!(*.wasm)"
    ]
  },
  ...
}
Enter fullscreen mode Exit fullscreen mode

And the last thing we need to do is to tell Tensorflow.js to use this backend:

import "@tensorflow/tfjs-backend-wasm";

async function main() {  
  await tf.setBackend("wasm");
  //...
}
Enter fullscreen mode Exit fullscreen mode

And that's all you need in order to enable the WebAssembly backend. Now let's run the app and see the difference:

yarn && yarn watch
Enter fullscreen mode Exit fullscreen mode

You should see the app compiled and a browser window open:

Tensorflow.js running on WebAssembly backend to detect an object in an image

And boom 🀯, look at the difference there. Almost down by 2 seconds, and that's just doing a single operation. Imagine if we were doing more operations and the benefits we gain from this approach.

Summary

You can find the full demo on my GitHub repo.

This feature is defo one of the best thing that's happened since introducing Tensorflow.js to allow web developers get into ML and AI within the browser in my opinion. So go ahead use this feature and benefit from the massive performance improvements.

However, there is a catch πŸ‘‡πŸΌ:

What's the catch

The catch is that not all functions are implemented in the WASM backend. It means you can't run many of the demo's on it unless the team implement them. You can follow the progress on their GitHub repo to stay on top of the game.

Discussion

pic
Editor guide
Collapse
jochemstoel profile image
Jochem Stoel

This might be the ultimate stupid question but what exactly does it predict? What is in the image? If so, what do I do with this resulting array of floats?

Collapse
yashints profile image
Yaser Adel Mehraban Author

There is no such thing as stupid questions, the point of this article was to show you the difference between running the prediction in normal and web assembly backends.

I suggest you look at the mobile net demo on tjfs GitHub repo to see what can be done with the results, but tldr is that it gives you the probability of each object it detected

Collapse
jochemstoel profile image
Jochem Stoel

Could you elaborate on how to retrieve "piano" from this?

Thread Thread
yashints profile image
Thread Thread
jochemstoel profile image
Collapse
juancarlospaco profile image
Juan Carlos

PyTorch runs on WebAssembly on Mobile on NimTorch with the full experience from a long time now.

Collapse
yashints profile image
Yaser Adel Mehraban Author

Does it run in the browser?

Collapse
juancarlospaco profile image
Thread Thread
yashints profile image
Yaser Adel Mehraban Author

That’s cool, will check it out

Thread Thread
juancarlospaco profile image
Juan Carlos

I recommend learning Nim, time well spent.
Nim runs on browser and everywhere where theres a C/C++/JS API.
I use it to dev Python, it just works like Cython.
It also replaced NodeJS for me, it works like Svelte.