DEV Community

Cover image for How to Run Machine Learning Models Directly in the Browser With JavaScript
Nithin Bharadwaj
Nithin Bharadwaj

Posted on

How to Run Machine Learning Models Directly in the Browser With JavaScript

As a best-selling author, I invite you to explore my books on Amazon. Don't forget to follow me on Medium and show your support. Thank you! Your support means the world!

The shift is noticeable. We're moving the intelligence out of the data center and directly into the browser. This changes things. I can now build web applications that see, understand, and react without constantly asking a server for permission. It makes things feel instant, keeps user data private, and honestly, it's just more elegant. Let me show you how this works, from the ground up.

Think about uploading a photo. The old way sends that photo across the internet, a server analyzes it, and sends back a result. The new way? The analysis happens right here, on this device, the moment you select the file. Nothing leaves. That’s the core idea. The web browser becomes a capable machine learning engine.

This matters for privacy. When a model runs on your device, your data stays with you. Your camera feed, your documents, your voice—they don't need to travel to an external server to be understood by an AI. This local processing builds user trust. It also makes applications work offline and reduces latency to near zero. The feedback is immediate.

Frameworks like TensorFlow.js made this practical. They act as a bridge, taking models trained in Python and letting them run inside JavaScript. The system handles the messy details: finding the GPU, converting data into the right format, and managing memory so your tab doesn't crash.

Here’s a basic setup for an image classifier. I start by creating a class that handles the lifecycle: loading, warming up, and running the model. Notice how it checks for the best available backend—WebGL for GPU power, WebAssembly for consistent CPU speed, or a plain JavaScript fallback.

class SimpleClassifier {
  async initialize(modelPath, labelsPath) {
    // Find the fastest backend for this device
    await this.setupBackend();

    // Load model and its labels simultaneously
    const [model, labelsRes] = await Promise.all([
      tf.loadGraphModel(modelPath),
      fetch(labelsPath)
    ]);

    this.model = model;
    this.labels = await labelsRes.json();

    // A quick first run to prepare the system
    await this.warmup();
    return this;
  }

  async setupBackend() {
    // Try backends in order of performance
    const backends = ['webgl', 'wasm', 'cpu'];
    for (const backend of backends) {
      if (await tf.setBackend(backend)) {
        console.log(`Backend: ${backend}`);
        break;
      }
    }
  }

  async classify(imgElement) {
    // Convert image to a numerical tensor
    let tensor = tf.browser.fromPixels(imgElement)
      .resizeBilinear([224, 224])  // Model expects this size
      .toFloat()
      .div(255)                    // Normalize pixels to 0-1
      .expandDims(0);              // Add batch dimension

    // Make the prediction
    const prediction = this.model.predict(tensor);
    const scores = await prediction.data();

    // Clean up memory immediately
    tensor.dispose();
    prediction.dispose();

    // Match scores to labels
    return this.labels
      .map((label, idx) => ({ label, score: scores[idx] }))
      .sort((a, b) => b.score - a.score)
      .slice(0, 5); // Return top 5
  }
}

// Using it is straightforward
async function handleImageUpload(file) {
  const classifier = await new SimpleClassifier().initialize(
    '/model/model.json',
    '/model/labels.json'
  );

  const img = document.createElement('img');
  img.src = URL.createObjectURL(file);

  await img.decode(); // Wait for image to load
  const results = await classifier.classify(img);
  console.log(results);
}
Enter fullscreen mode Exit fullscreen mode

This approach works, but for production, I need more. A real application must handle different device capabilities, manage memory carefully, and provide feedback. Let's build a more complete version. This one includes progress tracking, error handling, and memory monitoring.

class RobustImageClassifier {
  constructor() {
    this.model = null;
    this.labels = [];
    this.status = 'idle'; // idle, loading, ready, error
    this.memoryMonitor = null;
  }

