DEV Community

monkeymore studio
monkeymore studio

Posted on

Building a Browser-Based AI Image Inpainting Tool

Introduction

In this article, we'll explore how to implement a browser-based AI image inpainting tool that allows users to remove unwanted objects, watermarks, or people from photos. The tool runs entirely in the browser using ONNX Runtime Web with WebGPU acceleration, providing professional-grade results without sending any images to a server.

Why Browser-Based Inpainting?

1. Privacy Protection

When users inpaint images in the browser, their photos never leave their device. This is essential for:

  • Personal photos with sensitive content
  • Business documents
  • Medical images
  • Any content users want to keep private

2. Zero Server Costs

Running the AI model in the browser eliminates the need for:

  • GPU servers for AI inference
  • Bandwidth for uploading/downloading images
  • Storage for temporary files

3. Offline Capability

Once the model is loaded, users can inpaint images without an internet connection.

Technical Architecture

Core Implementation

1. Data Structures

interface InpaintClientProps {
  lang: Locale;
}

// Key state variables
const [imageSrc, setImageSrc] = useState<string>("");
const [resultImage, setResultImage] = useState<string | null>(null);
const [isProcessing, setIsProcessing] = useState(false);
const [brushSize, setBrushSize] = useState(30);
const [modelLoaded, setModelLoaded] = useState(false);

// Refs for canvas operations
const canvasRef = useRef<HTMLCanvasElement>(null);
const maskCanvasRef = useRef<HTMLCanvasElement>(null);
const imageRef = useRef<HTMLImageElement | null>(null);
const sessionRef = useRef<ort.InferenceSession | null>(null);

// Drawing state
const isDrawing = useRef(false);
const lastPos = useRef({ x: 0, y: 0 });
const history = useRef<string[]>([]);
const historyIndex = useRef(-1);
Enter fullscreen mode Exit fullscreen mode

2. Loading the AI Model

The tool uses the MIGAN (Multi-scale Inpainting Generative Adversarial Network) model from HuggingFace:

const MODEL_URL = "https://huggingface.co/andraniksargsyan/migan/resolve/main/migan_pipeline_v2.onnx";

const loadModel = async () => {
  if (sessionRef.current) return;

  setIsModelLoading(true);
  const totalSize = 80 * 1024 * 1024; // ~80MB model

  // Download model with progress tracking
  const response = await fetch(MODEL_URL);
  const reader = response.body?.getReader();
  const chunks: Uint8Array[] = [];
  let receivedLength = 0;

  while (true) {
    const { done, value } = await reader.read();
    if (done) break;
    chunks.push(value);
    receivedLength += value.length;
    const percent = Math.round((receivedLength / totalSize) * 100);
    setDownloadProgress(Math.min(percent, 99));
  }

  // Combine chunks into single buffer
  const modelBuffer = new Uint8Array(
    chunks.reduce((acc, val) => acc + val.length, 0)
  ).buffer;

  // Create ONNX Runtime session with WebGPU preferred
  const session = await ort.InferenceSession.create(modelBuffer, {
    executionProviders: ["webgpu", "wasm"]
  });

  sessionRef.current = session;
  setModelLoaded(true);
  console.log("Model loaded, input names:", session.inputNames);
};
Enter fullscreen mode Exit fullscreen mode

3. Image and Canvas Setup

useEffect(() => {
  if (!imageSrc || !canvasRef.current) return;

  const img = new Image();
  img.onload = () => {
    imageRef.current = img;
    const canvas = canvasRef.current;
    const maskCanvas = maskCanvasRef.current;

    // Limit max size to 1024px for performance
    const maxSize = 1024;
    let width = img.width;
    let height = img.height;

    if (width > maxSize || height > maxSize) {
      const ratio = Math.min(maxSize / width, maxSize / height);
      width = Math.round(width * ratio);
      height = Math.round(height * ratio);
    }

    canvas.width = width;
    canvas.height = height;
    maskCanvas.width = width;
    maskCanvas.height = height;

    // Draw original image
    ctx.drawImage(img, 0, 0, width, height);
    maskCtx.clearRect(0, 0, maskCanvas.width, maskCanvas.height);

    // Initialize history for undo/redo
    history.current = [canvas.toDataURL("image/png")];
    historyIndex.current = 0;
  };
  img.src = imageSrc;
}, [imageSrc]);
Enter fullscreen mode Exit fullscreen mode

