DEV Community

Cover image for Pretrained Models for Transfer Learning in Keras for Computer Vision
amananandrai
amananandrai

Posted on

Pretrained Models for Transfer Learning in Keras for Computer Vision

Tensorflow is one of the highly used libraries for Machine Learning. It has built-in support for Keras. We can easily call functions related to Keras by using the tf.keras module. Computer Vision is one of the most interesting branches of machine learning. The ImageNet dataset was the turning point for researchers related to Computer Vision as it provided a large set of images for Object detection. It is now a benchmark for testing the accuracy of Image Classification and Object Detection deep learning models.

Transfer Learning is also one of the major developments in the case of Deep Learning for Object Detection. In transfer learning, we take a pre-trained model performing classification on a dataset and then apply this same model to another set of classification task by just optimising the hyperparameters a little bit. Transfer Learning has two benefits:

  • It requires less time to train the model as it is already trained on a different task
  • It can be used for tasks which have smaller dataset as the model is already trained on a larger dataset and the weights are transferred to the new task

Alt Text

Illustration of Transfer Learning where the model trained for object detection like Cat,Dog,etc. is used again for Cancer Detection by transferring of weights.

The Tensorflow Keras module has a lot of pretrained models which can be used for transfer learning. The details about which can be found here. The tf.keras.applications module contains these models.

A list of modules and functions for calling Deep learning model architectures present in the tf.keras.applications module is given below:

Module DL Model Functions
densenet DenseNet121(), DenseNet169(), DenseNet201()
efficientnet EfficientNetB0(), EfficientNetB1(), EfficientNetB2(), EfficientNetB3(), EfficientNetB4(), EfficientNetB5(), EfficientNetB6(), EfficientNetB7()
inception_resnet_v2 InceptionResNetV2()
inception_v3 InceptionV3()
mobilenet MobileNet()
mobilenet_v2 MobileNetV2()
nasnet NASNetLarge(), NASNetMobile()
resnet ResNet101(), ResNet152(),
resnet50 ResNet50()
resnet_v2 ResNet101V2(), ResNet152V2(), ResNet50V2()
vgg16 VGG16()
vgg19 VGG19()
xception Xception()

We write models in TensorFlow as per the example given below:

import tensorflow.keras as keras

model = keras.Sequential([

    # First Convolutional Block
    layers.Conv2D(filters=32, kernel_size=5, activation="relu", padding='same',input_shape=[128, 128, 3]),
    layers.MaxPool2D(),

    # Second Convolutional Block
    layers.Conv2D(filters=64, kernel_size=3, activation="relu", padding='same'),
    layers.MaxPool2D(),

    # Third Convolutional Block
    layers.Conv2D(filters=128, kernel_size=3, activation="relu", padding='same'),
    layers.MaxPool2D(),

    # Classifier Head
    layers.Flatten(),
    layers.Dense(units=6, activation="relu"),
    layers.Dense(units=1, activation="sigmoid"),
])
Enter fullscreen mode Exit fullscreen mode

The structure of this Deep Learning model is as follow
Alt Text


In the same way we can call the Xception() function from the tf.keras.applications module to add the pretrained model to our architecture, this model is pretrained so we are taking the weights from previous dataset or task 'imagenet' and in our model not training it again, hence the parameter trainable is set to False. A globalaveragepooling layer is used and then softmax is used for multiclass classification in case of binary classification the activation function must be sigmoid.

pretrained_model = tf.keras.applications.Xception(
        weights='imagenet',
        include_top=False ,
        input_shape=[*IMAGE_SIZE, 3]
    )
pretrained_model.trainable = False

model = tf.keras.Sequential([

        pretrained_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(len(CLASSES), activation='softmax')
    ])
Enter fullscreen mode Exit fullscreen mode

We can use all the different models in the same way by just changing the functions.

Top comments (0)