DEV Community

Cover image for How I Fine-Tuned a Vision Transformer to Spot Deepfakes
Rupam Golui
Rupam Golui Subscriber

Posted on • Edited on

How I Fine-Tuned a Vision Transformer to Spot Deepfakes

This project started out as a hackathon idea — we wanted to create a practical tool that could detect deepfakes in images with high confidence. The goal? Build a complete multi-model deepfake classification framework that doesn’t just sound cool, but works.

We named the image model Virtus (because hey, if you're fighting fakes, might as well sound noble). I handled the image classification side & devops while others tackled video detection, frontend/backend.

This post dives into how I built and trained Virtus: the thinking behind the model choice, dataset, training strategies, evaluation, and pushing it to Hugging Face. I’ll also sprinkle in some tips and lessons I picked up along the way — stuff I wish I knew before starting.

Want to skip the reading and jump straight into the code? Here's the full training notebook on GitHub.


Choosing a Base Model: Why Vision Transformers?

Initially, I considered the usual CNN suspects — ResNet, EfficientNet, all the classics. But deepfakes are tricky. The difference between a real and fake face can be insanely subtle — we're talking fine textures, light inconsistencies, stuff that might get blurred out or overlooked by CNNs.

So I started digging into Vision Transformers (ViTs) — and let’s just say, I went down the rabbit hole.

Turns out, ViTs aren't just trendy — they're built different. While CNNs work with pixel grids and sliding filters, ViTs treat images like sequences, kind of like sentences. They split an image into patches (aka "visual tokens") and feed them into a transformer — the same architecture that powers modern NLP models like BERT and GPT.

CNN vs. ViT: FLOPs and throughput comparison of CNN and Vision Transformer Models
CNN vs. ViT: FLOPs and throughput comparison of CNN and Vision Transformer Models – Source

ViTs actually have weaker inductive bias compared to CNNs — which sounds bad, but it means they don’t assume as much about the structure of images. With enough data (or strong augmentations), they learn better generalizations. And here’s the kicker: they can outperform CNNs with 4x fewer computational resources, which was a huge win for my Kaggle GPU budget.

If you want a more in-depth comparison, check out this awesome article — seriously, worth a read.

Eventually, I ended up choosing facebook/deit-base-distilled-patch16-224, a ViT that’s been distilled from a CNN teacher model. It’s lightweight (only 87M parameters), fast to train, and surprisingly accurate — even outperforming the standard ViT-Base on ImageNet with just 1k classes. Plus, it doesn’t need massive compute or crazy pretraining to get good results, which made it perfect for our hackathon timeline.

Vision Transformer ViT Architecture
Vision Transformer ViT Architecture - Source

If you’ve got more GPU headroom or a larger dataset, there are beefier models out there like google/vit-large-patch16-224-in21k, vit-base-patch32, or even the 384px version of DeiT — but for this project, I wanted something fast, efficient, and reliable. DeiT hit that sweet spot.


Data Preparation: Deepfake Data Is Messy

I started with a Kaggle dataset that had around 190,000 labeled images of real and fake faces — a solid foundation to begin with. On top of that, I manually added a bunch of extra samples I’d collected from other sources to make things a bit more diverse. Everything was organized into Real/ and Fake/ folders, so loading them with Path.glob was smooth sailing.

After loading the dataset, I quickly noticed the class distribution was skewed — one of the classes (either real or fake) had noticeably more images than the other. That’s not great for training, since the model might just learn to always predict the majority class.

To fix that, I used RandomOverSampler to duplicate samples from the underrepresented class. It’s a quick and dirty way to balance things — works surprisingly well for binary classification.

from imblearn.over_sampling import RandomOverSampler
import gc

# Separate out the labels before resampling
y = df[['label']]
df = df.drop(['label'], axis=1)

# Oversample to balance the classes
ros = RandomOverSampler(random_state=83)
df, y_resampled = ros.fit_resample(df, y)

# Stick the labels back
df['label'] = y_resampled
gc.collect()  # Clean up some memory — just to be safe
Enter fullscreen mode Exit fullscreen mode

Now with the classes balanced, I converted the DataFrame into a Hugging Face Dataset object. This makes everything later (transforms, batching, etc.) super smooth. Before that, I mapped string labels ("Real" / "Fake") to numeric IDs (0 / 1). Hugging Face has a ClassLabel feature built exactly for this:

from datasets import ClassLabel

# Define the class order explicitly
labels_list = ['Real', 'Fake']
class_labels = ClassLabel(num_classes=2, names=labels_list)

# Label encoding function
def map_label2id(example):
    example["label"] = class_labels.str2int(example["label"])
    return example

# Apply the mapping to the dataset
dataset = dataset.map(map_label2id, batched=True)
dataset = dataset.cast_column("label", class_labels)  # Ensures label column behaves like an integer class
Enter fullscreen mode Exit fullscreen mode

Finally, I split the dataset into 60% for training and 40% for testing. I also made sure the label distribution was preserved across both splits using stratify_by_column.

# Train-test split with stratified labels
dataset = dataset.train_test_split(test_size=0.4, shuffle=True, stratify_by_column="label")

train_data = dataset['train']
test_data = dataset['test']
Enter fullscreen mode Exit fullscreen mode

How I Trained Virtus — Step-by-Step

Alright, now comes the fun part — training the model. We’re going to fine-tune facebook/deit-base-distilled-patch16-224 on our deepfake dataset using Hugging Face's Trainer API.

Quick side note: I trained everything inside a Kaggle Notebook using a single NVIDIA P100 GPU. After some trial runs, I found it performed noticeably better than the T4s (even the dual T4 setup Kaggle sometimes gives you). Turns out, the P100 has higher memory bandwidth and better raw compute — which really helps when you're fine-tuning ViTs.

If Kaggle isn’t your thing, no worries. I also tried Lightning AI Studio and AWS Studio Lab, and both were solid freemium options. Way more stable than google Colab, honestly. With Colab’s free tier, I kept hitting runtime errors, couldn’t even get a GPU some days, and the lack of persistent storage was a dealbreaker. Hardware quality also felt kinda... meh.

⚡ Bonus tip: If you’re training locally, try managing your Python environment with uv. It’s ridiculously fast. Plus, you won’t fall into dependency hell™, which is honestly half the battle when setting up ML projects. Follow this tutorial if you wanna give it a spin — highly recommend.

Step 1: Preprocessing and Augmentation

Before throwing data at the model, we need to make sure the input images are normalized exactly the way the pre-trained ViT expects.

from transformers import ViTImageProcessor
from torchvision.transforms import Compose, Resize, RandomRotation, RandomAdjustSharpness, ToTensor, Normalize

model_str = "facebook/deit-base-distilled-patch16-224"
processor = ViTImageProcessor.from_pretrained(model_str)

image_mean, image_std = processor.image_mean, processor.image_std
size = processor.size["height"]
Enter fullscreen mode Exit fullscreen mode

Then I defined two sets of transforms — one for training (with augmentations) and one for validation.

_train_transforms = Compose([
    Resize((size, size)),
    RandomRotation(90),
    RandomAdjustSharpness(2),
    ToTensor(),
    Normalize(mean=image_mean, std=image_std)
])

_val_transforms = Compose([
    Resize((size, size)),
    ToTensor(),
    Normalize(mean=image_mean, std=image_std)
])
Enter fullscreen mode Exit fullscreen mode

Why these augmentations? Deepfakes can vary a lot depending on the source. A little rotation and sharpness tweaks help the model generalize to those variations. No augmentation for validation though — we want to keep that clean.

Step 2: Applying Transforms

I used Hugging Face’s set_transform method to apply the preprocessing on-the-fly. This keeps RAM usage low and plays nicely with their Dataset objects.

train_data.set_transform(lambda x: {"pixel_values": [_train_transforms(img.convert("RGB")) for img in x["image"]]})
test_data.set_transform(lambda x: {"pixel_values": [_val_transforms(img.convert("RGB")) for img in x["image"]]})
Enter fullscreen mode Exit fullscreen mode

Step 3: Custom Collate Function

The Trainer needs batches of images and labels. Here's a simple collate function to stack tensors correctly:

def collate_fn(examples):
    pixel_values = torch.stack([e["pixel_values"] for e in examples])
    labels = torch.tensor([e["label"] for e in examples])
    return {"pixel_values": pixel_values, "labels": labels}
Enter fullscreen mode Exit fullscreen mode

Step 4: Loading the Model

Now we bring in the ViT model with the correct number of labels and label mappings:

from transformers import ViTForImageClassification

model = ViTForImageClassification.from_pretrained(
    model_str,
    num_labels=2
)

model.config.label2id = {'Real': 0, 'Fake': 1}
model.config.id2label = {0: 'Real', 1: 'Fake'}
Enter fullscreen mode Exit fullscreen mode

Step 5: TrainingArguments

These settings worked great for me — small learning rate, a couple of epochs, early checkpoint saving, etc.

from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="virtus",
    logging_dir="./logs",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=1e-6,  # Tiny learning rate to avoid overshooting on a sensitive task like classification
    per_device_train_batch_size=32,
    per_device_eval_batch_size=8,
    num_train_epochs=2, # 2 epochs were enough to converge for my dataset; more can overfit
    weight_decay=0.02,  # Helps regularize and reduce overfitting
    warmup_steps=50, # Linearly ramps up LR at the start for more stable training
    load_best_model_at_end=True,
    save_total_limit=1,
    report_to="none"
)
Enter fullscreen mode Exit fullscreen mode