4. Mask Drawing (User Interaction)

Users draw on the image to mark areas they want to remove:

const draw = (e: React.MouseEvent<HTMLCanvasElement> | React.TouchEvent<HTMLCanvasElement>) => {
  const canvas = canvasRef.current;
  const maskCanvas = maskCanvasRef.current;
  if (!canvas || !maskCanvas || !isDrawing.current) return;

  const ctx = canvas.getContext("2d");
  const maskCtx = maskCanvas.getContext("2d");

  const coords = getCanvasCoords(e);
  if (!coords) return;

  // Draw red stroke on main canvas (visual guide)
  ctx.strokeStyle = "rgba(255, 0, 0, 0.5)";
  ctx.lineWidth = brushSize;
  ctx.lineCap = "round";
  ctx.lineJoin = "round";
  ctx.beginPath();
  ctx.moveTo(lastPos.current.x, lastPos.current.y);
  ctx.lineTo(coords.x, coords.y);
  ctx.stroke();

  // Draw white stroke on mask canvas (for AI model)
  maskCtx.strokeStyle = "white";
  maskCtx.lineWidth = brushSize;
  maskCtx.lineCap = "round";
  maskCtx.lineJoin = "round";
  maskCtx.beginPath();
  maskCtx.moveTo(lastPos.current.x, lastPos.current.y);
  maskCtx.lineTo(coords.x, coords.y);
  maskCtx.stroke();

  lastPos.current = { x: coords.x, y: coords.y };
};
Enter fullscreen mode Exit fullscreen mode

5. Image to Tensor Conversion

The model expects input as tensors in the format [1, 3, H, W] for image and [1, 1, H, W] for mask:

// Convert ImageData to model input tensor
const imageToTensor = (imageData: ImageData): Float32Array => {
  const { width, height, data } = imageData;
  const tensor = new Float32Array(width * height * 3);

  // Convert RGB to normalized float (0-1 range)
  // CHW format: R first, then G, then B
  for (let i = 0; i < width * height; i++) {
    tensor[i] = data[i * 4] / 255;                                    // R channel
    tensor[i + width * height] = data[i * 4 + 1] / 255;              // G channel
    tensor[i + width * height * 2] = data[i * 4 + 2] / 255;          // B channel
  }
  return tensor;
};

// Convert tensor back to ImageData for display
const tensorToImage = (tensor: Float32Array, width: number, height: number): ImageData => {
  const data = new Uint8ClampedArray(width * height * 4);

  // CHW to RGBA conversion
  for (let i = 0; i < width * height; i++) {
    data[i * 4] = Math.max(0, Math.min(255, Math.round(tensor[i] * 255)));
    data[i * 4 + 1] = Math.max(0, Math.min(255, Math.round(tensor[i + width * height] * 255)));
    data[i * 4 + 2] = Math.max(0, Math.min(255, Math.round(tensor[i + width * height * 2] * 255)));
    data[i * 4 + 3] = 255; // Alpha
  }
  return new ImageData(data, width, height);
};
Enter fullscreen mode Exit fullscreen mode

6. Running the Inpainting Model

