In this article, we will take a look at the type of Recurrent Neural Network(RNN) that can overcome the vanishing gradient problem that simple RNNs suffer.
Information is the new electricity, gold, petrol, diamond…
It has become the most important prize possession of the 21st century that Tech Giants are either stealing, buying or asking for it without a proper disclaimer to what they are going to do with it; funny enough, it always played a big role not only in our current society but also throughout the animal kingdom.
“As a general rule, the most successful man in life is the man who has the best information.” — Benjamin Disraeli
Information created and started to develop its first-ever GMaps as early 10,000 years ago with the family of birds known as Columbidae, from the species pigeon, that evolved to the early domestic pigeon (Columba livia domestica).
The domestic pigeon is in turn derived from the rock pigeon, which the world’s oldest domesticated bird, that according to Mesopotamian cuneiform tablets mentions the domestication of birds dates back 5000 years ago. Domestic pigeons were vastly used during peace and wartime, to carry messages due to their amazingly spectacular homing abilities or perhaps should I say Google Maps ability according to today's standards; interesting enough is that if put side by side with Google Maps, the nature build homing system is a few thousand years ahead and definitely will win in terms of accuracy, UI and UX, only to fail in sales and marketing.
Homing: Ability to find its way home over extremely long distances. Flights as long as 1,800 KM have been recorded by bird in competitive racing. This entire system works without a smartphone or any other electronic device(brain aside), by just using a inner compass build and developed for thousand of years.
Researchers believe that these homing abilities of pigeon work with magnetoreception, and using a “map and compass” model, with the compass feature allowing birds to orient and the map feature allowing birds to determine their location relative to a goal site (home loft). Birds can detect a magnetic field to help them find their way home.
Magnetoreception: is a sense which allows an organism to detect a magnetic field to perceive direction, altitude or location.
Therefore we can safely say that they are fueled by receiving, storing, processing and recalling information using their sensory input to capture information and their tiny brains’ intriguing processing power and memory to make the whole system function.
Information is a great tool and can look like pure sorcery when proper systems are developed to store, process and analyse data in ways never thought of before. It allows animals to know a nest location, landmarks for navigation in a home range, where food and water have been found the past, and how previous social interactions with another animal have turned out, all are critical pieces of information in shaping future behaviour.
How does memory work?
Memory starts as a biochemical response in the brain following sensory inputs.
There are 3 types of memory:
- Short-term — It is triggered by any transient change in your brains’ neurotransmitter(chemical messenger) levels at synapses(a structure that permits a neuron to pass an electrical or chemical signal to another neuron).
- Mid-term — Lies between short and long term memory, it is where the consolidation phase of memory takes places(normally happens during sleep).
- Long-term — Information that has critical future importance is moved to long-term memory through reinforcement of the learned event.
Now, coming back to AI we can see how all of this relate to the topic of this article.
How does it all relate?
In my previous article, we had a brief introduction to a few fundamental DL algorithms and methods for sequence processing, namely:
- Word Embeddings — a method used to map human language into a geometric space
- Recurrent Neural Networks(RNN) — A type neural network has memory. It processes sequences of data while keeping track/memory of what it has seen so far, some examples of RNNs are: LSTM, GRU and BI-LSTM.
- 1D Convents — the one-dimensional version of the 2D/3D convnets used in the computer-vision domain.
Understanding Long-Short Term Memory
In my last article, we talked about vanilla RNN which is a layer in Tensorflow Keras framework called SimpleRNN or RNN and if you are a big fan of Pytorch I got you covered, it is under the Neural Network(NN) package by the name RNN(torch.nn.RNN).
This is not the only RNN layer that exists, let me tell you why.
Vanilla RNN has a major issue: although it is able to retain memory of all the information seen in previous timesteps(t) where t ≥ 1; in practice, such long-term dependencies are impossible to learn. This is due to the vanishing gradient problem which affects neural networks that are many layers deep(has many layers stacked), as we keep on adding layers to a network, the network becomes untrainable, to put it into simple terms, the ability to learn decreases with the more layers we have.
Just like when we learn new topics, we need to constantly revise past information in order to learn better and keep us from forgetting that information.
The solution to this is injecting/adding some past information to later layers so it can revise it, and propagate better the learning signal through the network, thus easily training bigger networks and avoid the degradation of the knowledge(weights) acquired by deeper layers of the network. Yet, when it comes to RNN we must compute by how much you want the past information to affect the present, because remembering all doesn’t help and neither is not remembering, so we must have a function that determines how much to remember from the past.
“We must welcome the future, remembering that soon it will be the past; and we must respect the past, remembering that it influences the future” — P.Canuma
The theoretical reasons for the effect of remembering everything had on RNNs were studied by Hochreiter, Schmidhuber and Bengio in the early 1990s.
The Long-Short Term Memory(LSTM) and Gated Recurrent Unit(GRU) layers were designed to solve this problem.
The SimpleRNN is the starting point for LSTM, but the LSTM layer was designed with a few considerations in mind; essentially the LSTM saves information for later, thus preventing older signals from gradually vanishing, and it does so in a way where it calculates by what percentage to inject past information(conceptually).
Now, let’s dive into the nuts and bolts of fig.1, to get a better sense of the flow of data inside the LSTM cell.
We have two weight matrices W and U inside the cell(containing the “knowledge” as I like to call it) which are indexed with the letter o (Wo and Uo) for output. We also have state_t which is the output of the previous step t.
Furthermore, let’s now add to the picture an additional data flow that carries information across timesteps(t). Calling its values at different timesteps Ct, where C stands for carry. The information stored in Ct(the aggregate of all outputs at time t) will be combined with the input(sequence data) and the recurrent(previous output), all of this is done via a dense transformation: a dot product with a weight matrix followed by a bias add and the application of an activation function, and it will affect the state being sent to the next timestep(via an activation function and a multiplication operation).
Bear with me…
Conceptually, the carry dataflow is a way to modulate the next output and or state.
Calculating the Ct(carry dataflow) involves three distinct transformations. All three have the form of SimpleRNN cell:
y = activation(dot(input_t, W) + dot(state_t, U) + b)
But all three transformations have their own weight matrices, which you’ll index with the letters i, f, and k.
Simple!!! See it’s, not that hard.
Anatomy of the LSTM
We can interpret what each of these operations is meant to do, but interpretations don’t do much, because what these operations actually do is determined by the contents of the weights(knowledge) parameterizing them; and the weights are learned in an end-to-end fashion, starting over with each training round, making it impossible to this or that operation with a specific purpose.
The same cell with different weights can be doing very different things.
A RNN cell determines your hypothesis space — the space in which you’ll search for a good model configuration during training. So, the combination of operations making up an RNN cell is better interpreted as a set of constraints on your search, not as a design in an engineering sense.
To a researcher, it seems that the choice of such constraints — the question of how to implement RNN cells — is better left to optimization algorithms than to human engineers i.e. genetic algorithms or reinforcement learning processes which will use all computational power available to find the best set of constraints(hyperparameters) which yield the best accuracy for a specific problem. And in the future, that’s how neural networks will be built.
- LSTM cell is meant to allow past information to be reinjected at a later time, by calculating by what percentage will we allow past information to affect the present information, thus fighting the vanishing-gradient problem that vanilla RNN has.
Note: I have prepared a colab notebook especially for you so you can understand all the concepts actually seeing how it’s done from scratch and how to use the production ready layers built into the famous frameworks like TF and Torch.
Thank you for reading. If you have any thoughts, comments or critics please comment down below.
Follow me on twitter at Prince Canuma, so you can always be up to date with the AI field.
Stored information plays a critical role in the lives of many animals. Knowing a nest location, landmarks for…
- Deep learning with python by Francois Chollet