Weight initialization — The Why

Lucas Vazquez
4 min readNov 25, 2019

--

Until some time ago (more specifically, until this part of the fast.ai course) I thought weight init was just a fancy thing to do. The weights are going to be changed anyways right? So how much should it matter?

Well… A LOT!! It can (and will be) the difference between total failure and success (specially when you go deeeeep). Not convinced yet huh? Thinking I’m exaggerating? Give me five minutes and I’ll change your mind, your future models will thank you.

Why does it matter?

Just admire the cover picture of this article for a while. That picture comes from this paper, it represents the loss surface of a (bad) neural network. Very chaotic indeed…

Your SGD steps are basically transversing this surface in search of the big blue minimum. If you’re lucky you can get a good starting point that’s already very close to that region and then you only need to surf down the big drop. But what happens if you’re out of luck? Then you’re going to start very far away and I wish you very good luck trying to get there with all those bumps along the way.

These “bumpy” regions are a real problem for our optimisation algorithm, if your learning rate is too small you’re going to be stuck forever at one local minimum. Go too big and you may shoot off to space. In more “scientific” terms: The gradients of these bumpy regions are very uninformative, the gradients “shatter” and training is impossible.

So how can we avoid this? WEIGHT INITIALIZATION! Remember when I said “If you’re lucky you can get a good starting point…”? Well, we can use weight init to control our luck!

Just to reiterate, the beautiful plot you saw above represents the loss surface of some neural network. Neural networks contain millions of weights and the loss is actually a function of these weights. The authors of the paper developed an awesome technique that takes this multi-million dimension mess and represents it using only 2 or 3 dimensions so our monkey brains can understand! Amazing!!! (You should definitely read the paper, it’s relatively simple to understand and it adds a huge amount to intuition).

The surface 3D plot is very cool to look at, but not so easy to analyse, let’s instead use a contour plot.

You can see that regions close to the minimum are well behaved and smooth, these are all regions of low loss. Chaos starts to increase the further away we are from the center, these sharp regions are generally regions of high loss.

Using good weight init will increase your chances of starting in this smooth, shiny, low loss region. Okay… But WHY you may ask?!? Let me use the power of very simple code to strengthen your intuition (stay away rigorous mathematical proofs!).

Let’s simulate a simple NN with a bunch of linear layers stacked together. For our inputs let’s take a single image from the famous MNIST dataset.

The first layer of our NN receives 784 inputs and outputs 64 activations. Let’s see what happens with the mean and std of our activations after the very first layer.

Hm.. We already see that a variance of 6.11. It’s not too bad you may think. Alright, let’s see what happens if we stack 50 layers

Familiar with “Not A Number” ? It means our variance got so big that our computer could not keep track of it anymore. What happens is that our variance gets exponentially bigger with each layer.

But we want those deep nets right? They’re cool and have sexy and confusing names like “ResNet101”, what’s the magic???

Fear not fellow readers, let’s take a step towards super convergence into the next episode Weight initialization — The Fix.

--

--