This article provides a tutorial for the built in methods of
tf.GradientTape and how to use them. If you already have a firm understanding of
tf.GradientTape, and are looking for more advanced uses, feel free to skip ahead to the section “Advanced Uses”
We will be using TensorFlow 2.0 in this tutorial; if you are currently running TensorFlow 1.x, you can refer to the installation page for help.
tf.GradientTapeallows us to track TensorFlow computations and calculate gradients w.r.t. (with respect to) some given variables
1.0 — Introduction
For example, we could track the following computations and compute gradients with
tf.GradientTape as follows:
- By default,
GradientTapedoesn’t track constants, so we must instruct it to with:
- Then we can perform some computation on the variables we are watching. The computation can be anything from cubing it,
x**3, to passing it through a neural network
- We calculate gradients of a calculation w.r.t. a variable with
tape.gradient(target, sources). Note,
EagerTensorthat you can convert to ndarray format with
If at any point, we want to use multiple variables in our calculations, all we need to do is give
tape.gradient a list or tuple of those variables. When we optimize Keras models, we pass
model.trainable_variables as our variable list.
1.1 — Automatically Watching Variables
x were a trainable variable instead of a constant, there would be no need to tell the tape to watch it—
GradientTape automatically watches all trainable variables.
If we were to re-run this replacing the first line with:
x = tf.constant(3.0)
x = tf.Variable(3.0, trainable=False)
The code would raise an error, as
GradientTape wouldn’t be watching
1.2 — watch_accessed_variables=False
If we don’t want
GradientTape to watch all trainable variables automatically, we can set the tape’s
watch_accessed_variables parameter to
watch_accessed_variables gives us fine control over what variables we want to watch.
If you have a lot of trainable variables and are not optimizing them all at once, You may want to disable
watch_accessed_variables to protect yourself from mistakes.
1.3 — Higher-Order Derivatives
If you want to compute higher-order derivatives, you can use nested
Higher-order derivatives is generally the only time when you would want to compute gradients inside a
GradientTape object. Otherwise, it will slow done computations as the
GradientTape is watching every computation done in the gradient.
1.4 — persistent=True
If we were to run the following:
We might expect the result to be:
But in reality, calling
tape.gradient a second time will raise an error.
This is because immediately after calling
GradientTape releases all the information stored inside of it for computational purposes.
If we want to bypass this, we can set
And now the code will output:
Just as expected!
1.5 — stop_recording()
tape.stop_recording() temporarily pauses the tapes recording, leading to greater computation speed.
In my opinion, in long functions, it is more readable to use
stop_recording blocks multiple times to calculate gradients in the middle of a function, than to calculate all the gradients at the end of a function.
For example, I prefer:
The effect is less noticeable and possibly even the opposite for a small example, however for a huge chunk of code, I believe
stop_recording blocks by far help improve readability.
Use it as you wish!
1.6 — Other Methods
Although I won’t go into detail here,
GradientTape has a few other handy methods, including:
.jacobian: “Computes the jacobian using operations recorded in context of this tape.”
.batch_jacobian: “Computes and stacks per-example jacobians.”
.reset: “Clears all information stored in this tape.”
.watched_variables: “Returns variables watched by this tape in order of construction.”
All above information quoted from the GradientTape documentation.
2.0 — Linear Regression
To start off the more advanced uses of GradientTape, let’s look at a classic “Hello World!” to ML: linear regression.
First, we start by defining a few essential variables and functions.
Then, we can go ahead and define our step function. The step function will be run every epoch to update the trainable variables,
And finally, we can call the step function 100,000 times or so and print the estimated variables.
Running this prints something like:
y ≈ 9.986780166625977x + 4.990530490875244
Which is pretty close to what our target,
10x + 5.
As a recap, here is the full example we just built:
2.1 — Polynomial Regression
We can quickly expand the previous example to work with any polynomial.
Just change up the variables we are using and the equation we are optimizing, and we are set!
The sample above outputs:
y ≈ 6.418105602264404x^2 + 7.572245121002197x + 2.0106215476989746
…after 10,000 epochs.
2.2 — Classifying MNIST
Polynomial regression is fun and all, but the real deal is optimizing neural networks.
Luckily for us, little has to change from the previous examples to do just that.
We start by following the standard procedure, loading the data, pre-processing it, and setting the hyperparameters.
Then we build the model:
And after building the model, we define the step function.
optimizer.apply_gradients(zip(gradients, variables) directly applies calculated gradients to a set of variables.
With the train step function in place, we can set up the training loop and calculate the accuracy of the model.
Running this will print an accuracy of about
0.99 after a few minutes.
Not too bad!
You can effortlessly expand the same style of
GradientTape programming as above into complex mathematical equations, with multiple models and loss functions.
tf.GradientTape is one of the most potent tools a machine learning engineer can have in their arsenal — its style of programming combines the beauty of mathematics with the power and simplicity of TensorFlow and Keras.
For more examples of tf.GradientTape I would recommend checking out some of the advanced generative TensorFlow tutorials, and my own articles on adversarial attacks and CycleGAN (shameless self-promotion, I know) to get a better idea of some real-world uses.
If you would like the source code for the examples created above, please refer to my GitHub repository here.
And until next time,