How exactly does torch.autograd.backward( ) work?

Saihimal Allu
7 min readJun 17, 2018

--

Okay, I get it. No one writes blogs about functions that are used in programming frameworks. Particularly when the said framework is PyTorch, which has one of the more extensive documentation available out there, but this function holds a special place in my “heart” which is putting it lightly. The sheer amount of frustration I felt while using this function, trying to understand what was going under the hood got me exhausted. I finally figured it out, hence this blog to ensure that nobody goes through the same painful process again.

PyTorch is a library that provides abstractions to reduce the effort on part of the developer so that deep networks can be easily built with little to no cognitive effort. Why would anyone have trouble understanding what’s already been simplified, you might think. On a personal note, I have always had a problem with libraries and API’s; they tend to abstract away a lot of details….. some of them which feel fundamental to understanding the core process going underneath…which makes me really uneasy while using them. The normal amount of control that I have on the flow of a program when I hand code something is lost while using API’s.

Okay first let me give you some background about my point of view to justify the existence of this blog and my motivation to blow up another hour writing and editing it.

As a part of my research intern, I had started getting familiar with PyTorch since most of the code I was supposed to read was written in it. As is the standard workflow while getting acquainted with a new framework, I started working on the “Deep Learning with PyTorch: A 60 Minute Blitz” tutorial when I ran into a snag on reading the following piece of code…

x = torch.randn(3, requires_grad=True)y = x * 2
while y.data.norm() < 1000:
y = y * 2
print(y)gradients = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)
y.backward(gradients)
print(x.grad)

The corresponding output that followed was this:

tensor([-590.4467,   97.6760,  921.0221])
tensor([ 51.2000, 512.0000, 0.0512])

I had no idea what happened when the tensor was passed as an argument to the tensor.backwards call. So obviously I googled the intuition behind it and I went down a rabbit hole of information none of which I was being able to piece together. It was frustrating, to say the least, and this was not helped by the fact that no one really had a problem understanding it. I finally was able to understand what I felt was the “missing piece” and since I wasted close to two hours running tests of my own and scouring PyTorch Developer forums, I wrote this blog so that if anyone down the lane faces the same problem he/she would not need to find a solution from some nondescript part of the internet.

So let’s dive in… let’s suppose that I have this following snippet of code that I want to execute. I have added print functions for better visualization.

from torch.autograd import Variable
import torch
x = Variable(torch.FloatTensor([[2,1]]), requires_grad=True)
print (x)
print ('\n')
M = Variable(torch.FloatTensor([[1,2],[3,4]]))
print (M)
print ('\n')
y = torch.mm(x, M)
print (y)
print ('\n')
#jacobian = torch.FloatTensor(2, 2).zero_()
#print (jacobian)
y.backward(torch.FloatTensor([[1,0]]), retain_graph = True)
print(x.grad.data)
print('\n')x.grad.data.zero_()
y.backward(torch.FloatTensor([[0,1]]), retain_graph = True)
print(x.grad.data)
print('\n')
x.grad.data.zero_()
y.backward(torch.FloatTensor([[1,1]]), retain_graph = True)
print(x.grad.data)
print ('\n')

This gives the following output:

tensor([[ 2.,  1.]])
tensor([[ 1., 2.],
[ 3., 4.]])
tensor([[ 5., 8.]])
tensor([[ 1., 3.]])
tensor([[ 2., 4.]])
tensor([[ 3., 7.]])

(Note: Example inspired by this discussion: https://discuss.pytorch.org/t/clarification-using-backward-on-non-scalars/1059)

So the first three outputs are trivial, the next three outputs are the items of interest. I could not understand the reason behind this result. When [1,1] was passed as an input to the grad_variable argument, I was (seemingly) getting the sum of the outputs when I passed [1,0] and [0,1] as arguments, even though everyone in on the discussion had figured it out (Now that I know the reason, I feel pretty dumb)

When the .backwards method is called on a scalar value, PyTorch preempts the grad_variable argument to be Torch.Tensor([1]) . The problem comes in when we attempt to call the method on a vector input.

To better fit in with the intuition, let’s assume that we have a standard neural net with 6 fully connected layers. After these 6 layers, I compute the result of the loss function for the neural network which by definition is a scalar value. So if I were to run the .backwards method on the loss function, the PyTorch behaves as we expect it by preempting the grad_variable argument to be torch.Tensor([1]) . Running the method on a vector can be equated to the process of me calling the method on a layer that is in the middle of the network something like the 4th layer. Let the 4th layer be represented by vector X. There is a scalar value downstream which quantizes the loss function, right now we have no idea about it but I can theorize that I can express the loss as some function of the vector X or more specifically as some function of the elements of the vector, i.e loss = f(x1,x2,x3,x4,……..xn,………) where f is some arbitrary function. Here the additional dots after xn indicate that the loss also depends on other elements in addition to the elements of the vector X.

Okay, that was the base of this blog, things only get easier from here. Since PyTorch only implements the backpropagation algorithm when a scalar (loss) is passed as an argument, it needs extra information when a vector is passed as an argument so that it can run the normal underlying functions when the scalar is passed as an argument. More specifically it needs this information:

[d(loss)/dx1, d(loss)/dx2 ,…….. d(loss)/dxn]

and this is exactly what we are providing in the grad_param argument, the gradient of the loss function which we are going to encounter downstream w.r.t. each element of the vector X. If this information were to be provided by us then, then theoretically the backpropagation algorithm implemented by PyTorch still works and we can calculate the gradient of the loss function w.r.t. each of the inputs to the neural net. This is the “missing” piece of the puzzle and provides an intuitive basis for passing tensors as an argument to the grad_variables function.

Coming to the snippet shown above which kicked things off, I will use this intuition to explain the outputs obtained in the above snippet of code.

Let’s suppose that L represents the value of the loss function down the lane. Let the first matrix be represented as

b1 =2, b2 = 1 in the above example

The second matrix is given by

and the multiplication is given by

express this as [y1 y2]

Now since we are interested in x.grad.data , we are interested in finding out the gradient of L w.r.t. b1 and b2. Let y1 and y2 represent the elements of the product vector

Mathematically this is equivalent to the equations given below

now the expected outputs can be estimated by plugging in the values of the gradients of the loss w.r.t. y1 and y2 which is what we are passing as an argument to the function call. In the case of passing [0,1] as an argument ( dL/dy1 = 0, dL/dy2 = 1), solving the equations with these values we get the result as [a1 a2] (here a1 is the solution of the first equation and a2 is the solution of the second equation) which is the expected output. Plugging in the values as [1,1] it is pretty clear that my hunch that the output is the sum of the outputs when [1,0] and [0,1] are passed as arguments is correct.

So I hope that this intuition works on bettering your experience using PyTorch. PyTorch has one of the best documentation which is oriented to the user without any unnecessary usage of programming jargon, so it’s pretty easy to understand and follow.

Edit: I would like to thank Verena Haunschmid for pointing out the errors that the code snippet had in an earlier version of the story.

--

--