Why Attention became a huge topic in Neural Network

Dhruv Kabra
Version 1
Published in
9 min readMay 19, 2023

“Let’s discuss what made attention a hot topic in the neural network world. Before we dive into it, first, let’s see that the intuition behind attention is based on how humans perceive the world. If we are given the question of how convolution is equivalent with respect to translation while having the deep learning book, instead of reading the entire book and trying to remember all the information, a better approach is to flip to the chapter on convolutional neural networks and find the part where equivalence is explained. Imagine if that same text were given to a computer, the computer might learn the whole book for us, which is quite counter-intuitive.”

Let’s discuss a Simple feedforward neural network. There once was a musical prodigy by the name of Melody. In order to build a machine that could produce music much like her, Melody set out to make lovely tunes. She was aware that the solution to realizing her dream might lie in a feedforward neural network.

Melody created a feedforward neural network in her musical adventure that resembled a series of musical notes. Each note added to the overall melody and had a unique significance. The input neurons, hidden layers, and output layers of the network were analogous to the various parts of a musical composition.

Melody supplied the network with her musical input in the hopes that it would train to produce songs that were harmonious. The network handled the input layer by layer, much like a musical instrument.

This limitation became evident when Melody wanted to create music with a sense of continuity and flow. Without memory, the network couldn’t capture the connections between different musical phrases or remember the patterns it had learned. It was like trying to compose a symphony with a composer who had no recollection of the previous movements.

Hence RNN became a hot topic because RNN was a neural network with a feedback loop, so it can hold temporal information.

This was easily solved if word sentences were small, but what if the sentence is 1000 words long hence the problem arises with RNN. Let’s discuss the problem with RNN. Imagine an RNN which looks back 4 steps back in time, it’s basically an unrolled Neural network 4 times in which weights are passed to the next unrolled neural network layer.

Example of a RNN with 4 unrolled layers. (source: Statsquest, via pixabay (CC0))

To simplify the explanation of the problem, let’s focus on a specific weight, W2, in the network. During training, we calculate gradients for each parameter and use them in the gradient descent algorithm to update the parameter values and minimize the loss function.

In the case of the exploding gradient problem, when we set W2 to a value larger than one, the input values get increasingly amplified with each RNN step. For example, if we unroll the network four times, the input value would be multiplied by 2 and raised to the power of 4, resulting in a significant amplification. This amplification becomes more pronounced as the network is unrolled more times, leading to an explosion of gradients. When these large gradients are incorporated into the training process, it becomes challenging to take small steps towards finding optimal weights and biases, resulting in unstable training progress.

On the other hand, in the vanishing gradient problem, when W2 is set to a value less than one (e.g., 0.5), the input values progressively diminish with each RNN step. As a result, the gradients become extremely small, approaching zero. With such small gradients, it becomes challenging to make meaningful updates to the parameters, leading to slow convergence or hitting the maximum number of allowed steps without reaching the optimal solution.

To mitigate the exploding gradient problem, one approach is to limit the values of W2 to less than one. However, this leads to the vanishing gradient problem. Both issues are challenging because they hinder the ability of the network to effectively learn and update its parameters during training.

Another problem with the RNN network is Parallelization, since the model is trained sequentially on each input it cannot be parallelized.

Let’s say we want to predict monthly car sales. Just like many other products, car sales experience significant variations throughout the year, with different seasons affecting consumer demand. For instance, car sales tend to be lower during the winter months and higher during the summer.

To capture this seasonal pattern that repeats every 12 periods, we can employ an LSTM network. Unlike traditional models that struggle to capture long-term dependencies, LSTMs excel at retaining information from earlier time steps and incorporating it into their predictions. This ability to consider longer-term context helps overcome the challenges faced by other models when dealing with complex patterns over extended periods.

While this example specifically focuses on car sales, the application of LSTMs becomes even more valuable when dealing with longer intervals, such as analyzing trends in extensive textual data or predicting sales patterns in industries with seasonal fluctuations.

By utilizing the memory and contextual understanding capabilities of LSTMs, we can effectively model and forecast complex patterns in car sales, enabling us to make more accurate predictions and informed business decisions.

To solve his issue, LSTM were used because instead of just a normal neuron, we started using neurons with memories instead of dumb neurons. Each neuron would be replaced like this

A LSTM Cell (Photo by Rian Dolphin on Medium.com)