  async initialize(modelConfig) {
    this.status = 'loading';

    try {
      // Dynamic backend selection with testing
      await this.selectOptimalBackend();

      // Load with progress events
      this.model = await this.loadModelWithProgress(modelConfig.url);
      this.labels = await this.fetchLabels(modelConfig.labelsUrl);

      // Initial warm-up run
      await this.performWarmup();

      this.status = 'ready';
      this.startMemoryMonitoring();
      return true;

    } catch (error) {
      this.status = 'error';
      console.error('Initialization failed:', error);
      return false;
    }
  }

  async selectOptimalBackend() {
    // Test each backend with a small computation
    const backends = ['webgl', 'wasm', 'cpu'];
    for (const backend of backends) {
      try {
        await tf.setBackend(backend);
        // Quick test
        const test = tf.tensor1d([1, 2, 3]).square();
        await test.data();
        test.dispose();
        console.log(`Selected backend: ${backend}`);
        return;
      } catch (e) {
        console.log(`${backend} failed, trying next...`);
      }
    }
    throw new Error('No suitable backend found');
  }

  async loadModelWithProgress(modelUrl) {
    // Use fetch to track download progress
    const response = await fetch(modelUrl);
    const contentLength = response.headers.get('content-length');

    if (!contentLength) {
      return await tf.loadGraphModel(modelUrl);
    }

    const total = parseInt(contentLength, 10);
    let loaded = 0;

    // Read stream and track progress
    const reader = response.body.getReader();
    const chunks = [];

    while (true) {
      const { done, value } = await reader.read();
      if (done) break;

      chunks.push(value);
      loaded += value.length;

      // Emit progress (e.g., 0.75 for 75%)
      const progress = loaded / total;
      this.emit('load-progress', { progress });
    }

    // Reconstruct and load model
    const blob = new Blob(chunks);
    const url = URL.createObjectURL(blob);
    const model = await tf.loadGraphModel(url);
    URL.revokeObjectURL(url);

    return model;
  }

  async performWarmup() {
    // Create dummy input matching model expectations
    const [batch, height, width, channels] = this.model.inputs[0].shape;
    const dummyInput = tf.zeros([batch, height, width, channels]);

    // First run is often slower due to compilation
    const start = performance.now();
    const warmupResult = this.model.predict(dummyInput);
    await warmupResult.data();
    const time = performance.now() - start;

    // Clean up
    dummyInput.dispose();
    warmupResult.dispose();

    console.log(`Warmup completed in ${time.toFixed(1)}ms`);
  }

  startMemoryMonitoring() {
    // Check memory usage every 30 seconds
    this.memoryMonitor = setInterval(() => {
      const memory = tf.memory();
      if (memory.numBytes > 50 * 1024 * 1024) { // 50MB threshold
        console.warn('High memory usage:', memory);
        this.cleanupTensors();
      }
    }, 30000);
  }

  cleanupTensors() {
    // Force TensorFlow.js to clean up unused memory
    tf.engine().startScope();
    tf.engine().endScope();
  }

  async classify(imageElement, options = {}) {
    if (this.status !== 'ready') {
      throw new Error('Model not ready');
    }

    const startTime = performance.now();

    try {
      // Preprocess
      const inputTensor = this.preprocessImage(imageElement);

      // Predict
      const prediction = this.model.predict(inputTensor);
      const results = await this.processOutput(prediction, options);

      // Timing
      results.inferenceTime = performance.now() - startTime;

      return results;

    } finally {
      // Ensure cleanup even if errors occur
      this.cleanupTensors();
    }
  }

  preprocessImage(img) {
    // Convert to tensor, resize, normalize
    return tf.tidy(() => {
      return tf.browser.fromPixels(img)
        .resizeBilinear([224, 224])
        .toFloat()
        .div(255.0)
        .expandDims(0);
    });
  }

  async processOutput(predictionTensor, options) {
    const { topK = 3, threshold = 0.1 } = options;
    const scores = await predictionTensor.data();
    predictionTensor.dispose();

    // Process and filter results
    return this.labels
      .map((label, index) => ({
        label,
        confidence: scores[index],
        index
      }))
      .filter(item => item.confidence >= threshold)
      .sort((a, b) => b.confidence - a.confidence)
      .slice(0, topK)
      .map(item => ({
        ...item,
        confidence: Math.round(item.confidence * 10000) / 100 // As percentage
      }));
  }
}
Enter fullscreen mode Exit fullscreen mode