const applyInpaint = async () => {
  const canvas = canvasRef.current;
  const maskCanvas = maskCanvasRef.current;
  if (!canvas || !maskCanvas) return;

  setIsProcessing(true);

  const ctx = canvas.getContext("2d");
  const maskCtx = maskCanvas.getContext("2d");

  const width = canvas.width;
  const height = canvas.height;

  // Get clean original image (without red drawing marks)
  const imageData = ctx.getImageData(0, 0, width, height);
  const maskData = maskCtx.getImageData(0, 0, width, height);

  // Check if user has drawn a mask
  let whitePixels = 0;
  for (let i = 0; i < maskData.data.length; i += 4) {
    if (maskData.data[i] > 128) whitePixels++;
  }

  if (whitePixels === 0) {
    alert("Please draw on the areas you want to remove first");
    setIsProcessing(false);
    return;
  }

  if (sessionRef.current) {
    const modelSize = 512; // Model expects 512x512 input

    // Create separate canvases for image and mask
    const imgCanvas = document.createElement('canvas');
    imgCanvas.width = modelSize;
    imgCanvas.height = modelSize;
    const imgCtx = imgCanvas.getContext('2d');

    // Draw and resize image to 512x512
    const originalImg = new Image();
    originalImg.src = canvas.toDataURL('image/png');
    await new Promise((resolve) => { originalImg.onload = resolve; });
    imgCtx.drawImage(originalImg, 0, 0, modelSize, modelSize);
    const resizedImageData = imgCtx.getImageData(0, 0, modelSize, modelSize);

    // Create and resize mask
    const maskCanvas512 = document.createElement('canvas');
    maskCanvas512.width = modelSize;
    maskCanvas512.height = modelSize;
    const maskCtx512 = maskCanvas512.getContext('2d');

    const maskImg = new Image();
    maskImg.src = maskCanvas.toDataURL('image/png');
    await new Promise((resolve) => { maskImg.onload = resolve; });
    maskCtx512.drawImage(maskImg, 0, 0, modelSize, modelSize);
    const resizedMaskData = maskCtx512.getImageData(0, 0, modelSize, modelSize);

    // Prepare image data for model (CHW format)
    const imgData = new Uint8Array(modelSize * modelSize * 3);
    for (let y = 0; y < modelSize; y++) {
      for (let x = 0; x < modelSize; x++) {
        const srcIdx = (y * modelSize + x) * 4;
        const dstIdx = y * modelSize + x;
        imgData[dstIdx] = resizedImageData.data[srcIdx];           // R
        imgData[modelSize * modelSize + dstIdx] = resizedImageData.data[srcIdx + 1]; // G
        imgData[modelSize * modelSize * 2 + dstIdx] = resizedImageData.data[srcIdx + 2]; // B
      }
    }

    // Prepare mask data (invert: white in mask = 0 in model)
    const maskDataArr = new Uint8Array(modelSize * modelSize);
    for (let i = 0; i < modelSize * modelSize; i++) {
      const maskPixel = resizedMaskData.data[i * 4];
      // Model expects: white (255) = keep, black (0) = inpaint
      maskDataArr[i] = maskPixel > 128 ? 0 : 255;
    }

    // Run inference
    const feeds: Record<string, ort.Tensor> = {};
    feeds[sessionRef.current.inputNames[0]] = new ort.Tensor("uint8", imgData, [1, 3, modelSize, modelSize]);
    feeds[sessionRef.current.inputNames[1]] = new ort.Tensor("uint8", maskDataArr, [1, 1, modelSize, modelSize]);

    const results = await sessionRef.current.run(feeds);
    const outputTensor = Object.values(results)[0] as ort.Tensor;
    const outputData = outputTensor.data as Uint8Array;

    // Convert output tensor to ImageData
    const resultImageData = new ImageData(modelSize, modelSize);
    for (let y = 0; y < modelSize; y++) {
      for (let x = 0; x < modelSize; x++) {
        const dstIdx = (y * modelSize + x) * 4;
        const srcIdx = y * modelSize + x;
        resultImageData.data[dstIdx] = outputData[srcIdx];
        resultImageData.data[dstIdx + 1] = outputData[modelSize * modelSize + srcIdx];
        resultImageData.data[dstIdx + 2] = outputData[modelSize * modelSize * 2 + srcIdx];
        resultImageData.data[dstIdx + 3] = 255;
      }
    }

    // Scale back to original size and blend
    // ... (blend logic in actual code)
  }
};
Enter fullscreen mode Exit fullscreen mode

7. Fallback Algorithm (Non-AI)

If the AI model fails to load, a traditional inpainting algorithm is used:

// Traditional inpainting using nearby pixel averaging
const mask = maskData.data;
const src = imageData.data;
const dst = new ImageData(width, height);

