DEV Community

Cover image for How to Train a Model in TensorFlow.js on Massive Datasets (Out-of-Core Learning).
Álvaro N.
Álvaro N.

Posted on

How to Train a Model in TensorFlow.js on Massive Datasets (Out-of-Core Learning).

You wake up one day and decide: "I'm going to create a machine learning model".

On the first day, you have a JSON file with 100 entries and 100 WebP images to train a Convolutional Neural Network (CNN) for detection. You read the JSON, load the 100 images, put everything in memory (which should consume about 100 MB), and everything works perfectly.

After a week, you've taken a lot of photos, and your data has increased 10 times. Now, with 1000 entries, your model performs much better. But then, you look at your task manager and your Node.js training process is already consuming 1 GB. Again, you think: "not a problem, my computer has 16 GB of RAM", and you carry on happily with your model.

However... how much can the memory handle?

After a month collecting images and filling out your JSON, you come across a 10 GB .zip file of photos and thousands of entries. The modus operandi is the same: you extract the files to the usual folder and run the training script.

Kaboom!!!

Your computer freezes, the fan screams, and the process quickly exhausts the RAM. The result? The dreaded OOM (Out of Memory) that crashes your script.

You then think of two alternatives: spend a fortune to upgrade your setup to 64 GB of RAM and move on with life, or... ask yourself: "is there a way to read a file, train the model on it, remove this file from memory, and pull the next one, in a continuous cycle?".

This way, you could feed the model with a dataset that wouldn't fit entirely in memory through a continuous flow (data stream), without ever blowing up your machine's RAM.

The answer to your problem lies in the Generators and their ability to create a data streaming.

Generators, Iterators and Lazy Evaluation

In short, generators (function*) are basically functions that manufacture "iterators" in JavaScript. Analogously, the functioning of an iterator is conceptually similar to a paginated list. You want to read the next item? Click next, the item will be fetched and processed. Want the next one? Click again. When you reach the last item and click next, the list will inform you that it's over.

In technical terms, this is called lazy evaluation. Data isn't allocated in memory until you explicitly ask for it. The code only runs on demand.

Let's take an example:

export function* someGenerator() {
  console.log("Inside: First log");
  yield 1;

  console.log("Inside: Second log");
  yield 2;

  console.log("Inside: Third log");
  yield 3;

  console.log("Inside: Fourth log");
}

const iterator = someGenerator();

console.log("Outside:", iterator.next().value);
console.log("Outside:", iterator.next().value);
console.log("Outside:", iterator.next().value);
console.log("Outside:", iterator.next().value);
Enter fullscreen mode Exit fullscreen mode

With the mindset of a normal function, you would think: "The code will run to the end, store all these yield instructions in an array and return everything to me". But the output proves otherwise:

Inside: First log
Outside: 1
Inside: Second log
Outside: 2
Inside: Third log
Outside: 3
Inside: Fourth log
Outside: undefined
Enter fullscreen mode Exit fullscreen mode

On the first next() call, the execution goes to the first yield and pauses. The function's state is frozen. When we call next() again, it resumes exactly from where it left off. At the last log, there are no more yields. The function reaches the end and the iterator returns { value: undefined, done: true }.

Whoa, powerful. But how do we use it in practice?

Generators are our golden goose for creating infinite processing pipelines without clogging the memory. See an example applying this concept to reading image files:

export async function* datasetGenerator(dataPath) {
  const dir = await fs.opendir(dataPath);

  for await (const dirent of dir) {
    if (!dirent.isFile() || (!dirent.name.endsWith(".webp") && !dirent.name.endsWith(".jpg"))) {
      continue;
    }

    const buffer = await fs.readFile(path.join(dataPath, dirent.name));
    yield { 
      buffer, 
      labelValue: getLabelFromName(dirent.name) 
    };
  }
}
Enter fullscreen mode Exit fullscreen mode

The fact that it's an asynchronous function doesn't change the essence of the generator; the difference is that internally we iterate with promises. The magic here is that there is only one image loaded in the buffer and in memory at a time.

How to use Generators in TensorFlow.js

The TensorFlow library has a native and direct method to convert our custom generator into an object of the tf.data.Dataset class. This object is the heart of the data pipeline in TF.js.

// TF.js requires the function reference, not the instantiated iterator
const dataset = tf.data.generator(() => datasetGenerator('./my-dataset'));
Enter fullscreen mode Exit fullscreen mode