But what about newer, faster APIs? Browsers are starting to expose direct machine learning hardware. The Web Neural Network API (WebNN) gives me lower-level access to GPUs and AI accelerators. It's more verbose but can be significantly faster.

class WebNNDetector {
  async initialize() {
    // Check for browser support
    if (!('ml' in navigator)) {
      throw new Error('WebNN not available');
    }

    // Create a context for ML operations
    this.context = await navigator.ml.createContext();

    // Define a simple model: object detection
    const builder = this.context.createModelBuilder();

    // Input: image tensor [1, 300, 300, 3]
    const input = builder.input('image', { 
      type: 'float32', 
      dimensions: [1, 300, 300, 3] 
    });

    // Example detection pipeline (simplified)
    const conv1 = builder.conv2d(input, 
      builder.constant({ type: 'float32', dimensions: [3, 3, 3, 16] }), 
      { strides: [2, 2], padding: [1, 1, 1, 1] }
    );

    const relu1 = builder.relu(conv1);

    // More layers would follow in a real model...
    const output = builder.softmax(relu1);

    // Build and compile the model
    const model = builder.build({ output });
    this.compiledModel = await model.compile();

    return this;
  }

  async detectFromCanvas(canvasElement) {
    // Get image data
    const ctx = canvasElement.getContext('2d');
    const imageData = ctx.getImageData(0, 0, 300, 300);

    // Prepare tensor from image data
    const tensorData = new Float32Array(300 * 300 * 3);
    for (let i = 0; i < imageData.data.length; i += 4) {
      const pixelIndex = i / 4;
      tensorData[pixelIndex * 3] = imageData.data[i] / 255;     // R
      tensorData[pixelIndex * 3 + 1] = imageData.data[i + 1] / 255; // G
      tensorData[pixelIndex * 3 + 2] = imageData.data[i + 2] / 255; // B
    }

    // Create WebNN tensor
    const inputTensor = this.context.createTensor(
      { type: 'float32', dimensions: [1, 300, 300, 3] },
      tensorData
    );

    // Execute
    const outputs = await this.compiledModel.compute({ 'image': inputTensor });

    // Process outputs
    const results = [];
    for (const output of outputs) {
      const data = await output.getData();
      results.push({ name: output.name, data });
    }

    return results;
  }
}
Enter fullscreen mode Exit fullscreen mode

Raw speed is only part of the story. The models themselves need to be tailored for the browser. They must be small enough to download quickly and efficient enough to run smoothly on a phone. This is where model optimization comes in. Techniques like quantization reduce precision to shrink size and speed up computation.

Let's say I have a model that uses 32-bit floating point numbers. Quantization might convert it to use 8-bit integers. The file becomes about 75% smaller, and the math gets faster, with usually minor accuracy trade-offs.

async function optimizeModelForWeb(originalModelBuffer) {
  // This simulates a quantization process
  const original = new Float32Array(originalModelBuffer);
  const quantized = new Uint8Array(original.length);

  // Find range for scaling
  const min = Math.min(...original);
  const max = Math.max(...original);
  const scale = (max - min) / 255;

  // Convert each value
  for (let i = 0; i < original.length; i++) {
    quantized[i] = Math.round((original[i] - min) / scale);
  }

  // Return both data and the parameters needed to de-quantize later
  return {
    data: quantized,
    quantizationParams: { min, scale, originalType: 'float32' }
  };
}

// When using the quantized model, I convert back
function dequantize(quantizedArray, params) {
  const { min, scale } = params;
  const dequantized = new Float32Array(quantizedArray.length);

  for (let i = 0; i < quantizedArray.length; i++) {
    dequantized[i] = quantizedArray[i] * scale + min;
  }

  return dequantized;
}
Enter fullscreen mode Exit fullscreen mode

