DEV Community

Sarvagya Jaiswal
Sarvagya Jaiswal

Posted on

Catching Deepfakes in Real-Time: A Spatial-Temporal Approach with EfficientNet-B0 and Bi-LSTM

Catching Deepfakes in Real-Time: A Spatial-Temporal Approach with EfficientNet-B0 and Bi-LSTM

The problem with most early deepfake detection models is that they treat video as a collection of static images. They pass individual frames through a Convolutional Neural Network (CNN) and look for spatial artifacts—weird blurring around the jawline, mismatched skin tones, or pixelated boundaries.

But modern deepfakes (especially those generated by GANs and diffusion models) have virtually eliminated static spatial artifacts. A single frame often looks flawless. What gives a deepfake away isn't the space; it is the time. The blink rate is unnatural. The micro-expressions jitter. The lip-sync drifts off by a fraction of a second.

To catch a modern deepfake, you cannot just look at a picture. You have to understand the sequence. Here is how I built a Spatial-Temporal Deepfake Detector using PyTorch, combining an EfficientNet-B0 backbone for spatial feature extraction with a Bi-LSTM network for temporal sequence analysis.

1. The Architecture: Why Spatial-Temporal?

Processing raw video directly is computationally brutal. Instead, the architecture works in two distinct phases:

  1. Spatial Extraction (The "What"): We sample $N$ frames from a video and pass each frame through a pre-trained EfficientNet-B0. We discard the final classification layer, using the network strictly as a feature extractor. EfficientNet-B0 was chosen because it perfectly balances high-dimensional feature extraction with low computational overhead.
  2. Temporal Analysis (The "When"): The sequence of extracted feature vectors is then fed into a Bidirectional Long Short-Term Memory (Bi-LSTM) network. The Bi-LSTM analyzes the sequence both forwards and backwards, searching for temporal inconsistencies and unnatural frame-to-frame transitions.

2. Building the Hybrid Model in PyTorch

Stitching a CNN to an RNN requires careful tensor dimension management. Here is the core PyTorch module that bridges the two networks:

import torch
import torch.nn as nn
from torchvision import models

class DeepfakeDetector(nn.Module):
    def __init__(self, sequence_length=20, hidden_dim=256, lstm_layers=2):
        super(DeepfakeDetector, self).__init__()
        self.sequence_length = sequence_length

        # 1. Spatial Feature Extractor (EfficientNet-B0)
        efficientnet = models.efficientnet_b0(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(efficientnet.children())[:-1])
        self.feature_dim = 1280 

        # 2. Temporal Sequence Modeler (Bi-LSTM)
        self.lstm = nn.LSTM(
            input_size=self.feature_dim,
            hidden_size=hidden_dim,
            num_layers=lstm_layers,
            batch_first=True,
            bidirectional=True
        )

        # 3. Final Classifier
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim * 2, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 1),
            # Sigmoid for binary probability
            nn.Sigmoid()
        )

    def forward(self, x):
        # x shape: (Batch, Sequence_Length, C, H, W)
        batch_size, seq_len, c, h, w = x.size()

        # Reshape for CNN processing
        x = x.view(batch_size * seq_len, c, h, w)
        spatial_features = self.feature_extractor(x)

        # Reshape back to sequence for LSTM
        spatial_features = spatial_features.view(batch_size, seq_len, self.feature_dim)
        lstm_out, _ = self.lstm(spatial_features)

        # Use final timestep for classification
        final_timestep_out = lstm_out[:, -1, :]
        return self.classifier(final_timestep_out)
Enter fullscreen mode Exit fullscreen mode

3. The Deployment Challenge: Video Processing in Gradio

Deploying this model introduces a unique challenge: you aren't just handling images; you are handling video streams.

When deploying this to Hugging Face via Gradio, I had to write a custom preprocessing pipeline using OpenCV (cv2.VideoCapture) to extract the video frames, sample them evenly to match the model's sequence_length, and stack them into a 5D PyTorch tensor (Batch, Sequence, Channels, Height, Width).

import gradio as gr
import cv2

def process_video(video_path):
    # OpenCV logic to extract exactly 20 frames
    # Apply Resize and Normalization transforms
    # Perform model inference
    return prediction_score

interface = gr.Interface(
    fn=process_video,
    inputs=gr.Video(),
    outputs=gr.Label(label="Authenticity Analysis"),
    title="Spatial-Temporal Deepfake Detector"
)
Enter fullscreen mode Exit fullscreen mode

The Takeaway

Catching deepfakes is no longer just a computer vision problem; it is a sequence modeling problem. By utilizing an EfficientNet-B0 to understand the space and a Bi-LSTM to understand the time, we can flag the unnatural temporal micro-jitters that standard frame-by-frame analysis misses.

Top comments (0)