for (let i = 0; i < mask.length; i += 4) {
  if (mask[i] > 128) { // Inpainting region
    const x = (i / 4) % width;
    const y = Math.floor((i / 4) / width);
    const searchRadius = 20;
    let count = 0;
    let r = 0, g = 0, b = 0;

    // Sample pixels from surrounding area
    for (let dy = -searchRadius; dy <= searchRadius; dy++) {
      for (let dx = -searchRadius; dx <= searchRadius; dx++) {
        const nx = x + dx;
        const ny = y + dy;
        if (nx >= 0 && nx < width && ny >= 0 && ny < height) {
          const ni = (ny * width + nx) * 4;
          if (mask[ni] < 128) { // Not in mask region
            const dist = Math.sqrt(dx * dx + dy * dy);
            if (dist <= searchRadius && dist > 0) {
              // Weight by inverse distance squared
              const weight = 1 / (dist * dist + 0.1);
              r += src[ni] * weight;
              g += src[ni + 1] * weight;
              b += src[ni + 2] * weight;
              count += weight;
            }
          }
        }
      }
    }

    if (count > 0) {
      dst.data[i] = Math.min(255, Math.max(0, r / count));
      dst.data[i + 1] = Math.min(255, Math.max(0, g / count));
      dst.data[i + 2] = Math.min(255, Math.max(0, b / count));
    }
  }
}
Enter fullscreen mode Exit fullscreen mode

8. Undo/Redo System

const undo = () => {
  if (historyIndex.current > 0) {
    historyIndex.current--;
    const canvas = canvasRef.current;
    const ctx = canvas?.getContext("2d");
    if (ctx && canvas) {
      const img = new Image();
      img.onload = () => {
        canvas.width = img.width;
        canvas.height = img.height;
        ctx.drawImage(img, 0, 0);
      };
      img.src = history.current[historyIndex.current];
    }
  }
};

// Save state after each drawing stroke
const stopDrawing = () => {
  if (isDrawing.current) {
    isDrawing.current = false;
    const canvas = canvasRef.current;
    if (canvas) {
      history.current = history.current.slice(0, historyIndex.current + 1);
      history.current.push(canvas.toDataURL("image/png"));
      historyIndex.current = history.current.length - 1;
    }
  }
};
Enter fullscreen mode Exit fullscreen mode

Service Worker for Model Caching

The ~80MB model is cached by a Service Worker for faster subsequent loads:

const CACHE_NAME = 'inpaint-model-cache-v1';
const MODEL_URL = 'https://huggingface.co/andraniksargsyan/migan/resolve/main/migan_pipeline_v2.onnx';

self.addEventListener('fetch', (event) => {
  if (url.href.includes('migan_pipeline_v2.onnx')) {
    event.respondWith(
      caches.match(event.request).then((cachedResponse) => {
        if (cachedResponse) {
          return cachedResponse;
        }
        return fetch(event.request).then((networkResponse) => {
          const responseToCache = networkResponse.clone();
          caches.open(CACHE_NAME).then((cache) => {
            cache.put(event.request, responseToCache);
          });
          return networkResponse;
        });
      })
    );
  }
});
Enter fullscreen mode Exit fullscreen mode

Processing Flow

Key Technologies Used

Technology Purpose
ONNX Runtime Web Run ONNX models in browser
MIGAN Model Deep learning inpainting
WebGPU GPU acceleration for inference
Canvas API Image manipulation & mask drawing
Service Worker Cache large AI model
Blob URLs Handle images without server

Performance Characteristics

  1. Model Size: ~80MB (MIGAN v2)
  2. Model Loading: 10-30 seconds (depending on network)
  3. Inference Time: 2-10 seconds (depending on device)
  4. Input Size: Limited to 1024px max dimension
  5. Output Size: Maintains original dimensions

Browser Support

ONNX Runtime Web with WebGPU works in:

  • Chrome 113+
  • Edge 113+
  • Safari 17+ (limited WebGPU support)
  • Firefox (WebGPU coming)

Use Cases

  1. Remove watermarks - Paint over text or logos
  2. Remove objects - Erase people, pets, or items
  3. Fix imperfections - Remove scratches or dust
  4. Photo restoration - Fill in damaged areas

Conclusion

Browser-based AI inpainting brings professional image editing capabilities to the web without privacy concerns. The implementation uses:

  • ONNX Runtime Web for running the MIGAN deep learning model
  • WebGPU for GPU-accelerated inference (with WASM fallback)
  • Canvas API for mask drawing and image blending
  • Service Worker for model caching

The tool provides a seamless experience where users can draw masks on unwanted objects and have them automatically removed by AI, all while their images never leave their device.


Try it yourself at Free Image Tools

Experience the power of browser-based AI inpainting. No upload required - your images stay on your device!

Top comments (0)