In practice, I often use pre-optimized model architectures designed for mobile and web. Models like MobileNet, EfficientNet-Lite, or MediaPipe solutions are built with these constraints in mind. They provide a good balance of accuracy, size, and speed.

The real magic happens when I combine these pieces for real-time interaction. Consider a camera-based application that guides you through exercises, detects objects in real-time, or translates text through your phone's camera. The flow is continuous and immediate.

class CameraMLProcessor {
  constructor(model) {
    this.model = model;
    this.video = document.createElement('video');
    this.canvas = document.createElement('canvas');
    this.ctx = this.canvas.getContext('2d');
    this.active = false;
    this.frameQueue = [];
    this.processing = false;
  }

  async start() {
    // Access camera
    const stream = await navigator.mediaDevices.getUserMedia({
      video: { width: 640, height: 480 }
    });

    this.video.srcObject = stream;
    await this.video.play();

    this.active = true;
    this.processFrames();
  }

  async processFrames() {
    while (this.active) {
      // Capture frame
      this.canvas.width = this.video.videoWidth;
      this.canvas.height = this.video.videoHeight;
      this.ctx.drawImage(this.video, 0, 0);

      // Process if not already busy
      if (!this.processing) {
        this.processing = true;

        try {
          const results = await this.model.classify(this.canvas);
          this.onResults(results); // Handle results
        } catch (error) {
          console.error('Frame processing error:', error);
        } finally {
          this.processing = false;
        }
      }

      // Yield to browser
      await new Promise(resolve => requestAnimationFrame(resolve));
    }
  }

  onResults(results) {
    // Draw bounding boxes, labels, etc.
    this.ctx.clearRect(0, 0, this.canvas.width, this.canvas.height);
    this.ctx.drawImage(this.video, 0, 0);

    results.forEach(obj => {
      this.ctx.strokeStyle = '#00ff00';
      this.ctx.lineWidth = 2;
      this.ctx.strokeRect(obj.x, obj.y, obj.width, obj.height);

      this.ctx.fillStyle = '#00ff00';
      this.ctx.fillText(
        `${obj.label} (${obj.confidence}%)`, 
        obj.x, 
        obj.y > 10 ? obj.y - 5 : 10
      );
    });
  }

  stop() {
    this.active = false;
    if (this.video.srcObject) {
      this.video.srcObject.getTracks().forEach(track => track.stop());
    }
  }
}
Enter fullscreen mode Exit fullscreen mode

Memory management is critical in these applications. Tensors—the multi-dimensional arrays that hold model data—can accumulate quickly. If I don't clean them up, the browser tab will eventually freeze or crash. TensorFlow.js provides tools to help, but I need to be disciplined.

// Good practice: Use tf.tidy() to auto-clean
function safePrediction(model, input) {
  return tf.tidy(() => {
    const tensor = tf.browser.fromPixels(input)
      .resizeBilinear([224, 224])
      .toFloat();
    return model.predict(tensor);
  });
  // All tensors created inside tidy() are disposed automatically
}

// Manual cleanup when tidy isn't possible
async function classifyWithManualCleanup(model, image) {
  let tensor, prediction;

  try {
    tensor = tf.browser.fromPixels(image)
      .resizeBilinear([224, 224])
      .toFloat()
      .expandDims(0);

    prediction = model.predict(tensor);
    const results = await prediction.data();

    return processResults(results);

  } finally {
    // Always clean up, even if errors occur
    if (tensor) tensor.dispose();
    if (prediction) prediction.dispose();
  }
}

// Monitor memory usage
function setupMemoryMonitor() {
  setInterval(() => {
    const mem = tf.memory();
    console.log(`Tensors: ${mem.numTensors}, Memory: ${(mem.numBytes / 1024 / 1024).toFixed(2)}MB`);

    if (mem.numBytes > 100 * 1024 * 1024) { // 100MB threshold
      console.warn('High memory - forcing cleanup');
      tf.engine().startScope();
      tf.engine().endScope();
    }
  }, 10000);
}
Enter fullscreen mode Exit fullscreen mode

