DEV Community

Santosh Premi Adhikari
Santosh Premi Adhikari

Posted on

5

Fine-Tuning a Pre-Trained Model in PyTorch: A Step-by-Step Guide for Beginners

Fine-tuning is a powerful technique that allows you to adapt a pre-trained model to a new task, saving time and resources. This tutorial will guide you through fine-tuning a ResNet18 model for digit classification using PyTorch.

Step 1: Setting Up the Environment and Model

First, let's import the required libraries and set up the pre-trained model.

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision import models

# Load a pre-trained ResNet18 model
model = models.resnet18(pretrained=True)

# Modify the last layer to match MNIST classes (10 classes)
model.fc = nn.Linear(model.fc.in_features, 10)

# Set the model to training mode and use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

Enter fullscreen mode Exit fullscreen mode

Explanation: We are using ResNet18, which is pre-trained on ImageNet. The last layer is modified to fit our new task (10 classes for digits 0-9).

Step 2: Preparing the Dataset

We need to resize the MNIST images to 224x224 because ResNet18 expects this size.

# Transform images to 224x224 and normalize
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the MNIST dataset
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)

testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)

Enter fullscreen mode Exit fullscreen mode

Explanation: We resize images and normalize them to fit what the model expects. We also create data loaders for training and testing.

Step 3: Setting Up the Loss Function and Optimizer
Define the loss function and optimizer. We use Adam, a popular choice for fine-tuning.

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Learning rate scheduler to adjust the learning rate
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

Enter fullscreen mode Exit fullscreen mode

Explanation: CrossEntropyLoss is used for classification tasks, and Adam is chosen for optimization. The scheduler helps adjust the learning rate during training.

Step 4: Fine-Tuning the Model
Train the model for a few epochs to fine-tune it on the new task.

# Fine-tune the model
num_epochs = 5
#set num_epochs to a smaller number like 1 and use T4 GPU, wait for 4 minute may be, if training is taking too long in colab. 
for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in trainloader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # Step the scheduler after each epoch
    scheduler.step()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(trainloader):.4f}")

print('Fine-tuning complete!')

Enter fullscreen mode Exit fullscreen mode

Explanation: We loop through the dataset, perform forward and backward passes, and update the model's weights. This step adapts the pre-trained model to our specific task.

Step 5: Saving the Fine-Tuned Model
Save the fine-tuned model for later use.

# Save the fine-tuned model
torch.save(model.state_dict(), 'finetuned_resnet18_mnist.pth')
print('Model saved!')

Enter fullscreen mode Exit fullscreen mode

Explanation: We save the model's state dictionary (parameters) to a file. This allows us to load it later without retraining.

Step 6: Evaluating the Model
Check how well the model performs on unseen data.

# Set the model to evaluation mode
model.eval()

correct = 0
total = 0

with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the fine-tuned model on the test images: {100 * correct / total:.2f}%')

Enter fullscreen mode Exit fullscreen mode

Explanation: We set the model to evaluation mode and calculate accuracy on the test set to see how well the model generalizes to new data.

Step 7: Making Predictions
Finally, use the fine-tuned model to make predictions on a single image.

# Load the model for inference
model = models.resnet18()
model.fc = nn.Linear(model.fc.in_features, 10)
model.load_state_dict(torch.load('finetuned_resnet18_mnist.pth'))
model.eval()
model = model.to(device)

# Make a prediction on a single image from the test set
test_image, _ = testset[0]  # Get the first image from the test set
test_image = test_image.unsqueeze(0).to(device)  # Add a batch dimension and move to device

output = model(test_image)
_, predicted = torch.max(output, 1)

print('Predicted label:', predicted.item())

Enter fullscreen mode Exit fullscreen mode

Explanation: We load the saved model, set it to evaluation mode, and use it to predict the class of a new image.
Conclusion
Fine-tuning is an efficient way to adapt powerful pre-trained models to new tasks with minimal effort. In this guide, you learned how to:

1.Set up a pre-trained model and modify it for a new task.
2.Prepare and load your dataset.
3.Train and fine-tune the model with a new dataset.
4.Save and load the model for future use.
5.Make predictions using the fine-tuned model.

Hope you found this post helpful and enjoyable.
Thank you!

Image of Datadog

The Future of AI, LLMs, and Observability on Google Cloud

Datadog sat down with Google’s Director of AI to discuss the current and future states of AI, ML, and LLMs on Google Cloud. Discover 7 key insights for technical leaders, covering everything from upskilling teams to observability best practices

Learn More

Top comments (0)

Image of Docusign

🛠️ Bring your solution into Docusign. Reach over 1.6M customers.

Docusign is now extensible. Overcome challenges with disconnected products and inaccessible data by bringing your solutions into Docusign and publishing to 1.6M customers in the App Center.

Learn more