Understanding Neural Arithmetic Logic Units

A recent paper from DeepMind researchers has introduced a new neural network module capable of counting objects and numbers and keeping track of time.

In this piece, we’ll found out why this is important, how it works, what they achieved, and finally recreate some of the experiments in the paper using TensorFlow. This first part covers the problem the paper set out to solve, and the novel contribution it made.

The Problem

Neural networks are great, aren’t they? Plug in some input data, give it appropriate target labels to match, and sit back as it learns to represent whatever function it needs to. But if you’ve ever trained your own nets, you may have noticed a little (really big) problem: if you test it on data outside the range you would find in the training data, it will perform poorly.

In this regard, it can be said that neural networks aren’t really learning the concepts which determine the output vector, they’re just memorising the question-answer responses you gave it. Then, when seeing a new data point, it’s mapped to the closest example found in training, and outputs the corresponding label.

An extreme example of this failure is in an auto-encoder which learns the identity function. This particular model is a neural network with 1 input data point, 3 hidden layers with 8 neurons each, and a single output neuron, where the target number is the same as the input number. This seems like an easy problem to solve, but with any of the common activation functions, the net works okay within the training range, but awfully outside of it.

Error of neural networks on a identity mapping task

This problem is clearly activation function-dependant. The authors note that the sharply non-linear functions perform the most poorly. While non-linear activation functions are key to neural networks’ success, it’s evident that, to be capable of tackling extrapolation, other types of layers are required.

Extrapolation is an important challenge for neural networks to overcome. After all, we extrapolate every single day of our lives: imagine you’ve just learned to read and explain pieces of TensorFlow code, say 100 lines long. Would you fall apart if I presented you with the same program, just spread over more lines? Even if you did, I’d wager you would perform less pitifully than a neural network would.

The Solution — part I

The authors present two modules, capable of adding and subtracting quantities of numbers (neural accumulator — NAC), and multiplication, division, and powers (the eponymous neural arithmetic logic unit — NALU). Let’s start with the NAC.

Like most neural network layers, the NAC contains a weight matrix, W. This matrix, however, consists only of values equal to -1, 0, or 1. This means its outputs are simple linear combinations of the inputs, and does not subject them to arbitrary re-scaling.

Unfortunately, if we sharply constrain W to contain {-1,0,1} values, it would not be continuously valued, therefore non-differentiable. Consequently, we couldn’t learn what W should be via backpropagation. So, we’ll make W a learnable function (of course, we could try something other than backprop, but we’ll save that for another post). We define:

W = tanh(V) * sigmoid(M)

Where V and M are two weight matrices. Both tanh and sigmoid are differentiable functions, so we can now learn W! But why this representation? You can see below an image of this function with two dimensional inputs. Importantly, it gets flat near -1,0,1 valued inputs. These are the stationary points of the function- using an update algorithm like gradient descent, any point will eventually be taken to the region of -1, 1, or 0. Once in one of these regions, it cannot escape easily due to the near 0 gradient.

Thanks to https://imgur.com/a/vyfE2PQ

Taking it Up a Notch

Yeah, addition and subtraction are cool, but you know what’s really cool? Multiplication. Enter NALU. The NALU can learn multiplication in the exact same way as learning addition, using the simple trick of changing the input domain.

If you’re something of a mathematician, then you’ll know that additions of natural logarithms is equivalent to the natural logarithm of the product of the numbers.

ln(x) + ln(y) = ln(xy)

So if we take the log of the inputs, apply a NAC module, and then transform it back by taking the exponential, the inverse function of the natural logarithm, we can perform multiplication using only a weight matrix of 1s and 0s!

exp(ln(a) + ln(b)) = exp(ln(ab)) = ab

The complete NALU is comprised of one NAC for addition and subtraction, and the log-NAC-exp combination for multiplication and division. The outputs of these two parts are combined as:

y = g * a + (1 - g) * m
g
= sigmoid(Gx)

a is the addition module, and m the multiplication. g weights the output between the addition and multiplication modules.

Through these calculations, the NALU can learn functions consisting of addition, subtraction, multiplication, division. Modelling power function can be achieved by stacking layers of NALU modules. E.g Imagine a NALU which takes a single input to a 2 neuron hidden layer, and then a NALU which takes it back to a single value, where the multiplication weight matrices are all 1s (ignoring the addition matrices):

Next Time

In the next post we’ll look at the experiments made in the paper — showing how a NALU can learn to count sums of numbers it sees in images and how it can help a reinforcement learning agent learn to tell the time. We’ll then implement some of these ourselves.

Props to Andrew Trask, Felix Hill, Scott Reed, Jack Rae, Chris Dyer, and Phil Blunsom for a great paper!