What about loading these models? Some can be several megabytes. I don't want to block the main thread or make users wait. Progressive loading and caching strategies help.

class ModelLoader {
  constructor() {
    this.cache = new Map();
    this.pending = new Map();
  }

  async load(modelUrl, options = {}) {
    const { cacheKey = modelUrl, priority = 'high' } = options;

    // Return cached model if available
    if (this.cache.has(cacheKey)) {
      return this.cache.get(cacheKey);
    }

    // Join existing request if already loading
    if (this.pending.has(cacheKey)) {
      return this.pending.get(cacheKey);
    }

    // Create new loading promise
    const loadPromise = this.createLoadPromise(modelUrl, priority);
    this.pending.set(cacheKey, loadPromise);

    try {
      const model = await loadPromise;
      this.cache.set(cacheKey, model);
      return model;
    } finally {
      this.pending.delete(cacheKey);
    }
  }

  async createLoadPromise(modelUrl, priority) {
    // Use fetch with priority hint
    const response = await fetch(modelUrl, { priority });

    if (!response.ok) {
      throw new Error(`Failed to load: ${response.status}`);
    }

    // For very large models, consider streaming
    const reader = response.body.getReader();
    const chunks = [];
    let received = 0;

    while (true) {
      const { done, value } = await reader.read();
      if (done) break;

      chunks.push(value);
      received += value.length;

      // Could update a progress bar here
      this.updateProgress(received);
    }

    // Combine chunks and load model
    const blob = new Blob(chunks);
    const url = URL.createObjectURL(blob);
    const model = await tf.loadGraphModel(url);
    URL.revokeObjectURL(url);

    return model;
  }

  updateProgress(bytesReceived) {
    // Dispatch event or update UI
    const event = new CustomEvent('modelloadprogress', {
      detail: { bytesReceived }
    });
    window.dispatchEvent(event);
  }
}
Enter fullscreen mode Exit fullscreen mode

Putting it all together, the development pattern becomes clear. I start with a use case that benefits from immediate, private processing. I select or train a model optimized for size and speed. I build an interface that loads the model efficiently, processes input locally, provides real-time feedback, and carefully manages resources.

The applications are growing. I've built tools that let artists apply style transfer to photos without uploading them. Educational apps that provide real-time pronunciation feedback. Accessibility tools that describe scenes for visually impaired users. All running completely in the browser.

The limitations are still there. Very large models or complex training still need server infrastructure. But for inference—applying already-trained knowledge—the browser has become remarkably capable. As device hardware improves and browser APIs mature, this boundary will keep expanding.

For developers entering this space, my advice is to start simple. Take a pre-optimized model from TensorFlow Hub or MediaPipe. Build something that works locally first. Understand the memory and performance characteristics. Then incrementally add complexity: real-time camera feeds, multiple model coordination, offline support.

The result is a different kind of web application. One that feels responsive in a fundamental way, respects user privacy by design, and works consistently regardless of network quality. It's not just about doing machine learning on the web. It's about making the web itself more intelligent, capable, and respectful of the people using it.

📘 Checkout my latest ebook for free on my channel!

Be sure to like, share, comment, and subscribe to the channel!


101 Books

101 Books is an AI-driven publishing company co-founded by author Aarav Joshi. By leveraging advanced AI technology, we keep our publishing costs incredibly low—some books are priced as low as $4—making quality knowledge accessible to everyone.

Check out our book Golang Clean Code available on Amazon.

Stay tuned for updates and exciting news. When shopping for books, search for Aarav Joshi to find more of our titles. Use the provided link to enjoy special discounts!

Our Creations

Be sure to check out our creations:

Investor Central | Investor Central Spanish | Investor Central German | Smart Living | Epochs & Echoes | Puzzling Mysteries | Hindutva | Elite Dev | Java Elite Dev | Golang Elite Dev | Python Elite Dev | JS Elite Dev | JS Schools


We are on Medium

Tech Koala Insights | Epochs & Echoes World | Investor Central Medium | Puzzling Mysteries Medium | Science & Epochs Medium | Modern Hindutva

Top comments (0)