Here are some callbacks that I have found to be very useful when training machine learning models using python and tensorflow:
Number one: Early stopping
Keras early stopping (https://keras.io/api/callbacks/early_stopping/) has to be my favorite callback. With it you can define when the model should stop training if it is not improving. An example for usage is:
earlystopping = tf.keras.callbacks.EarlyStopping(
monitor="val_loss",
min_delta=0.001,
patience=5,
verbose=1,
restore_best_weights=True,
)
This will stop the model's training once it does not improve at least 0.001 in loss for 5 epochs. It will then restore the model's weights to the weights on the best epoch. Just like any callback make sure to include it during training like
model.fit(some_data_X, some_data_y, epochs=some_number, callbacks=[earlystopping, some_other_callback])
Number two: Learning rate scheduler
Keras learning rate scheduler (https://keras.io/api/callbacks/learning_rate_scheduler/) can be very useful if you are having problems with your learning rate. With it you can reduce or increase learning rate during training based on a number of conditions. An example:
def scheduler(epoch, lr):
return lr * tf.math.exp(-0.5)
learningratecallback = tf.keras.callbacks.LearningRateScheduler(scheduler)
The scheduler function is where you can define your logic for how the learning rate should decrease or increase. learningratecallback
just wraps your function in a tf.keras.callbacks.LearningRateScheduler()
. Don't forget to include it in model.fit()
!
Last but not least, number three: Custom callbacks
Custom callbacks (https://keras.io/guides/writing_your_own_callbacks/) are great if you need to do something during training that is not built in to keras + tensorflow. I won't go in depth as there is a lot you can do. Basically you have to define a class that inherits from keras.callbacks.Callback
. There are many different functions that you can define that will be called at different times during the training (or testing and prediction) cycle. A simple example would be:
class Catsarecoolcallback(keras.callbacks.Callback):
def on_epoch_end(self, logs=None):
print('cats are cool!`)
callback = Catsarecoolcallback()
This (as you can probably tell) prints out cats are cool!
every time an epoch ends.
Top comments (0)