Attention is not quite all you need
Whilst working on MacGraph, a neural network that answers questions using knowledge graphs, I came across a problem: How to tell if something’s not present in a list.
In this article, we’ll show you our solution, the focus signal, and show how it performs reliably across a range of different datasets and architectures.
In traditional programming, it’s easy to tell if something is not in a list: run a for loop over the items, and if you get to the end without finding it, then it’s not there. However, neural networks are not quite so simple.
Neural networks are composed of differentiable functions, so that they can be trained using gradient descent. Equality operators, for loops and if conditions, the standard pieces of traditional programming used to solve this task, do not work that well in neural networks:
- An equality operator is essentially a step function, which has zero gradient almost everywhere (and therefore breaks gradient descent back-propagation)
- If conditions generally use a boolean signal to switch branches, which again is often the output of a problematic step function
- While loops can be inefficient on GPUs, and sometimes not even useful as often neural network libraries require all data to have the same dimensions (e.g. TensorFlow executes on a statically defined tensor graph).
A popular neural-network technique for working with lists of items (e.g. translating sentences treating them as lists of words) is to apply “attention”. This is a function where a learnt “query” of what the network is looking for is compared to each item in the list, and a weighted sum of the items similar to the query is output:
In the above example,
- The query is dot-producted with each item in the list to compute a “score”. This is done in parallel for all items
- The scores are then passed through softmax to transform them into a list that sums to 1.0. The scores can then be used as a probability distribution.
- Finally, a weighted sum of the items is calculated, weighting each item by its score
Attention has been very successful and forms the basis of current best-in-class translation models. The mechanism has worked particularly well because:
- It’s fast and simple to implement
- Compared to a Recurring Neural Network (e.g. LSTM) it is much more able to refer to past values in the input sequence. An LSTM has to learn to sequentially retain past values together in a single internal state across multiple RNN iterations, whereas attention can recall past sequence values at any point in a single forward pass.
- Many tasks can be solved by rearranging and combining list elements to form a new list (e.g. attention models have been important components in many current best-in-class translation, question-answering and reasoning models)
Despite attention’s versatility and success, it has a deficiency that plagued our work on graph question answering: attention does not tell us if an item is present in a list.
This first happened when we attempted to answer questions like “Is there a station called London Bridge?” and “Is Trafalgar Square station adjacent to Waterloo station?”. Our tables of graph nodes and edges have all this information for attention to extract, but attention itself was failing to successfully determine item existence.
This happens because attention returns a weighted sum of the list. If the query matches (e.g. scores highly) against one item in the list, the output will be almost exactly that value. If the query did not match any items, then a sum of all the items in the list is returned. Based on attention’s output, the rest of the network cannot easily differentiate between those two situations.
The simple solution we propose is output a scalar aggregate of the raw item-query scores (e.g. before using softmax). This signal will be low if no items are similar to the query, and high if many items are.
In practice this has been very effective (indeed, the only robust solution of the many we’ve tested) at solving existence questions. From now on we will refer to this signal as “focus”. Here is an illustration of the attention network we showed earlier, with the focus signal added:
Finding that the focus signal was essential for MacGraph succeeding on some tasks, we tested this concept on a range of datasets and model architectures.
We constructed a network that takes a list of items and a desired item (the “query”), and outputs whether that item was in the list. The network takes the inputs, performs attention (optionally with our focus signal), transforms the outputs through a couple of residual layers, then outputs a binary distribution¹ of whether the item was found.
The network loss is calculated using softmax cross entropy and trained using the Adam optimizer. Each network variation has its ideal learning rate determined prior to training by using the learning rate finder.
In our experiments, we vary the network by:
- Including / removing the focus signal from the attention step (“Use focus” below)
- Varying the attention score function
- Varying the focus signal aggregation function
- Varying the first residual transformation layer’s activation function (“Output activation” below)
We then apply these network variations to a few different datasets, listed in the next section.
Here are all of the network configurations we tested:
The focus signal was found to be essential to the network reliably achieving >99% accuracy across the range of datasets and network configurations.
Our experiments’ code is open-source in our GitHub.
Each dataset is a set of examples, each example contains the input features List of items, Desired item and the ground truth output, Item is in list. An item is an N dimensional vector of floating point numbers.
Each dataset was constructed so 100% accuracy is possible.
We tested on three different datasets, each with a different source of items:
- Orthogonal one-hot vectors of length 12
- Many-hot vectors (e.g. random strings of 1.0s and 0.0s) of length 12
- Word2vec vectors of length 300
Each dataset has balanced answer classes (i.e. an equal number of True and False answers)
The one-hot and many-hot vectors were randomly generated. Each training example has a list of two items. Each dataset has 5,000 training items. Here’s one of the training examples:
For our word2vec dataset, we used pre-calculated word2vec embeddings trained on Google News, which can be downloaded here.
The word2vec dataset has item lists of up to length 77, each representing a shuffled, randomly chosen sentence from a 300-line wikipedia article. Each desired-item is a randomly chosen word2vec vector from the words within the article.
The focus signal was found to be essential for reliably determining item existence:
- Without the focus signal, no network configuration achieved >56% accuracy for all datasets
- For orthogonal dataset vectors, no network achieved >80% accuracy without the focus signal
- Networks with the focus signal achieved >98% accuracy in all datasets and configurations, networks without the focus signal only achieved >98% in 4 out of 9 configurations
It must be noted that the network without focus signal did achieve > 98% accuracy in some configurations and datasets.
In this article we’ve introduced a new concept for attention networks, the “focus signal”. We’ve shown that it’s a robust mechanism for detecting the existence of items in a list, which is an important operation for machine reasoning and question answering. The focus signal successfully detected item existence in all datasets and network configurations tested.
We hope that this work helps other teams in tackling the challenge of item-existence and adds more clarity to considering what architecture to use.
Octavian’s mission is to develop systems with human-level reasoning capabilities. We believe that graph data and deep learning are key ingredients to making this possible. If you interested in learning more about our research or contributing, check out our website or get in touch.
I would like to thank David Mack for the initial idea and helping with much of the editing of this article. I would also like to thank Andrew Jefferson for writing much of the code and guiding me through this write-up. Finally, I would like to thank Octavian for giving me an opportunity to work on this project.
- The output transformation layers were copied from MacGraph and could possibly be simplified. The sigmoid (σ) layer is definitely unnecessary in this model as it limits the range of the inputs to softmax and therefore prevents the model from being able to train to zero loss (although it can still achieve 100% accuracy since the output is argmax’ed to a one-hot-vector and compared with the label). We don’t believe this complexity invalidates any of the results, but felt compelled to call it out.