DEV Community

Cover image for GPU-Accelerated Machine Learning in the Browser: WebGL and JavaScript Techniques That Actually Work
Nithin Bharadwaj
Nithin Bharadwaj

Posted on

GPU-Accelerated Machine Learning in the Browser: WebGL and JavaScript Techniques That Actually Work

As a best-selling author, I invite you to explore my books on Amazon. Don't forget to follow me on Medium and show your support. Thank you! Your support means the world!

Running machine learning directly in the browser used to feel like a distant dream. The very idea of a website performing complex, data-driven predictions seemed to clash with the JavaScript environment's reputation. Today, it's not only possible but practical, and the key to making it work smoothly is tapping into the raw power of your user's graphics card through WebGL.

I want to share some practical methods I've used to bring machine learning models to life in the browser. This isn't just about loading a pre-packaged model; it's about understanding how to structure, accelerate, and execute these models efficiently using JavaScript and WebGL together.

Let's start with the foundation: designing the model itself. You can't just copy a massive server-side neural network and expect it to run well in a browser. The architecture needs to be lean and purposeful.

Here is a basic structure for a neural network class in JavaScript. It sets up the layers, initializes weights intelligently, and prepares for WebGL acceleration.

class NeuralNetwork {
  constructor(config) {
    this.layers = [];
    this.weights = [];
    this.biases = [];
    this.compiled = false;
    this.glContext = null;
    this.shaderCache = new Map();

    this.buildFromConfig(config);
  }

  buildFromConfig(config) {
    const { layerSizes, activation = 'relu', weightInit = 'he' } = config;

    for (let i = 0; i < layerSizes.length - 1; i++) {
      const inputDim = layerSizes[i];
      const outputDim = layerSizes[i + 1];

      const weights = this.createWeights(inputDim, outputDim, weightInit);
      const biases = new Float32Array(outputDim).fill(0.01);

      const activationFn = this.selectActivation(
        i === layerSizes.length - 2 ? 'linear' : activation
      );

      this.layers.push({ inputDim, outputDim, weights, biases, activationFn });
      this.weights.push(weights);
      this.biases.push(biases);
    }
    this.prepareWebGL();
  }

  createWeights(inputDim, outputDim, initScheme) {
    const total = inputDim * outputDim;
    const weightArray = new Float32Array(total);

    if (initScheme === 'he') {
      const spread = Math.sqrt(2.0 / inputDim);
      for (let i = 0; i < total; i++) {
        weightArray[i] = this.gaussianRandom(0, spread);
      }
    } else if (initScheme === 'xavier') {
      const range = Math.sqrt(6.0 / (inputDim + outputDim));
      for (let i = 0; i < total; i++) {
        weightArray[i] = (Math.random() * 2 - 1) * range;
      }
    } else {
      for (let i = 0; i < total; i++) {
        weightArray[i] = (Math.random() * 2 - 1) * 0.05;
      }
    }
    return weightArray;
  }

  gaussianRandom(mean, stdev) {
    let u1 = Math.random();
    let u2 = Math.random();
    let randStdNormal = Math.sqrt(-2.0 * Math.log(u1)) * Math.cos(2.0 * Math.PI * u2);
    return mean + stdev * randStdNormal;
  }

  selectActivation(name) {
    const activations = {
      relu: {
        forward: (x) => Math.max(0, x),
        backward: (x) => (x > 0 ? 1 : 0),
      },
      sigmoid: {
        forward: (x) => 1 / (1 + Math.exp(-x)),
        backward: (x) => {
          const s = 1 / (1 + Math.exp(-x));
          return s * (1 - s);
        },
      },
      tanh: {
        forward: (x) => Math.tanh(x),
        backward: (x) => 1 - Math.tanh(x) ** 2,
      },
      linear: {
        forward: (x) => x,
        backward: (x) => 1,
      },
    };
    return activations[name] || activations.relu;
  }

  prepareWebGL() {
    const canvas = document.createElement('canvas');
    this.glContext = canvas.getContext('webgl2');
    if (!this.glContext) {
      console.log('WebGL2 context could not be created. Using CPU fallback.');
      return;
    }
    this.buildShaders();
  }
}
Enter fullscreen mode Exit fullscreen mode

The real performance gains come from moving calculations to the GPU. WebGL isn't designed for machine learning; it's designed for drawing triangles. The trick is to treat our data—weights, biases, inputs—as textures (images) and our mathematical operations as fragment shaders that run on every pixel of that texture in parallel.

Here is how you can create a shader program for the core operation of a neural network: matrix multiplication. This is where most of the computational work happens.