Step 6: Train Time!

Let’s go!

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_data,
    eval_dataset=test_data,
    data_collator=collate_fn,
    tokenizer=processor,  # Required even if we don’t use text
    compute_metrics=lambda p: {"accuracy": (p.predictions.argmax(-1) == p.label_ids).mean()}
)

trainer.train()
Enter fullscreen mode Exit fullscreen mode

It trained in around 2 hours on Kaggle’s GPU runtime and reached ~99.2% accuracy. Not bad for just two epochs.


Evaluation: Did Virtus Actually Learn Anything?

After training wrapped up, I wanted to be sure the model wasn’t just memorizing the training data. Hugging Face's Trainer makes it dead simple to evaluate performance on the test set.

# Run evaluation on the test set
trainer.evaluate()
Enter fullscreen mode Exit fullscreen mode

That gave me some solid metrics:

{'eval_loss': 0.0248, 'eval_accuracy': 0.9919, ...}
Enter fullscreen mode Exit fullscreen mode

Yeah — over 99% accuracy. I double-checked this wasn’t a fluke by inspecting predictions manually:

# Make predictions on test data
outputs = trainer.predict(test_data)

# See predicted vs actual for the first 5 samples
preds = outputs.predictions.argmax(axis=1)
labels = outputs.label_ids

