DEV Community


Posted on

Creating a workout tracker using Teachable Machine and SashiDo

I always used to think that in order to use machine learning models you had to be a genius and have PhD in mathematics, but now using transfer learning it has become easier than ever!

In this tutorial, I will show you how you can get started with machine learning by creating a simple web app to count your push-ups and squats.

This article is aimed at people with some experience of javascript but no experience of Machine Learning or AI.

Table of Contents

What we will be making
Part 1: Model Training
Part 2: Creating the API
A word of caution…
Part 3: Creating the front End
Part 4 Deployment
Closing Thoughts

What we will be making

This tutorial is based on a small workout tracker I created to count push-ups and squats. The web app is hosted on SashiDo and uses SashiDo’s cloud code for the Express API too. The Machine Learning model was trained using a Teachable Machine model. The front end uses the svelte framework but the code should translate easily to your framework of choice -so if you are more familiar with React or Vue feel free to follow along with that.

Part 1: Model Training

The first step when creating a machine learning model is to decide on an architecture. This is the really challenging part which people spend years creating and refining. Luckily, using the Teachable Machine we can leverage an existing model (namely PoseNet).

If you are interested in how PoseNet works under the hood I recommend this article from the TensorFlow blog.

The next step is model training. This is where we teach our model to perform our task.

The pose model from Teachable Machine is already pre-trained on thousands of hours of training data, which makes it ideal for general purpose pose estimation.

But we don’t just want to do general purpose pose estimation, we want to count push-ups and squats.

To solve this problem we can leverage something called transfer learning. The solution involves tweaking the pre-trained model so that it works for our task. We do this by giving the model a smaller amount of training examples.

Although I say “smaller” don’t be fooled, if you have only a handful of training examples, you will have very lackluster performance. My model used around 1000 images per class which may seem like a lot but is really a drop in the ocean compared to the millions that PoseNet was trained on originally.

Getting the training data for this task was quite easy, I simply went to YouTube and recorded videos of people doing workouts. You can also use your webcam directly, the most important thing is that you get a range of samples with as much variation in lighting and position as possible.

After that I extracted the video frames using ffmpeg.

ffmpeg -i input.mp4 -r 5 putput_%04d.png
Enter fullscreen mode Exit fullscreen mode

This line of code extracts frames from the video input.mp4 and saves the frames as numbered png images. I specified a frame rate of 5 frames per second as this strikes a balance between having enough data and not having frames which are too similar to each other.

After that the data needs to be categorized. I split up the data into three categories: up, middle and down. Strictly speaking only up and down categories are needed but using a third category in the middle is useful to distinguish the difference between up and down more clearly.

Be warned this is the most tedious part of the whole process but it is crucial if you want your model to perform well.

Folders containing images

Once that’s done we’re ready to train our model, head over to the Teachable Machine website, create the categories, and upload your frames.

Teachable Machine

Click the Train Model button to commence the transfer learning using the samples you’ve added. Depending on the number of samples you’ve added and your computer’s resources this may take a while since model training happens in your browser.

Once that's done click Export Model and copy the model url that’s created - we’ll need that for the next stage.

Part 2: Creating the API

The training of the model made use of client side processing - all of the computation was performed in your browser. Inference could be performed on the client side too however this has the disadvantage of giving users on low powered devices such as smartphones a greatly diminished experience. As a result, for this tutorial we will stick to performing inference on a server and delivering those predictions using a REST API.

The Teachable Machine node package does most of the heavy lifting here so the entire express api is only 30 lines of code.

1 const express = require(' express')
2 const TeachableMachine = require('@sashido/teachablemachine-node')
4 const cars = require('cors')
6 const model = new TeachableMachine(f
7 modelUrl: '',
8 })
10 const app = express()
12 app.use(cors())
14 app.use(express.json())
15'/image/classify', async (req, res) => {
16 const { url } = req.body
18 return model
19 .classify({
20  imageUrl: url,
21 })
22 .then((predictions) => {
23  console.log(predictions)
24  return res.json(predictions)
25 })
26 .catch((e) => {
27  console.error(e)
28  res.status(500).send('Something went wrong!')
29 })
30 })
Enter fullscreen mode Exit fullscreen mode

This example should get you started but if you are looking to deploy this in a production environment, I recommend looking into adding additional features such as rate limiting and validation on the incoming data.

A word of caution…

You may be tempted to dynamically create a new instance of a model on each request but this is a bad idea. Instantiating a new instance of the Teachable Machine class is highly computationally intensive so will usually take several seconds to complete. Consequently creating a new instance on each request would harm performance drastically.

If you do expect to change your model I recommend storing the model url in environment variables which can be accessed when the node app is launched.

SashiDo makes this particularly easy - environment variables can be set within the runtime section of the dashboard.

SashiDo console

Part 3: Creating the Front End