// Setting up a WebGL shader for matrix multiplication
function createMatrixMultProgram(gl) {
  const vertexShaderSource = `#version 300 es
    in vec2 a_position;
    out vec2 v_texCoord;
    void main() {
      v_texCoord = a_position * 0.5 + 0.5;
      gl_Position = vec4(a_position, 0.0, 1.0);
    }
  `;

  const fragmentShaderSource = `#version 300 es
    precision highp float;
    precision highp int;

    uniform sampler2D u_matrixA;
    uniform sampler2D u_matrixB;
    uniform int u_colsA;
    uniform int u_colsB;
    uniform vec2 u_texelSize;

    in vec2 v_texCoord;
    out vec4 outColor;

    void main() {
      ivec2 outputCoord = ivec2(gl_FragCoord.xy);
      int row = outputCoord.y;
      int col = outputCoord.x;

      float sum = 0.0;
      for (int k = 0; k < u_colsA; ++k) {
        float a = texelFetch(u_matrixA, ivec2(k, row), 0).r;
        float b = texelFetch(u_matrixB, ivec2(col, k), 0).r;
        sum += a * b;
      }
      outColor = vec4(sum, 0.0, 0.0, 1.0);
    }
  `;

  const vertexShader = compileShader(gl, gl.VERTEX_SHADER, vertexShaderSource);
  const fragmentShader = compileShader(gl, gl.FRAGMENT_SHADER, fragmentShaderSource);

  const program = gl.createProgram();
  gl.attachShader(program, vertexShader);
  gl.attachShader(program, fragmentShader);
  gl.linkProgram(program);

  if (!gl.getProgramParameter(program, gl.LINK_STATUS)) {
    console.error('Shader program link failed:', gl.getProgramInfoLog(program));
    return null;
  }
  return program;
}

function compileShader(gl, type, source) {
  const shader = gl.createShader(type);
  gl.shaderSource(shader, source);
  gl.compileShader(shader);
  if (!gl.getShaderParameter(shader, gl.COMPILE_STATUS)) {
    console.error('Shader compile error:', gl.getShaderInfoLog(shader));
    gl.deleteShader(shader);
    return null;
  }
  return shader;
}
Enter fullscreen mode Exit fullscreen mode

Once you have the shader, you need to manage the data. In WebGL, you store your matrix data in textures. A Float32Array of your weights or input data gets uploaded to the GPU as a texture, which the shader can then read from.

// Function to upload a matrix (stored as a Float32Array) to a WebGL texture
function matrixToTexture(gl, matrixData, rows, cols) {
  const texture = gl.createTexture();
  gl.bindTexture(gl.TEXTURE_2D, texture);

  // We need to ensure the dimensions are valid for textures
  const width = cols;
  const height = rows;
  const dataArray = new Float32Array(width * height * 4); // RGBA format

  for (let row = 0; row < height; row++) {
    for (let col = 0; col < width; col++) {
      const srcIndex = row * cols + col;
      const dstIndex = (row * width + col) * 4;
      dataArray[dstIndex] = matrixData[srcIndex]; // Store value in Red channel
      dataArray[dstIndex + 1] = 0.0; // Green
      dataArray[dstIndex + 2] = 0.0; // Blue
      dataArray[dstIndex + 3] = 1.0; // Alpha
    }
  }

  gl.texImage2D(
    gl.TEXTURE_2D,
    0,
    gl.RGBA32F,
    width,
    height,
    0,
    gl.RGBA,
    gl.FLOAT,
    dataArray
  );

  gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MIN_FILTER, gl.NEAREST);
  gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_MAG_FILTER, gl.NEAREST);
  gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_S, gl.CLAMP_TO_EDGE);
  gl.texParameteri(gl.TEXTURE_2D, gl.TEXTURE_WRAP_T, gl.CLAMP_TO_EDGE);

  gl.bindTexture(gl.TEXTURE_2D, null);
  return { texture, width, height };
}
Enter fullscreen mode Exit fullscreen mode

With the data on the GPU, you execute the shader program. This involves setting all the uniforms (the variables you send to the shader, like matrix dimensions), binding the input textures, and then drawing to an output texture.