"But I need to process the data first!"

It's not necessary to pollute the generator function with heavy preprocessing logic. The Dataset class contains powerful utility methods for us to process (or map) the data as it flows through the pipeline:

const preparedDataset = dataset.map(({ buffer, labelValue }) => {
    return tf.tidy(() => {
      const imageTensor = tf.node.decodeImage(buffer, 3);
      const xs = imageTensor.div(255);
      const ys = tf.tensor1d([labelValue]);
      return { xs, ys };
    });
});
Enter fullscreen mode Exit fullscreen mode

The Mandatory {xs, ys} Format

TensorFlow.js needs to know who is who in your data.

  • xs (features): What we are going to use to predict.
  • ys (labels): What we want to predict.

The model.fitDataset(dataset) method requires that each return from your pipeline has the exact format of an object containing xs (input tensors) and ys (output tensors).

const trainableDataset = preparedDataset.map(({ normalizedImage, label }) => ({
  xs: normalizedImage,
  ys: label
}));
Enter fullscreen mode Exit fullscreen mode

The Dataset Class Arsenal

Building the pipeline is like a conveyor belt in a factory: you order the workers (functions) and each one does a thing at their stage and passes it forward.

With this, it's possible to apply various processings to your dataset before feeding them to the model, such as:

  • .shuffle(1000): Creates a fixed-size buffer in memory (e.g., 1000 items). As your generator processes items, they fill this buffer. When the model needs an item, it randomly picks an item from this buffer, and the generator immediately replaces it with the next fetched item.

A VERY important note about the shuffle method: if your dataset is ordered, for example: 1000 photos of cats and then 1000 photos of dogs, the buffer will only randomize among the 1000 photos of cats, which can cause catastrophic forgetting. To avoid this, you must randomize your data or the way you read it.

  • .batch(32): Groups a number of samples into a single batch. Your {xs: Tensor, ys: Tensor} object containing 1 image becomes a group containing 32 images in the first dimension. (Essential for taking advantage of TensorFlow's matrix calculations).
  • .prefetch(2): Makes the generator read 2 batches of files in parallel with the model training and leaves them ready in the queue for when the training finishes. (Basically, training while processing).
  • .take(10) / .skip(10): Selects only 10 items or ignores the first 10 items.
  • .filter(fn): Works like the native array filter, removing unwanted items from the stream before they reach the model.

Putting it all together (pipeline example):

import fs from 'fs/promises';
import path from 'path';
import * as tf from '@tensorflow/tfjs-node';

export async function* datasetGenerator(dataPath) {
  const dir = await fs.opendir(dataPath);

  for await (const dirent of dir) {
    if (!dirent.isFile() || (!dirent.name.endsWith(".webp") && !dirent.name.endsWith(".jpg"))) {
      continue;
    }

    const buffer = await fs.readFile(path.join(dataPath, dirent.name));
    yield { 
      buffer, 
      labelValue: getLabelFromName(dirent.name) 
    };
  }
}

const finalDataset = tf.data
  .generator(() => datasetGenerator("./train_data"))
  .shuffle(1000) 
  .map(({ buffer, labelValue }) => {
    return tf.tidy(() => {
      const xs = tf.node.decodeImage(buffer, 3);
      const ys = tf.tensor1d([labelValue]);
      return { xs, ys };
    });
  })
  .batch(32) 
  .map(({ xs, ys }) => {
    return tf.tidy(() => {
      const normalizedXs = xs.div(255);
      return { xs: normalizedXs, ys };
    });
  })   
  .prefetch(2);

await model.fitDataset(finalDataset, { 
  epochs: 32,
});
Enter fullscreen mode Exit fullscreen mode

But is that it, only advantages? (Some trade-offs of the approach)

Although this approach reads only the files necessary to compose a given training batch, it will probably increase your model's training time. If you previously had your dataset entirely in memory, the model could immediately request your data and iterate over it to train.

Now, the approach is to read, prepare, train, and discard, keeping in memory only the data that is being used. Reading this data should take some time, and the same file might be read several times if the model is configured to iterate over the data more than once.

Despite allowing greater scalability in the amount of data, this approach takes a toll on time and I/O and CPU work.

By swapping in-memory loading for lazy-evaluated generators, we successfully trade I/O speed for infinite memory scalability. This ensures our models can learn from massive datasets without melting our machines.

Thanks for reading!!!

Top comments (0)