Using a JS framework for a project as small as this could be considered overkill, but I used Svelte in this case as I think it makes the code slightly more readable and modular, plus it gave me an opportunity to use a framework I have not used before.

There are three main requirements for the front end code. It must:

  • Request Access and setup the user’s camera
  • Make requests to the API
  • Display a counter showing the number of repetitions of the current exercise

Getting access to the user’s webcam can easily be achieved using the mediaDevices.getUserMedia function.

We store the stream object returned in the srcObject of a video element so that it can be accessed later.

1 async function setup() {
2     try {
3          stream = await navigator.mediaDevices getUserMedia({
4          audio: false,
5          video: true,
6      });
7  video.srcObject = stream;
8 } catch (error) {
9  console.log("error setting up video", error);
11 }
Enter fullscreen mode Exit fullscreen mode

The catch clause is most likely triggered if the user rejects access to their webcam. In this case I have just logged the result to the console. I will leave it as an exercise for the reader to modify the code to deliver an informative error message to the user.

Before we can send a request to our REST API we must first capture an image, we do this by drawing a frame from the hidden video element onto a canvas. Lines 2-12 scale the crop so that it is a 217px by 217px square as the pose net can only perform inference on square images that size.

It would be possible to perform this scaling on our server but that would mean sending a large image across the internet - something which would unnecessarily slow things down for users with slow internet connections.

After that we POST a request to our API using javascript's fetch function.

 1 async function getPrediction() {
 2  videoScale = Math.min(video.videoHeight, video.videoWidth)
 3  xOffset = (((videoScale - video.videoWidth) / 2) * DIMENSION) / videoScale
 4  yOffset = (((videoScale - video.videoHeight) / 2) * DIMENSION) / videoScale
 6  context.drawImage(
 7   video,
 8   xOffset,
 9   yOffset,
10   DIMENSION - 2 * xOffset,
11   DIMENSION - 2 * yOffset
12  )
14  return await fetch(apiEndpoint, {
15   method: 'POST',
16   headers: {
17     'Content-Type': 'application/json',
18   },
19   body: JSON.stringify({
20     url: canvas.toDataURL('image/jpeg'),
21   }),
22  })
23   .then((response) response.json())
24   .then(interpret)
25 }
Enter fullscreen mode Exit fullscreen mode

Making one request is nice but we really need to do it several times per second

You could be tempted to write something like this:

setTimeOut(getPrediction, 200)
Enter fullscreen mode Exit fullscreen mode

But it falls short in two areas: For users with a fast internet connection we are artificially limiting them to 5 frames per second. But worse still for users with a slow internet connection after 200 milliseconds the request may still not have been completed.

HTTP 1.1 can only cope with a handful of concurrent connections at once, so a solution like this one will likely starve the most recent requests as they wait for the previous ones to complete. If the user is lucky they may just have a very laggy experience and, if they’re not, their entire browser may crash.

To solve this issue we can use a recursive function which waits for the previous request to finish before making a new one. The boolean variable isTracking is used here to allow us to stop making requests when the user ends their session.

const sequentialTrack = () => getPrediction().then(()=>
        isTracking ? setTimeout(sequentialTrack, 0) : null
Enter fullscreen mode Exit fullscreen mode

To finish off we need to update the counter.

 1 function interpret(predictions) {
 2 const { probability, className } = likeliest(predictions);
 3 if (probability > 0.7) {
 4  switch (className) {
 5   case "up":
 6   case "down":
 7    if (className !== status) {
 8     count += 0.5;
 9     status = className;
10    }
11    break;
12   default:
13    break;
14  }
15 }
16 }
Enter fullscreen mode Exit fullscreen mode

Although there are 3 classes (“up”, “middle”, and “down”) we only change the counter when the current likeliest prediction is “up” or “down”.

The if statement (line 7) is used to ensure there is a state transition - a transition from down to up should increment the counter, whereas multiple sequential down predictions should not change it.

Part 4 Deployment

Using SashiDo deployment is easier than ever. All projects are given a private github account so to deploy your version to the web simply push your changes to the github repo.

Since Svelte requires a build step, I used two repositories and set the output of the build from rollup, to be the public directory of the SashiDo repo. There are other ways to structure this so feel free to structure it in a way which works for your project.

Closing Thoughts

Creating this tutorial was a fun learning experience. Although I have experience using javascript, I was new to Teachable Machine and Svelte, not to mention SashiDo.

I’ve tried to keep this tutorial as accessible as possible by sticking to only the essential features of the web app.

As a result I have only just scratched the surface of what SashiDo can do. Using Parse Server for instance you could allow users to login to save and track progress over time.

I hope you found this interesting and if you’re interested in having a go at the live version it can be found here

Useful Links:

Github Repository
Official Parse Guide for Javascript
SashiDo’s Getting Started Guide
Teachable Machine on NodeJS
Teachable Machine Community repo on GitHub

Top comments (0)