// Executing a matrix multiplication on the GPU
function executeGPUMatMul(gl, program, textureA, textureB, colsA, colsB) {
  gl.useProgram(program);

  // Bind input textures to specific texture units
  gl.activeTexture(gl.TEXTURE0);
  gl.bindTexture(gl.TEXTURE_2D, textureA.texture);
  const locA = gl.getUniformLocation(program, 'u_matrixA');
  gl.uniform1i(locA, 0);

  gl.activeTexture(gl.TEXTURE1);
  gl.bindTexture(gl.TEXTURE_2D, textureB.texture);
  const locB = gl.getUniformLocation(program, 'u_matrixB');
  gl.uniform1i(locB, 1);

  // Set other uniforms
  gl.uniform1i(gl.getUniformLocation(program, 'u_colsA'), colsA);
  gl.uniform1i(gl.getUniformLocation(program, 'u_colsB'), colsB);
  gl.uniform2f(gl.getUniformLocation(program, 'u_texelSize'), 1.0 / textureB.width, 1.0 / textureB.height);

  // Create a framebuffer to draw the result into a texture
  const framebuffer = gl.createFramebuffer();
  gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);

  const outputTexture = gl.createTexture();
  gl.bindTexture(gl.TEXTURE_2D, outputTexture);
  // Allocate space for the output (rows of A x cols of B)
  gl.texImage2D(gl.TEXTURE_2D, 0, gl.RGBA32F, colsB, textureA.height, 0, gl.RGBA, gl.FLOAT, null);
  gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, outputTexture, 0);

  // Set up viewport and draw
  gl.viewport(0, 0, colsB, textureA.height);
  const vertices = new Float32Array([-1, -1, 1, -1, -1, 1, 1, 1]);
  const buffer = gl.createBuffer();
  gl.bindBuffer(gl.ARRAY_BUFFER, buffer);
  gl.bufferData(gl.ARRAY_BUFFER, vertices, gl.STATIC_DRAW);

  const positionLoc = gl.getAttribLocation(program, 'a_position');
  gl.enableVertexAttribArray(positionLoc);
  gl.vertexAttribPointer(positionLoc, 2, gl.FLOAT, false, 0, 0);

  gl.drawArrays(gl.TRIANGLE_STRIP, 0, 4);

  // Cleanup
  gl.bindFramebuffer(gl.FRAMEBUFFER, null);
  gl.deleteFramebuffer(framebuffer);
  gl.deleteBuffer(buffer);

  return outputTexture;
}
Enter fullscreen mode Exit fullscreen mode

After the GPU has done its work, you often need to get the data back to JavaScript to apply an activation function or prepare it for the next layer. Reading from a GPU texture back to a JavaScript array is a slow operation, so you want to design your pipeline to keep data on the GPU for as many sequential steps as possible.

// Reading data back from a GPU texture to a JavaScript array
function readTextureData(gl, texture, width, height) {
  const framebuffer = gl.createFramebuffer();
  gl.bindFramebuffer(gl.FRAMEBUFFER, framebuffer);
  gl.framebufferTexture2D(gl.FRAMEBUFFER, gl.COLOR_ATTACHMENT0, gl.TEXTURE_2D, texture, 0);

  const pixels = new Float32Array(width * height * 4); // 4 for RGBA
  gl.readPixels(0, 0, width, height, gl.RGBA, gl.FLOAT, pixels);

  gl.bindFramebuffer(gl.FRAMEBUFFER, null);
  gl.deleteFramebuffer(framebuffer);

  // Extract just the red channel where our data is stored
  const result = new Float32Array(width * height);
  for (let i = 0; i < width * height; i++) {
    result[i] = pixels[i * 4];
  }
  return result;
}
Enter fullscreen mode Exit fullscreen mode

To build a complete forward pass for a neural network, you chain these operations. Each layer involves a matrix multiply (weights * input), an addition of biases, and then an activation function. The activation function is often applied on the CPU after the GPU math, unless you write a more complex shader that includes it.

Here is a simplified forward pass method for our NeuralNetwork class, coordinating between GPU and CPU.

class NeuralNetwork {
  // ... previous constructor and setup methods ...

