Memory Augmented Neural Network for Meta Learning — Case Study

Mohamed Afham
the-ai.team
Published in
5 min readMay 11, 2020
Figure 1: Memory Augmented Neural Network Architecture [1]

Meta Learning in simple words “Learning to Learn” is one of the fast growing research domains in the field of Artificial Intelligence specifically in Reinforcement Learning. The conventional Deep Learning architectures such as DNN, CNN and RNN are defined or built to work well for a specified task. In other words they are required to optimize the parameters (weights and biases), given that a training data set for the particular task. The idea of Meta Learning is that to allow a Neural Network to learn across previous tasks and to accomplish a new unseen task. Many researches have been carried out and novel architectures have been proposed to accomplish Meta Learning task. Memory Augmented Neural Network [1](MANN) is one of them which inspired the use of external memory from Neural Turing Machine [2].

Figure 02: Meta Training and Meta Testing data set sample

In this article, we’ll discuss the following:

  • A gentle introduction to MANN.
  • Using MANN for a few shot classification in Omniglot data set. [3]

Introduction to MANN

In the usual context, we learn the parameters(phi) of a network (weights and biases) such that the expected value of the loss function(L) across a particular Dataset (D) is optimized. It is illustrated by the following equation.

However, in Meta Learning we are required to optimize the expected values of the loss function by learning the Meta parameter (theta) across a distribution of datasets (p(D)).

The task setup proposed in the original paper is for a Dataset D = {d_t} = {(x_t,y_t)}, the sequence (x_1, null),(x_2,y_1),….,(x_t,y_t-1) is fed to the network as an input where y_i is the class label for the input vector x_i. Given a sequence of concatenated labeled inputs, the network is supposed to output y_t and the loss function is defined for the class label y_t [1]. This task setup is carried out in a distribution of Dataset such that the network learns to utilize for a new unseen task.

MANN contains three major parts.

  • Controller Network
  • External Memory Module
  • Read — Write Heads.

The Controller Network is the usual feed forward network for prediction. Since the input for the Network is a sequence of data, it is preferable to use an RNN or LSTM. As a note, remember that the External Memory module of MANN is not the memory cell in LSTM.

The External Memory Module is basically a matrix which retains the memory of important parameters from one task to the other. The communication between Memory Module and the Controller Network is done by the Read Write Heads where while reading, the representations from memory is retrieved and while writing, a new memory is encoded. The governing equations of the reading and writing processes are as follows:

Figure 03: Equations governing the External Memory Module. Extracted from [1] section 3

For each input, the memory matrix is updated based on the usage weight and previously read vector (eqn 7). The retrieved memory is then concatenated with the input vector and fed through the Controller Network. The controller network parameters (weights and biases) is then updated by Gradient Descent carried out by the preferred optimization function minimizing the Categorical Cross Entropy.

MANN for Few Shot Classification in Omniglot Dataset

Omniglot Dataset is a collection of handwritten characters of 1623 characters from 50 different alphabets [3]. Our objective is to generate a Neural Network to classify handwritten characters given a few amount of true examples each. We’ll be solving this task in the following way:

  • Split the distribution of 1623 characters database into Meta train data sets and Meta test data sets.
  • Use MANN with LSTM controller network to do training on the distribution of Meta training datasets.
  • Evaluate the performance of the model using the distribution of Meta test data sets.
Code 01: Input and True label generator for a given batch_size [5]

Now that we have designed a generator, which can be fed to the keras built-in fit_generator function.

The MANN cell implementation is inspired from the book titled “Hands on Meta Learning with Python” [4] and the detailed code can be seen in the github repository (Link provided at the end of the article).

The controller I have used is an LSTM of 128 units with activation function “tanh” followed by a Fully Connected Layer of size equal to the number of classes with “softmax” activation. The initial state of the Memory cell is defined as the zero_state with all zero tensors. After each epoch of training the Memory cell should be updated to input as the present Memory state.

The flow of work can be depicted by the following lines of code:

Code 02: Few Shot Classification using MANN

Once the model is defined in terms of tensors, we can now run the usual keras built-in fit_generator function to carry out the training.

Code 03: Model Training

I have implemented the MANN algorithm for 10 way 2 - shot classification in Omniglot Dataset. The model was able to achieve 99% training accuracy in around 110 epochs. The Accuracy and Loss vs epochs diagrams are shown below.

Figure 04: Accuracy vs Epoch in MANN
Figure 05: Loss vs Epoch in MANN

All the dimensions and the tensor shapes related to Memory cell of the MANN are hyper parameters which we can tune ourselves. Further the controller network is also a modifiable network with preferred layers and preferred hidden units.

Feel free to refer to the Github repository for detailed code and try to implement and see for different number of Few shot classifications. Cheers! 😉

References

[1] Adam Santoro, Sergey Bartunov, Matthew Botvinick, Daan Wierstra, and Timothy Lillicrap. Meta-learning with memory-augmented neural networks.

[2] Alex Graves, Greg Wayne, Ivo Danihelka. Neural Turing Machines.

[3] Brenden M. Lake, Ruslan Salakhutdinov, and Joshua B. Tenenbaum. Human-level concept learning through probabilistic program induction.

[4] Sudharsan Ravichandiran. Hands on Meta Learning with Python

[5] Stanford cs330 Deep Multi-Task and Meta Learning - Homework 01

--

--

Mohamed Afham
the-ai.team

Electronic and Telecommunication Engineering Undergraduate | Math Lover | ML and AI enthusiast