DEV Community

gdemarcq
gdemarcq

Posted on • Updated on • Originally published at ikomia.ai

How to train a classification model on a custom dataset

Featured image

In this blog post, we will cover the necessary steps to train a custom image classification model and test it on images.

The Ikomia API simplifies the development of Computer Vision workflows and provides an easy way to experiment with different parameters to achieve optimal results.

Get started with Ikomia API

You can train a custom classification model with just a few lines of code. To begin, you will need to install the API within a virtual environment.

How to install a virtual environment

pip install ikomia
Enter fullscreen mode Exit fullscreen mode

API documentation

API repo

In this tutorial, we will use the Rock, Paper, Scissor dataset from Roboflow.

Ensure that the dataset is organized in the correct format, as shown below:

(Note: The “validation” folder should be renamed to “val”.)

Folder tree

‍Run the train ResNet algorithm

You can also charge directly the open-source notebook we have prepared.

from ikomia.dataprocess.workflow import Workflow
from ikomia.utils import ik

# Init your workflow
wf = Workflow()
# Add the training task to the workflow
resnet = wf.add_task(ik.train_torchvision_resnet(
    model_name="resnet34",
    batch_size="16",
    epochs="5",
    output_folder="Path/To/Output/Folder"
),
    auto_connect=True
)

# Set the input path of your dataset
dataset_folder = "Path/To/Rock Paper Scissors.v1-hugggingface.folder"
# Launch your training on your data
wf.run_on(folder=dataset_folder)
Enter fullscreen mode Exit fullscreen mode

After 5 epochs of training, you will see the following metrics:

  • train Loss: 0.3751 Acc:0.8468

  • val Loss: 0.5611 Acc:0.7231

  • val per class Acc:tensor([0.75806, 1.00000, 0.41129])

  • Training complete in 1m57s

  • Best accuracy: 0.838710

Image Classification

Before experimenting with TorchVision ResNet, let’s dive deeper into image classification and the characteristics of this particular algorithm.

What is Image Classification?

Image classification is a fundamental task in Computer Vision that involves categorizing images into predefined classes based on their visual content. It enables computers to recognize objects, scenes, and patterns within images. The importance of image classification lies in its various applications:

Object Recognition

It allows computers to identify and categorize objects in images, essential for applications like autonomous vehicles and surveillance systems.

Image Understanding

Classification helps machines interpret image content and extract meaningful information, enabling advanced analysis and decision-making based on visual data.

Visual Search and Retrieval‍

By assigning tags or labels to images, classification models facilitate efficient searching and retrieval of specific images from large databases.

Content Filtering and Moderation‍

Image classification aids in automatically detecting and flagging inappropriate or offensive content, ensuring safer online environments.

Medical Imaging and Diagnosis‍

Classification assists in diagnosing diseases and analyzing medical images, enabling faster and more accurate diagnoses.

Quality Control and Inspection‍

By classifying images, defects or anomalies in manufactured products can be identified, ensuring quality control in various industries.

Visual Recommendation Systems‍

Image classification enhances recommendation systems by analyzing visual content and suggesting related items or content.

Security and Surveillance‍

Classification enables the identification of objects or individuals of interest in security and surveillance applications, enhancing threat detection and public safety.

In summary, image classification is essential for object recognition, image understanding, search and retrieval, content moderation, medical imaging, quality control, recommendation systems, and security applications in computer vision.

What is TorchVision ResNet?

A DCNN architecture

TorchVision is a popular Computer Vision library in PyTorch that provides pre-trained models and tools for working with image data. One of the widely used models in TorchVision is ResNet. ResNet, short for Residual Network, is a deep convolutional neural network architecture introduced by Kaiming He et al. in 2015. It was designed to address the challenge of training deep neural networks by introducing a residual learning framework.

Residual blocks to train deeper networks

ResNet uses residual blocks with skip connections to facilitate information flow between layers, mitigating the vanishing gradient problem and enabling the training of deeper networks.

The key idea behind ResNet is the use of residual blocks, which allow the network to learn residual mappings. These residual blocks contain skip connections that bypass one or more layers, enabling the flow of information from earlier layers to later layers.

This helps alleviate the vanishing gradient problem and facilitates the training of deeper networks.

Skip connections

The residual connection creates a shortcut path by adding the value at the beginning of the block, x, directly to the end of the block (F(x) + x) [Source].

This allows information to pass through multiple layers without degradation, making training and optimization easier.‍

The Microsoft Research team won the ImageNet 2015 competition using these deep residual layers, which use skip connections. They used ResNet-152 convolutional neural network architecture, comprising a total of 152 layers.