  forwardPassGPU(inputArray) {
    if (!this.glContext || !this.compiled) {
      return this.forwardPassCPU(inputArray); // Fallback
    }

    const gl = this.glContext;
    let currentInputTex = this.arrayToTexture(gl, inputArray, 1, inputArray.length);

    for (let i = 0; i < this.layers.length; i++) {
      const layer = this.layers[i];
      const weightTex = this.weightTextures[i]; // Pre-uploaded
      const biasTex = this.biasTextures[i];     // Pre-uploaded

      // 1. Matrix Multiply: input * weights
      const matmulResultTex = this.executeLayerMatMul(gl, currentInputTex, weightTex, layer.inputDim, layer.outputDim);

      // 2. Add Bias (using another custom shader)
      const biasedResultTex = this.executeAddBias(gl, matmulResultTex, biasTex, layer.outputDim);

      // 3. Read data back for activation on CPU (for simplicity)
      const layerOutputArray = this.readTextureData(gl, biasedResultTex, layer.outputDim, 1);

      // 4. Apply Activation Function
      for (let j = 0; j < layerOutputArray.length; j++) {
        layerOutputArray[j] = layer.activationFn.forward(layerOutputArray[j]);
      }

      // Cleanup old texture, prepare for next layer
      gl.deleteTexture(currentInputTex);
      if (i < this.layers.length - 1) {
        currentInputTex = this.arrayToTexture(gl, layerOutputArray, 1, layer.outputDim);
        gl.deleteTexture(matmulResultTex);
        gl.deleteTexture(biasedResultTex);
      } else {
        // Final output
        return layerOutputArray;
      }
    }
  }

  forwardPassCPU(inputArray) {
    let currentActivation = inputArray;
    for (const layer of this.layers) {
      const newActivation = new Float32Array(layer.outputDim);
      const weights = layer.weights;
      const biases = layer.biases;

      // Simple CPU matrix-vector multiplication
      for (let j = 0; j < layer.outputDim; j++) {
        let sum = biases[j];
        for (let k = 0; k < layer.inputDim; k++) {
          sum += currentActivation[k] * weights[j * layer.inputDim + k];
        }
        newActivation[j] = layer.activationFn.forward(sum);
      }
      currentActivation = newActivation;
    }
    return currentActivation;
  }
}
Enter fullscreen mode Exit fullscreen mode

Optimization is crucial. A naive implementation will be slower than a pure CPU version due to the overhead of talking to the GPU. The goal is to minimize the number of readbacks (getting data from GPU to CPU) and batch operations. For instance, if you're processing multiple inputs at once (a batch), you can structure your data as a wider texture and process it all in one shader execution.

// Example: A batch-forward shader that processes 4 inputs simultaneously
const batchForwardShader = `#version 300 es
  precision highp float;
  uniform sampler2D u_inputBatch; // 4 inputs stacked
  uniform sampler2D u_weights;
  uniform sampler2D u_biases;
  uniform int u_inputSize;
  uniform int u_outputSize;

  in vec2 v_texCoord;
  out vec4 outColor;

  void main() {
    ivec2 coord = ivec2(gl_FragCoord.xy);
    int outputNeuron = coord.x;
    int batchItem = coord.y; // 0, 1, 2, or 3

    float sum = texelFetch(u_biases, ivec2(outputNeuron, 0), 0).r;

    for (int k = 0; k < u_inputSize; ++k) {
      float inputVal = texelFetch(u_inputBatch, ivec2(k, batchItem), 0).r;
      float weightVal = texelFetch(u_weights, ivec2(outputNeuron, k), 0).r;
      sum += inputVal * weightVal;
    }
    // Apply ReLU activation
    outColor = vec4(max(sum, 0.0), 0.0, 0.0, 1.0);
  }
`;
Enter fullscreen mode Exit fullscreen mode

Another important technique involves quantizing your model. Weights stored as 32-bit floating-point numbers are precise but large. You can often convert them to 16-bit or even 8-bit integers with minimal accuracy loss for inference. This reduces memory bandwidth, which is a major bottleneck for GPU performance.

// Function to quantize 32-bit float weights to 8-bit integers
function quantizeWeightsToUint8(floatWeights) {
  // Find the range of the weights
  let min = floatWeights[0];
  let max = floatWeights[0];
  for (const val of floatWeights) {
    if (val < min) min = val;
    if (val > max) max = val;
  }
  const scale = 255.0 / (max - min);
  const zeroPoint = Math.round(-min * scale);

  const uint8Weights = new Uint8Array(floatWeights.length);
  for (let i = 0; i < floatWeights.length; i++) {
    const quantized = Math.round(floatWeights[i] * scale + zeroPoint);
    uint8Weights[i] = Math.max(0, Math.min(255, quantized));
  }
  return {
    data: uint8Weights,
    scale: scale,
    zeroPoint: zeroPoint,
    originalMin: min,
    originalMax: max,
  };
}

