PyTorch Hooks
Sometimes there are many ways to do the same task. How you do it depends on the tools available and how efficient you are with them. Pytorch Hook is that tool, without which you may make a whole Neural Network and also train it, but when you know how powerful it is, you won't be able to keep your hands away from it.
So what are hooks? Hooks are functions that help to update the gradients, inputs or outputs dynamically. That is I can change the behaviour of the Neural Network even when I am training it.
Hooks are used in two places
- On tensors
- On torch.nn.Modules
One more important thing is that to apply a hook we have to first “register” where we want to apply it. It may sound a little complex now, we will understand it in the further examples.
A hook can be applied in 3 ways
- forward prehook (executing before the forward pass),
- forward hook (executing after the forward pass),
- backward hook (executing after the backward pass).
Here forward pass is the part when inputs are used to compute the values of the next hidden neurons using the weights and so on until it reaches the end and returns an output. Backward Pass happens after calculating the Loss using the output’s value and the true value, then the gradients of each weight and bias of every layer are calculated in the direction of output to input(hence backwards) using the chain rule. Basically, the step when Backpropagation happens
On Tensors
Only a backward hook is possible for Tensors. To register a tensor for the hook we can
x.register_hook( your_hook_func ) #x is a tensor
This hook function works with the gradients, and it will be activated every time a gradient with respect to the Tensor is computed.
The hook function either returns an updated gradient or None. We should not do any in-place operations, which might change the gradients of tensors connected to it in the computational graph.
#Correct way #Inplace (wrong)
def func(grad): def func(grad):
return grad+100 grad+=100
Let's see a full example and see the outputs. First let's define a hook, which will add 2 for the gradient calculated
def hook(grad):
return grad + 2
Now let's write a simple multiplication code, first, we will compute the gradients without using the hook
import torch#initializing two tensors(requires_grad = True is necessary to calculate gradients)a = torch.tensor(7.0, requires_grad=True)
b = torch.tensor(13.0, requires_grad=True)c = a * bc.retain_grad()#to store the gradient of Cc.backward()print(a.grad)
print(b.grad)
print(c.grad)
The output will be
tensor(13.) tensor(7.) tensor(1.)
So now if we use the hook, the gradient of c should increase by 2 that is it should be 3, and similarly, the gradient of a and b will be changed and their new grad will be the old grad multiplied by 3, ie 39 and 21 respectively. Let's see if it matches.
c = a * b#registering the tensor c with the hook
c.register_hook(lambda grad: hook(grad))c.retain_grad()
c.backward()print(a.grad)
print(b.grad)
print(c.grad)
The output
tensor(39.) tensor(21.) tensor(3.)
Hence it matches with what we have discussed earlier. (If you are confused with how these gradients are calculated, you better see the Autograd library)
To remove a hook, do this
d = c.register_hook(c_hook)
d.remove()
I will not be discussing the whole Hook system on Modules. But here we can use all the three hooks, that is forward pre_hook, forward and backward hook.
Let us see one great application of Forward hooks on the modules.
Finding Layer Activation using Hooks
If we ever want to calculate the activations that the Model learns, the forward hook can be very useful. Suppose you have made a model that can detect skin cancer, using the model's activation we can see where actually the model is focusing on the image. It's a great tool for describing the explainability of the Model as we can see the activation maps.
Let's build a simple CNN model, consisting of 3 layers. First would be a Convolution Layer, then an Average pooling layer and finally a Linear layer. We will try to get the activations from the pooling layer. (If you want you can get activations from every layer)
import torch
import torch.nn as nn
import torch.nn.functional as Fclass Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3,8,2)
self.pool = nn.AdaptiveAvgPool2d((4,4))
self.fc = nn.Linear(8*4*4 , 1)def forward(self, x):
x = F.relu(self.conv(x))
x = self.pool(x)
x = x.view(x.shape[0] , -1)
x = self.fc(x)
return x
net = Net()
The forward hook function has 3 arguments, module, input and output. It returns an updated output according to the function or None. It should have the following signature:
hook(module, input, output) -> None or modified output
Let's make a hook, which can collect the activations. We will use a dictionary data structure to collect them.
feats = {} #an empty dictionary
def hook_func(m , inp ,op):
feats['feat'] = op.detach()
Now usually we first train a model. I will not be doing it here(obviously). Let's assume that model that we made earlier has been trained on some data and we want now the features that it has learned.
Registering a forward hook on the Pooling layer
net.pool.register_forward_hook(hook_func)
Suppose we have fed an image of dimension 1x3x10x10(a single RGB image of dimension 10 x10) and now we want the features.
x= torch.randn(1,3,10,10)
output = net(x)
Doing this will make the activations being saved in the feat dictionary. Let me show the shape of the dictionary, as showing what is inside will be not possible due to its very big size.
print(feats['feat'].shape)#output -> torch.Size([1, 8, 4, 4])
Hence our activations got saved.
So to sum it up, We have seen what Hooks are and how it is used. It's a special tool that can be used in various ways. We can control the gradients while it's being trained, we can store activations of layers, we could change how the outputs are being calculated and many more. I have not found a very good application about pre-hooks, if you have any do tell me. Hope after reading this, you have attained some level of curiosity to explore further.
Happy Learning!
References: