DEV Community

Cover image for Useful tensorflow/keras callbacks for model training
catasaurus
catasaurus

Posted on

Useful tensorflow/keras callbacks for model training

Here are some callbacks that I have found to be very useful when training machine learning models using python and tensorflow:

Number one: Early stopping

Keras early stopping (https://keras.io/api/callbacks/early_stopping/) has to be my favorite callback. With it you can define when the model should stop training if it is not improving. An example for usage is:

earlystopping = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    min_delta=0.001,
    patience=5,
    verbose=1,
    restore_best_weights=True,
)
Enter fullscreen mode Exit fullscreen mode

This will stop the model's training once it does not improve at least 0.001 in loss for 5 epochs. It will then restore the model's weights to the weights on the best epoch. Just like any callback make sure to include it during training like
model.fit(some_data_X, some_data_y, epochs=some_number, callbacks=[earlystopping, some_other_callback])

Number two: Learning rate scheduler

Keras learning rate scheduler (https://keras.io/api/callbacks/learning_rate_scheduler/) can be very useful if you are having problems with your learning rate. With it you can reduce or increase learning rate during training based on a number of conditions. An example:

def scheduler(epoch, lr):
       return lr * tf.math.exp(-0.5)

 learningratecallback = tf.keras.callbacks.LearningRateScheduler(scheduler)
Enter fullscreen mode Exit fullscreen mode

The scheduler function is where you can define your logic for how the learning rate should decrease or increase. learningratecallback just wraps your function in a tf.keras.callbacks.LearningRateScheduler(). Don't forget to include it in model.fit()!

Last but not least, number three: Custom callbacks

Custom callbacks (https://keras.io/guides/writing_your_own_callbacks/) are great if you need to do something during training that is not built in to keras + tensorflow. I won't go in depth as there is a lot you can do. Basically you have to define a class that inherits from keras.callbacks.Callback. There are many different functions that you can define that will be called at different times during the training (or testing and prediction) cycle. A simple example would be:

class Catsarecoolcallback(keras.callbacks.Callback):
    def on_epoch_end(self, logs=None):
        print('cats are cool!`)
callback = Catsarecoolcallback()
Enter fullscreen mode Exit fullscreen mode

This (as you can probably tell) prints out cats are cool! every time an epoch ends.

Hope you learned something while reading this!

Top comments (0)