// A shader that can de-quantize 8-bit data on the fly
const quantizedMatMulShader = `#version 300 es
  precision highp float;
  uniform usampler2D u_weightsQuantized; // Uint8 texture
  uniform sampler2D u_input;
  uniform float u_weightScale;
  uniform float u_weightZeroPoint;
  uniform int u_inputSize;

  in vec2 v_texCoord;
  out vec4 outColor;

  void main() {
    ivec2 coord = ivec2(gl_FragCoord.xy);
    int outputNeuron = coord.x;

    float sum = 0.0;
    for (int k = 0; k < u_inputSize; ++k) {
      float inputVal = texelFetch(u_input, ivec2(k, 0), 0).r;
      uint weightQuantized = texelFetch(u_weightsQuantized, ivec2(outputNeuron, k), 0).r;
      float weightVal = (float(weightQuantized) - u_weightZeroPoint) / u_weightScale;
      sum += inputVal * weightVal;
    }
    outColor = vec4(sum, 0.0, 0.0, 1.0);
  }
`;
Enter fullscreen mode Exit fullscreen mode

Finally, integrating this into a real application requires thoughtful design. You need to load your pre-trained model weights, which could be fetched from a server. You then initialize the network, upload the weights to GPU textures, and are ready to run inference.

// Main integration example
async function initMLInference() {
  // 1. Fetch pre-trained model data (e.g., from a .bin or .json file)
  const response = await fetch('/model/weights.bin');
  const buffer = await response.arrayBuffer();
  const weightData = new Float32Array(buffer);

  // 2. Define the model architecture (e.g., a simple classifier)
  const modelConfig = {
    layerSizes: [784, 128, 64, 10], // MNIST-like classifier
    activation: 'relu',
    weightInit: 'he',
  };

  // 3. Create the neural network
  const model = new NeuralNetwork(modelConfig);

  // 4. Load the fetched weights into the model's structure
  model.loadPretrainedWeights(weightData);

  // 5. Compile the network for GPU execution
  await model.compileForGPU();

  console.log('Model is ready for GPU-accelerated inference.');

  // 6. Example usage: run inference on some input data
  const dummyInput = new Float32Array(784).fill(0.1);
  const startTime = performance.now();
  const predictions = model.forwardPassGPU(dummyInput);
  const endTime = performance.now();

  console.log(`Inference took ${(endTime - startTime).toFixed(2)} ms`);
  console.log('Predictions:', predictions);
}

// Add a method to load pre-trained weights
NeuralNetwork.prototype.loadPretrainedWeights = function (flatWeightArray) {
  let offset = 0;
  for (let i = 0; i < this.layers.length; i++) {
    const layer = this.layers[i];
    const weightSize = layer.inputDim * layer.outputDim;
    const biasSize = layer.outputDim;

    layer.weights.set(flatWeightArray.subarray(offset, offset + weightSize));
    offset += weightSize;

    layer.biases.set(flatWeightArray.subarray(offset, offset + biasSize));
    offset += biasSize;
  }
  // After loading, we can upload these to GPU textures
  this.uploadWeightsToGPU();
};
Enter fullscreen mode Exit fullscreen mode

Working with WebGL for machine learning feels like a different kind of programming. You're managing two separate processors: the general-purpose CPU and the highly parallel GPU. Debugging can be challenging because you can't simply console.log a value inside a shader. Instead, you might write shaders that output debug colors or carefully read back intermediate textures.

The payoff, however, is remarkable. Tasks that would choke the main thread, like running a convolutional network on a video stream, become feasible. You can build interactive experiences that respond intelligently in real-time, all within the security and portability of the web browser. By mastering these techniques—efficient model design, GPU data management, shader programming, operation batching, and model quantization—you can push the boundaries of what's possible in a web application. It turns the browser from a simple document viewer into a powerful, personal computing platform for intelligent software.

📘 Checkout my latest ebook for free on my channel!

Be sure to like, share, comment, and subscribe to the channel!


101 Books

101 Books is an AI-driven publishing company co-founded by author Aarav Joshi. By leveraging advanced AI technology, we keep our publishing costs incredibly low—some books are priced as low as $4—making quality knowledge accessible to everyone.

Check out our book Golang Clean Code available on Amazon.

Stay tuned for updates and exciting news. When shopping for books, search for Aarav Joshi to find more of our titles. Use the provided link to enjoy special discounts!

Our Creations

Be sure to check out our creations:

Investor Central | Investor Central Spanish | Investor Central German | Smart Living | Epochs & Echoes | Puzzling Mysteries | Hindutva | Elite Dev | Java Elite Dev | Golang Elite Dev | Python Elite Dev | JS Elite Dev | JS Schools


We are on Medium

Tech Koala Insights | Epochs & Echoes World | Investor Central Medium | Puzzling Mysteries Medium | Science & Epochs Medium | Modern Hindutva

Top comments (0)