for i in range(5):
    print(f"Predicted: {id2label[preds[i]]} | Actual: {id2label[labels[i]]}")
Enter fullscreen mode Exit fullscreen mode

And it matched up nicely.

To dig deeper, I calculated the macro F1 score and plotted the confusion matrix — just to visualize how well the model was doing on both classes.And here's the result:

Confusion matrix for virtus

The matrix was basically diagonal — which means the model was nailing both classes.


Publishing to Hugging Face: Share It With the World

Once I was happy with Virtus, I wanted to make it public. Hugging Face Hub is the easiest way to share models — and you can even push directly from a Kaggle notebook using secrets.

First, install the CLI tools:

!pip install -q huggingface_hub
Enter fullscreen mode Exit fullscreen mode

Then authenticate using a token (I stored mine using Kaggle secrets, but you can use huggingface-cli login locally):

from huggingface_hub import login, create_repo
from kaggle_secrets import UserSecretsClient

token = UserSecretsClient().get_secret("HF_TOKEN")
login(token)
Enter fullscreen mode Exit fullscreen mode

Create your repo (this can be done via the website too, but I like automation):

create_repo(repo_id="agasta/virtus", private=False)
Enter fullscreen mode Exit fullscreen mode

Finally, push both the model and its image processor (so others don’t have to guess your preprocessing steps):

from transformers import AutoModelForImageClassification, AutoFeatureExtractor

model = AutoModelForImageClassification.from_pretrained("./virtus")
extractor = AutoFeatureExtractor.from_pretrained("./virtus")

model.push_to_hub("agasta/virtus")
extractor.push_to_hub("agasta/virtus")
Enter fullscreen mode Exit fullscreen mode

Boom. Your model’s live.

👉 Check it out here: https://huggingface.co/agasta/virtus

But we’re not done yet.

In the next blog, I’ll show you how to wrap this model in a FastAPI-powered backend, make it production-ready, and deploy it like a real-world service — something you can actually integrate into an app or use in a real-time system.

If you're into this kind of stuff — AI, web3, backend dev, DevOps, hackathon builds — follow me on X @idkAgasta. I post cool projects, quick tips, and sometimes just chaos.

Until next time — keep building, keep shipping Nerds.

Top comments (0)