Recurrent neural networks (RNNs) operate like children playing a game of telephone (a.k.a. Chinese whispers). At each processing step, the RNN must encode new information it has received and pass this information through a set of feedback connections to the next processing step. The challenge of designing a RNN model is to make sure that the information does not degrade each time it is passed through the feedback connections. It is also important to make sure error correcting information can backpropagate through the model. Hochreiter and Schmidhuber were the first to solve these issues by equipping a RNN with what they called long short-term memory (LSTM). Their approach introduced gating mechanisms into a RNN model that controlled when information is stored, updated, and erased. The LSTM model is still playing a game of telephone, but it can operate with photocopier precision. Since the invention of LSTM models, several RNN architectures using alternative gating mechanisms have been proposed.
To understand the limitations of gated RNN models, suppose you have a sequence of 100,000 symbols. The first symbol would have to pass through the gating mechanism 100,000 times. This would be okay except the gates in a LSTM model can never be fully open. Assuming the gates are 99.99% open, the signal from the first symbol will degrade to 0.9999¹⁰⁰⁰⁰⁰ = 0.0000454 of its original value. So even though a LSTM model performs with photocopier precision, if the sequence is extremely long information from the beginning of the sequence still degrades.
To overcome the limitations of existing RNN architectures, a new model is needed that includes feedback connections to every processing step, not just the preceding step. One solution is to use the attention mechanism. Suppose we wanted to model time-series data using a recurrent neural network (RNN) with attention. At each processing step, the output of the RNN is weighted by an attention model. The weighted outputs are then aggregated together into a weighted average. The result from the weighted average is called a context vector. The context vector can represent information aggregated together across any time-points in the data.
A major constraint with the attention mechanism is that only one context vector is produced for the entire time-series data. The entire sequence of data must be read into the model before the context vector can be generated. In other words, the attention mechanism is static. To overcome this limitation, we recently proposed a new approach to compute the attention mechanism — using a moving average. Because the attention mechanism is nothing more than a weighted average, it can be computed as a running calculation. This requires saving the numerator and denominator from each processing step to use in the next iteration. By maintaining a moving average of the attention mechanism, a new context vector is produced at every time-step. With this approach, the attention mechanism becomes dynamic and can be computed on the fly.
We decided to take our approach one step further. We realized that the output of the attention mechanism could be feed back into the attention mechanism at the next processing step. The resulting model represents a new kind of RNN model. Given that the weighted average is recursively defined, we decided to call this approach a recurrent weighted average (RWA) model.
We started testing the RWA model on several toy problems and compared its performance to the LSTM model. On every task except one, the RWA model learned much faster using fewer training steps. Moreover, each training step for the RWA model required less clock-time. Here are select results.
As you can see in the figure, the RWA model scales better to longer sequences. We don’t expect the RWA model to always outperform alternative RNN models like the LSTM. When recent information is more important than old information, the LSTM model may be better choice (some examples). That said, there are many problems where we may want a RNN model with memory into the deep past, which is where we except to the RWA model to outperform alternative methods.