How to Use PyTorch Hooks
PyTorch hooks provide a simple, powerful way to hack your neural networks and increase your ML productivity.
What are Hooks?
Hooks are actually quite common in software engineering — not at all unique to PyTorch. In general, “hooks” are functions that automatically execute after a particular event. Some examples of hooks you may have encountered in the real world:
- Website displays an ad after you visit N different pages.
- Banking app sends a notification when funds are added to your account.
- Phone dims the screen brightness when ambient light decreases.
Each of these things could be implemented without hooks. But in many cases, hooks make the programmer’s life easier.
PyTorch hooks are registered for each
nn.Module object and are triggered by either the forward or backward pass of the object. They have the following function signatures:
Each hook can modify the input, output, or internal Module parameters. Most commonly, they are used for debugging purposes. But we will see that they have many other uses.
Example #1: Verbose Model Execution
Ever find yourself inserting print statements in your model, trying to find the cause of an error message? (I’m certainly guilty of this.) It’s an ugly debugging practice, and in many cases, we forget to remove the print statements when we’re done. Now our code looks unprofessional, and users get strange information printed to the terminal whenever they use your code.
Never again! Let’s use hooks instead to debug models without modifying their implementation in any way. For example, suppose you want to know the shape of each layer’s output. We can create a simple wrapper that prints the output shapes using hooks.
The biggest benefit to this is: it works even for PyTorch modules that we didn’t create! We can quickly show this using ResNet50 and some dummy inputs.
Example #2: Feature Extraction
Commonly, we want to generate features from a pre-trained network, and use them for another task (e.g. classification, similarity search, etc.). Using hooks, we can extract features without needing to re-create the existing model or modify it in any way.
We can use the feature extractor exactly like any other PyTorch module. Running on the same dummy inputs from before gives:
Example #3: Gradient Clipping
Gradient clipping is a well-known method for dealing with exploding gradients. PyTorch already provides utility methods for performing gradient clipping, but we can also easily do it with hooks. Any other method for gradient clipping/normalization/modification can be done the same way.
This hook is triggered during the backward pass, so this time we also compute a dummy loss metric. After executing
loss.backward(), we can manually inspect the parameter gradients to check that it worked.