ResNet34 architecture

ResNet34 Architecture [Source].

Various ResNet models

ResNet models are available in torchvision with different depths, including ResNet-18, ResNet-34, ResNet-50, ResNet-101, and ResNet-152. These pre-trained models have been trained on large-scale image classification tasks, such as the ImageNet dataset, and achieved state-of-the-art performance.

By using pre-trained ResNet models from torchvision, researchers and developers can leverage the learned representations for various Computer Vision tasks, including image classification, object detection, and feature extraction.‍

Step by step: Train ResNet Image Classification Model using Ikomia API

With the dataset of Rock, Paper & Scissor images that you have downloaded, you can easily train a custom ResNet model using the Ikomia API. Let’s go through the process together:

Step 1: import

from ikomia.dataprocess.workflow 
import Workflowfrom ikomia.utils import ik
Enter fullscreen mode Exit fullscreen mode
  • Workflow is the base object to create a workflow. It provides methods for setting inputs such as images, videos, and directories, configuring task parameters, obtaining time metrics, and accessing specific task outputs such as graphics, segmentation masks, and texts.

  • Ik is an auto-completion system designed for convenient and easy access to algorithms and settings.

Step 2: create workflow

Initialize a workflow instance by creating a ‘wf’ object. This object will be used to add tasks to the workflow, configure their parameters, and run them on input data.

wf = Workflow()
Enter fullscreen mode Exit fullscreen mode

Step 3: add the torchvision ResNet algorithm and set the parameters

Now, let’s add the train_torchvision_resnet task to train our custom image classifier. We also need to specify a few parameters for the task:

resnet = wf.add_task(ik.train_torchvision_resnet(
    model_name="resnet34",
    batch_size="16",
    epochs="5"
),
    auto_connect=True
)
Enter fullscreen mode Exit fullscreen mode
  • model_name: name of the pre-trained model

  • batch_size: Number of samples processed before the model is updated.

  • epochs: Number of complete passes through the training dataset.

  • input_size: Input image size during training.

  • learning_rate: Step size at which the model’s parameters are updated during training.

  • momentum: Optimization technique that accelerates convergence

  • weight_decay: Regularization technique that reduces the magnitude of the model’s

  • output_folder: Path to where the trained model will be saved.

Step 4: set the input path of your dataset

Next, provide the path to the dataset folder for the task input.

dataset_folder = "Path/To/Rock Paper Scissors.v1-raw-300x300.folder"
Enter fullscreen mode Exit fullscreen mode

Step 5: run your workflow

Finally, it’s time to run the workflow and start the training process.

wf.run_on(folder=dataset_folder)
Enter fullscreen mode Exit fullscreen mode

Test your custom ResNet image classifier

First, we can run a rock/paper/scissor image on the pre-trained ResNet34 model:

from ikomia.dataprocess.workflow import Workflow
from ikomia.utils.displayIO import display
from ikomia.utils import ik

# Initialize the workflow
wf = Workflow()
# Add the image classification algorithm  
resnet = wf.add_task(ik.infer_torchvision_resnet(model_name="resnet34"), auto_connect=True)
# Run on your image
wf.run_on(path="Path/To/Rock Paper Scissors/Dataset/test/rock/rock8_png.rf.8b06573ed8208e085c3b2e3cf06c7888.jpg")
# Inspect your results
display(resnet.get_image_with_graphics())
Enter fullscreen mode Exit fullscreen mode

Result knee pod

We can observe that ResNet34 pre-trained model doesn’t detect rock signs. This is because the model has been trained on the ImageNet dataset, which does not contain images of rock/paper/scissor hand signs.

To test the model we just trained, we specify the path to our custom model and class names using the ’model_weight_file’ and “class_file” parameters. We then run the workflow on the same image we used previously.

# Add the custom ResNet model  
resnet = wf.add_task(ik.infer_torchvision_resnet(
    model_name="resnet34",
    model_weight_file="path/to/output_folder/timestamp/06-06-2023T14h32m40s/resnet34.pth",
    class_file="path/to/output_folder/timestamp/classes.txt"),
    auto_connect=True
)
Enter fullscreen mode Exit fullscreen mode

Result rock

Here are some more examples of image classification using the pre-trained (left) and our custom model (right):‍

Result scissors

paper

Build your own Computer Vision workflow

To learn more about the API, refer to the documentation. You may also check out the list of state-of-the-art algorithms on Ikoma HUB and try out Ikomia STUDIO, which offers a friendly UI with the same features as the API.

Top comments (0)