DEV Community


Posted on

Forward hooks in PyTorch

Alt Text

Forward hooks are custom functions that get executed right after the forward pass. Among other things, one can use them together with TensorBoard to visualize activations of any layer.


When using torch.nn.Module, did you ever wonder what the difference between the forward and the __call__ methods is?
One can roughly say that __call__ = forward + execution of various hooks.


Hooks are custom functions that get executed at specific moments during the forward/backward phase. They allow us to inspect what is going on inside of the network. The specific use cases are:

  • Debugging
  • Logging
  • Visualizing (this post)

If you are interested to learn more about hooks checkout the official docs.

Tell me more about forward hooks!

Forward hook is a function that accepts 3 arguments

  • module_instance : Instance of the layer your are attaching the hook to
  • input : tuple of tensors (or other) that we pass as the input to the forward method
  • output : tensor (or other) that is the output of the the forward method

Once you define it, you need to "register" the hook with your desired layer via the register_forward_hook method.
Once registered, the hook will be executed right after the forward method. You do not have to worry about triggering it manually!

Too vague,… I need to see an example!

I created a hands-on video tutorial where I explain step by step how to use forward hooks together with TensorBoard. The goal is to visualize activations of any layer of choice (=creating a histogram of its values for a given sample / batch).

The tutorial does not talk about several related (interesting) topics.

  • backward hooks
  • forward pre hooks

I would encourage the reader to learn more about them:)


Hooks are hidden gems of PyTorch. Specifically, the forward hooks allow you to debug and visualize what is going on inside of your network. This post provided a first look into what they are and how one can use them.


Cover photo

Top comments (0)