1. Forgotten door (Forgot Gate):
Imagine you are in a library and you have a shelf full of memories. However, not all memories are equally important to the current task. The Forgotten Door is like a librarian deciding which books to keep on the shelf and which to throw away. Librarians look at the relevance of each book and assign a number from 0 to 1 to indicate its importance. This assignment is done using a sigmoid function neural network, which guarantees values between 0 and 1. Books with values close to 0 will be forgotten and those closer to 1 will be stored. The librarian then multiplies the forgotten values by the previous memory collection, deleting the less important ones and preserving the more important ones. Therefore, a sigmoid function is used at this gate. A documented sigmoid function (Sigma)
2. Front door ( Input Gate):
Now imagine you have a new set of books to add to your collection. But you want to be selective and choose only the things that are truly valuable. The front door helps you make this decision. It includes two parts:
New memory network and input filter. The new memory network analyzes new books and combines them with your previous collection. It creates a “new memory update” that represents the knowledge gained from these books. However, not all the information in the new memory update is worth remembering. This is where the input filter comes in. The filter acts as a gatekeeper, determining which part of the memory is important. It assigns a value between 0 and 1 to each memory element, with a value close to 1 indicating importance. The new filtered memory will then be added to your storage collection.
3. Update cell state:
Now that you have new memories to add, you update your memory collection, called the cell state. New memories go through the update process. The updated memories are obtained by multiplying the new filtered memories by the input values. This multiplication adjusts for the impact of new information and allows for reduction if necessary. Finally, the updated memories will be added to your existing collection, enriching your long-term memory.
4. Exit door ( Exit Gate) :
With your updated memory collection, you want to provide a useful output depending on the current task. Imagine you have a talent for storytelling and want to share a summary of your memories. The exit door helps you decide which memories to include in your summary. It looks at your updated memories, past outings and current situation and decides which is relevant. The output port assigns important values to each element of the memory, as well as the forget and input ports. The key memories are then transformed using the tanh function, which crushes the values from -1 to 1, preserving their nature. Finally, the overwritten memories are combined and you share your output, representing your prediction or the result of your LSTM gate. By going through these stages, the LSTM gate effectively manages its memories, filtering out irrelevant information, capturing valuable insights, and providing meaningful predictions or outputs based on the given context.

The major problem with RNN is solved by LSTM, the Vanishing Gradient Descent problem but it is still very slow to Train.

The rise of Encoder and Decoder Architecture came into the picture to deal with limitations of RNNs and LSTM, let’s take neural machine translation for example.

In the picture below you can see the encoder part on the left side and the decoder part on the right side. The encoder will take three German words and the decoder will get the hidden state information from h3 to decode it into English words.

photo taken from https://dennybritz.com

Imagine you have a picture in front of you. It’s a picture of a sentence written in a different language that you don’t understand. Your task is to translate that sentence into English. But here’s the catch: you can only look at the picture once and then you have to close your eyes.

To tackle this challenge, you come up with a clever strategy. You decide to break down the task into two parts: understanding the sentence and generating the translation.

First, you use an encoder, which is like your brain, to analyze the picture. The encoder takes each word in the sentence and processes it, keeping track of the information it gathers along the way. The encoder’s job is to capture the meaning of the sentence and encode it into a single vector, called the sentence embedding. It’s like condensing all the important information from the picture into a compact representation.

Now, with the sentence embedding in hand, you’re ready to generate the translation. You employ a decoder, which is like your mouth, to produce the English translation word by word. The decoder starts with the sentence embedding and generates the first word of the translation. Then, based on that word, it generates the next word, and so on, until it reaches the end of the sentence.

But here’s the interesting part. The decoder relies heavily on the last hidden state of the encoder, represented by the vector h3. This vector is expected to contain all the essential information about the source sentence. It’s like a summary of the picture you saw. The decoder uses this summary to guide its translation process.

However, there’s a concern. You might wonder if it’s reasonable to assume that a single vector can encode all the information from a potentially long sentence. Let’s say the source sentence has 50 words. The first word of the English translation is likely related to the first word of the source sentence. But that means the decoder needs to consider information from 50 steps ago, and somehow that information needs to be encoded in the vector.

It’s indeed a challenge to encode and retain all the details of a long sentence in a single vector. However, the remarkable thing is that, in practice, this approach often works well. By carefully training the encoder-decoder model, sentence embedding tends to capture the essential information and semantic meaning of the source sentence. It’s like compressing the knowledge into a compact representation.

NMTs worked well for shorter sentences but performed poorly when sentences became large.

As we discussed in the first sentence of this article humans use the attention mechanism to grab the piece of information they want rather than processing the whole information.

Convolutional Neural Networks (CNNs) came to the rescue when we encounter problems that need to be solved efficiently. These networks offer several advantages:

  1. Easy parallelization: CNNs can be parallelized effortlessly, allowing us to process multiple inputs simultaneously at each layer. This speeds up the computation and makes it more efficient.
  2. Local dependencies: CNNs excel at capturing local dependencies within the data. They focus on small patches or neighbourhoods of the input and learn patterns and features specific to those regions. This ability helps them recognize spatial or temporal patterns in data effectively.
  3. Logarithmic distance: CNNs have the advantage of considering the distance between positions logarithmically. This means that they can capture and understand relationships between elements that are far apart more effectively. It allows them to learn and generalize from data that has a wide spatial or temporal extent.

Two well-known examples of CNNs used for sequence transduction tasks are Wavenet and Bytenet. These networks leverage the power of convolutional operations to process sequential data and generate desired outputs. By taking advantage of the inherent properties of CNNs, they can handle complex patterns and generate accurate results for tasks like speech synthesis or machine translation.

Unlike RNN which took relationships from each of the last layers, CNN is somewhat similar but it only needs a layer in the logarithmic distance, it still is sequential and does not necessarily solve the attention

This made rise to Transformers which will be discussed in depth in coming articles.

About the author:
Dhruv Kabra is a Python Developer here at Version 1.

--

--