Understanding Neural Arithmetic Logic Units
By Akil Hylton
DeepMind recently released a new paper titled, Neural Arithmetic Logic Units (NALU). It’s an interesting paper that solves an important problem in Deep Learning, teaching neural networks to count. Surprisingly, although neural networks have been able to achieve state of the art results in many tasks such as categorizing lung cancer, they struggle with simpler tasks, like counting numbers.
In one experiment demonstrating how networks struggle to interpolate features from new data, the researches of the paper found that they were able to classify training data with numbers that ranged between -5 and 5 with near perfect accuracy, but with numbers outside the training data, the network couldn’t seem to generalize.
The paper offers a solution to this, in two parts. Below I’ll briefly describe how NAC works, and how it can handle operations such as addition and subtraction. After that, I’ll introduce NALU, which can handle more complex operation such as multiplication and division. I included code you can try that demonstrates these, and you can read the paper above for more details.
First Neural Network (NAC)
The Neural Accumulator (or NAC for short), is a linear transformation of its inputs. What does this mean? It takes a transform matrix which is the element-wise product of tanh(W_hat) and sigmoid(M_hat). Finally, the transform matrix W is then multiplied by the input (x).
NAC in Python
Second Neural Network (NALU)
The Neural Arithmetic Logic Units or NALU for short consist of two NACs. The first NAC g equal sigmoid(Gx). The second NAC operates in a log space m which equals exp(W(log(|x| + epsilon))).
NALU in Python
Test NAC by learning the addition➕
Now lets run a test, we’ll start off by turning NAC into a function.
Next lets create some toy data which will be used for training and test data. NumPy has a great API called numpy.arrange that we will leverage to create are dataset.
Now we can define the boilerplate code to train are model. We start with defining the placeholders X and Y to feed the data at run time. Next we define are NAC network (y_pred, W = NAC(in_dim=X, out_dim=1)). For the loss we will us tf.reduce_sum(). We will have two hyper-parameters, alpha which is the learning rate and the number of epochs we want to train the network with. Before the training loop is ran we need to define a optimizer so we will use tf.train.AdamOptimizer() to minimize the loss.
After training this is how the cost plot looks:
Actual sum: [2000. 2012. 2024. 2036. 2048. 2060. 2072. 2084. 2096. 2108.]
Predicted sum: [1999.9021 2011.9015 2023.9009 2035.9004 2047.8997 2059.8992 2071.8984
2083.898 2095.8975 2107.8967]
While NAC can handle operations such as addition and subtract it cannot handle multiplication and division. That is where NALU comes in. It is able to handle more complex operation such as multiplication and division.
Test NALU by learning the multiplication✖️
For this we will add the pieces to make the NAC into a NALU.
The Neural Accumulator (NAC) is a linear transformation of its inputs. The Neural Arithmetic Logic Unit (NALU) uses two NACs with tied weights to enable addition/subtraction (smaller purple cell) and multiplication/division (larger purple cell), controlled by a gate (orange cell).
Now lets again create some toy data however, this time we will make two line changes.
Lines 8 and 20 is where the changes were made switching the add operator to multiplication.
Now we can train are NALU network. The only change we will make is where we define are NAC network instead we will us NALU (y_pred = NALU(in_dim=X, out_dim=1)).
Actual product: [1000000. 1012032. 1024128. 1036288. 1048512. 1060800. 1073152. 1085568.
Predicted product: [1000000.2 1012032. 1024127.56 1036288.6 1048512.06 1060800.8
1073151.6 1085567.6 1098047.6 1110592.8 ]
Full implementation in TensorFlow
I initially paid little attention to the release of this paper. I became interested after watching Siraj’s YouTube video on NALU, my interest peeked. I wrote this article to help others understand it, and I hope you find it useful!