The Reparameterization “Trick”

As Simple as Possible in TensorFlow

A worrying pattern I see when trying to learn about new machine learning concepts is that I will search the internet and be immediately drowned in maths, and walk off being none the wiser. Sometime later, someone will explain it to me or I have an ‘a-ha’ moment and mutter to myself “OoooooH! Why didn’t they just SAY that then?”. I think this is what my articles are going to be about: me explaining things in the way I wished they had been explained to me when I was first trying to learn about the subject. As I explained last time, I personally need concrete examples of an idea applied to a simple version of the problem for things to sink in.

The reparameterization trick is a perfect example: a basic idea which is often made overly complicated. For instance, look at these stack exchange answers, the top Google result for “reparameterization trick”. Did that help? Did we really need equations full of probabilities and integrals, or images of computational graphs full of calculus to explain the idea?!

So here is it. I am going to show you a concrete example of the reparameterization trick in less than 20 lines of TensorFlow and zero equations, as concise as I could make it. We are going to look at an extremely simple model to learn what the reparametrization is.

Let’s get started.

import tensorflow as tf

The model is going to transmit a single real number over a noisy channel. To simulate gaussian noise, the output of the model will be sampled from a gaussian around the input. The only parameter of the model will be the standard deviation of this gaussian.

So here’s the model, and the loss.

delta = tf.get_variable('delta', initializer=1.)  # Our only param
x = tf.random_uniform((), -1., 1.) # Input
y = tf.random_normal((), mean=x, stddev=delta) # Output
loss = tf.losses.mean_squared_error(y, x)

Now, clearly, if we are optimizing for accuracy when transmitting this real number to the output then the the standard deviation of the gaussian should be low, right? That’s another way of saying that the noise should be low. Meaning that the output number is likely to be close to input number, which is what we are optimizing for. Let’s see if we can learn to do this using just pure back propagation. Let’s optimize it:

apply_update = tf.train.GradientDescentOptimizer(0.1).minimize(y)
ValueError Traceback (most recent call last)
<ipython-input-4-adda79011a87> in <module>()
----> 1 apply_update = tf.train.GradientDescentOptimizer(0.1).minimize(y)

/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/optimizer.pyc in minimize(self, loss, global_step, var_list, gate_gradients, aggregation_method, colocate_gradients_with_ops, name, grad_loss)
293 "No gradients provided for any variable, check your graph for ops"
294 " that do not support gradients, between variables %s and loss %s." %
--> 295 ([str(v) for _, v in grads_and_vars], loss))
297 return self.apply_gradients(grads_and_vars, global_step=global_step,

ValueError: No gradients provided for any variable, check your graph for ops that do not support gradients, between variables ['Tensor("delta/read:0", shape=(), dtype=float32)'] and loss Tensor("mean_squared_error/value:0", shape=(), dtype=float32).

Oh no! We have a problem, this doesn’t work, we don’t have any gradients! The specific problem here is that we can’t backpropagate through therandom_normal function. This makes sense right? We are trying to back propagate through a random/stochastic node in the computational graph. It doesn’t make much sense to differentiate in terms of a stochastic node since that means the gradient would technically be a random variable too! What is the gradient of a value that might have been different?

So let’s … reparametrize! That is, let’s change how the parameters are incorporated into the model. Since random_normal does not have gradients, we simply move the parameters outside of the function:

delta = tf.get_variable('delta', initializer=1.)
x = tf.random_uniform((), -1., 1.)
y = x + tf.random_normal((), mean=0., stddev=1.) * delta
loss = tf.losses.mean_squared_error(y, x)

And that’s it, seriously, that’s it, that’s the Reparameterization Trick! Applied to this specific example at least. We have moved the all the parameters out of the normal distribution, but this does not change the behaviour of our model. For example, adding x to a zero centered (zero mean) gaussian is the same as having a gaussian centered at x. Now, the gradients don’t need to flow through the random_normal function, since it has no learnable parameters, nor any variable inputs from the rest of the graph. No gradient calculations need to go through this stochastic node so now our model has well defined gradients.

Let’s try again:

apply_update = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
sess = tf.InteractiveSession()
for _ in range(100):
> 1.08802e-14

Success! Our model now learns to make the delta small to reduce the noisiness of the channel to achieve a low loss.
> 3.461055e-06

OK, actually, please forgive me, I lied, for pedagogical reasons. I don’t know what version of TensorFlow they did this in, but the clever folks at Google have actually made random_normal differentiable.

out = tf.random_normal((), mean=x, stddev=delta)

Works perfectly well… but that rather ruined my story didn’t it! They’ve already reparameterized this function.

The Maths Bit

Now, that is not the whole story of course. This is the trick applied to a specific simple example. Nevertheless, what I showed you above is the essential idea behind the trick and now you can build on this foundation. You should now be able to tease out this idea from the stack exchange answers to get a deeper understanding. But there are other motivations, from a mathematical point of view, as explained there.

You could stop reading now and leave knowing the basic idea behind the trick, but before I go I want to generalize the idea a bit by showing you what we did above from a mathematical point of view. So, I guess I lied about the equations too…

What we created above was a probabilistic model. That is, unlike vanilla neural networks, the output of the network is non-deterministic, the output will be different every time we run it. This means that we can no longer just blindly apply gradient descent to the model, and why the initial version of the code above (should have) failed. The reason for this should be clear, the gradient of a model is ‘how the output changes when we change the parameters’, but here we have a model where the output is different whenever we run the model, even if we don’t change the parameters!

So let’s take another look at the model. The output of the model is described by normal distribution, where 𝓍 is our input and 𝜎 is our standard deviation parameter:

Output of the model

As mentioned we can’t optimize against this directly, so what we do instead is to try to optimize the most probable output of the model, or what is also known the the expectation.

Now, the problem is that in most cases we can’t calculate the expectation efficiently in most models (we can in our model since it’s so simple, the expectation of a normal distribution is its mean by definition, which is the input in our case). Therefore we need a way to approximate it. The simplest way of doing this is just taking a lot of samples, you can imagine running the model several times and then just averaging the outputs. The more samples we take, the closer the average will be to the true expected output.

So how many samples to we need? The more the better but in practice people usually just do one!

Which is exactly what we effectively did in the model above. So our loss was

Another way of thinking about it is that we have moved the source of noise outside of the main flow of the network and used the noise as a way to sample from the expectation.

You should now have an understanding of the basic idea behind the trick and how it applies to probabilistic models. A nice way to build on it might be to go learn about Variational Auto-Encoders (VAEs), which is an example of where this trick is applied in practice.

One clap, two clap, three clap, forty?

By clapping more or less, you can signal to us which stories really stand out.