Making Meta-Learning Easily Accessible on PyTorch

PyTorch
PyTorch
Published in
8 min readNov 19, 2019

This post is authored by Ian Bunner, a student at University of Southern California. He was one of the members of the PyTorch Summer Hackathon in Menlo Park first place winning team, learn2learn. The team also includes Debajyoti Datta, Praateek Mahajan and Séb Arnold.

GitHub Repository: https://github.com/learnables/learn2learn

Meta-learning is often not production-ready. The current algorithms suffer from instability and high computational costs and often require an arduous hyper parameter search. That’s why our PyTorch Summer Hackathon in Menlo Park team, learn2learn, decided to design the library with these difficulties in mind. Learn2learn offers the core functionality needed for those unfamiliar with meta-learning to easily implement it in their project, as well as tools for researchers to create new meta-learning algorithms (or modifications to existing ones) that seek to address those problems.

Learn2learn team winning the first place in the PyTorch Summer Hackathon in Menlo Park

Our team consists of researchers and industry professionals who recognize that meta-learning is a high impact area that needs a simple tool to drive growth. This project was born from the realization that an open source, modularly designed library with a user friendly interface could go a long way towards increasing the usage and popularity of meta-learning and decreasing the overhead required to build and test modified or new meta-learning algorithms. Both of these factors are important in driving advances in the field. We’ll introduce the motivation and meta-learning basics before demonstrating the learn2learn libraries interface.

What is meta-learning and why should you use it?

As evidenced by our GitHub repo name, meta-learning is the process of teaching agents to “learn to learn”. The goal of a meta-learning algorithm is to use training experience to update a learner to be increasingly effective at learning from new training input. The new learner requires fewer samples to adapt to a new task at the cost of a possible time and experience intensive meta-training process. The idea of meta-learning extends from the fact that many tasks we would like to train an agent to solve share common structure, which we can hopefully leverage to create agents capable of applying knowledge gained from old learned tasks to new ones. Meta-learning is one powerful way to tackle problems in the few-shot learning domain, as well as problems that require an agent capable of performing many different but related tasks.

Meta-learning is essential for advancing your AI projects because it can improve performance and allow you to solve problems that might otherwise be hard to using only traditional machine learning algorithms. Take the task of teaching an agent to run as an example. This is a general task, with different instantiations of the task characterized by the physics of the learner and their environment. You might begin training your agent with randomly initialized weights, which almost invariably leads to poor initial performance. After many thousands of training loops, you eventually end up with an agent capable of running. If we instead apply a meta-learning algorithm to determine a good weight initialization, we get an agent capable of learning to run in just a few training loops. Shown below is an agent that learns to run using only one parameter update using meta-learning!

This gif shows motions of an agent before and after using the meta-learning technique to learn to run using only one parameter

Introducing learn2learn

Applying meta-learning to your existing or new projects can be a great way of improving performance or solving new problems. However, this can be a difficult task due to the high computational costs and hyperparameters of algorithms that are unstable and brittle. Also, since the meta-learning is relatively new, most supervised datasets are not properly formatted for inserting into a meta-learning algorithm. Learn2learn alleviates these issues by providing a simple user interface for loading datasets and training using fast, robust implementations of common meta-learning algorithms. With learn2learn, you can apply meta-learning to your projects as easy as plug-and-play. For researchers, any aspect of the core algorithms of our project can be easily modified or completely re-written and tested without all the overhead.

Getting Started with learn2learn

To get started with learn2learn, you can install the project using `pip install learn2learn` or clone and install from the source. Learn2learn maintains an array of high quality example scripts to demonstrate how to use its core-utils and give researchers an easy way to test their algorithms.

Explaining MAML Interface

Model Agnostic Meta Learning (MAML) is a popular gradient-based meta-learning algorithm that learns a weight initialization that maximizes task adaptation with a few training samples. The paper introducing MAML can be found here, with links to the author’s open-source implementation. Some familiarity with the algorithm will certainly help in understanding the nuances behind the interface, though this should not be a barrier to understanding how to use the interface. The following is an example of using the high-level MAML implementation from learn2learn on the popular MNIST dataset. For those unfamiliar, MNIST is a popular benchmark task for image classification algorithms. It is a collection of hand drawn decimal digits (i.e. in [0,9]) that can be used to validate and compare various image classification algorithms. For meta-learning we can treat learning to classify a select subset of ‘m’ digits as tasks, with our goal being to train an agent to learn to classify any arbitrary subset of m digits with only k training samples.

import learn2learn as l2l

mnist = torchvision.datasets.MNIST(root="/tmp/mnist", train=True)

mnist = l2l.data.MetaDataset(mnist)
task_generator = l2l.data.TaskGenerator(mnist,
ways=3,
classes=[0, 1, 4, 6, 8, 9],
tasks=10)
model = Net()
maml = l2l.algorithms.MAML(model, lr=1e-3, first_order=False)
opt = optim.Adam(maml.parameters(), lr=4e-3)

for iteration in range(num_iterations):
learner = maml.clone() # Creates a clone of model
adaptation_task = task_generator.sample(shots=1)

# Fast adapt
for step in range(adaptation_steps):
error = compute_loss(adaptation_task)
learner.adapt(error)

# Compute evaluation loss
evaluation_task = task_generator.sample(shots=1,
task=adaptation_task.sampled_task)
evaluation_error = compute_loss(evaluation_task)

# Meta-update the model parameters
opt.zero_grad()
evaluation_error.backward()
opt.step()

We start by loading the dataset from torchvision, the PyTorch package consists of popular datasets, model architectures, and common image transformations for computer vision. Then we convert the dataset into an instance of a learn2learn MetaDataset, an abstraction that allows us to sample elements from the dataset randomly for a given label. This wrapper works for any dataset that returns a (data, target) tuple. Once we have the dataset, it feeds into learn2learn’s TaskGenerator, which is a wrapper class that enables us to generate sample tasks for few-shot learning easily. The parameters follow the conventional names for few-shot learning. For those unfamiliar, ‘way’ determines how large a task should be, and ‘shot’ is the number of data points available for each task. For example, we show MNIST with 3 ways and 1 shot, so each subtask will be to classify a single input as one of three numbers. ‘Classes’ is a list of classes to sample from. If it is None then we considered all classes of our dataset. The effect of passing classes=[0,1,4,6,8,9] is that these are the only six numbers that can be used to create our tasks. When a task is an integer, it specifies the number of tasks to sample from. If it is not specified, all possible permutations of size ways will be considered. Once we have performed these steps, we can sample a task by simply calling the TaskGenerators’s sample method.

Once we have our dataset squared away, we create a new neural network that will be the model we learn. Then we pass this model into learn2learn’s MAML class, which provides the functionality necessary to implement MAML. We create an optimizer to wrap our model’s weights, and then we are ready to begin implementing MAML.

Our outer loop is the training loop, which will determine how many meta-training steps we take. As part of MAML, we need to clone entire modules so that each training task has its own computational graph for backpropagating through the module’s original graph. This functionality is provided via the clone function. As mentioned before, we can get a new task to train on using the sample function for TaskGenerator. The sample function lets you declare the number of shots for the task in case you did not set it when creating the TaskGenerator. Once we have the task and cloned module, we perform a training loop on it just as we would in a single task setting. The MAML class has an adapt method that allows us to perform an optimization step on our cloned module, so the user does not have to create a new optimizer every meta-training step. To use this method, pass in the computed loss. After training, we evaluate the effectiveness of our adaptation by sampling a new input for the same task, which can be done by using the TaskGenerator sample method’s task argument. This argument lets us specify the task to use instead of picking one from tasks at random. We then compute the loss on this evaluation task as normal and again perform an optimization step, this time on the parameters of our actual model (not our cloned model).

Benchmark Performance

The below tables show our implementation matching the state-of-the-art by comparing performance to original results published for common benchmarks.

MAML Performance on mini-ImageNet for learn2learn compared to results published in the original mini-ImageNet paper.
MAML performance on Omniglot compared to results published in original Omniglot paper.
MAML mini-ImageNet performance as measured by prediction accuracy after meta-adaptation. Accuracy shown for validation tasks to show that adaptation on train and test tasks improves performance on unseen tasks as well. Graphs obtained using a CNN, 5 shots, 5 ways, and 5 adaptation steps.

Accomplishments that we are proud of

  • At the highest level it allows anyone to use meta learning
  • Anyone can experiment with modifications to meta learning algorithms
  • Anyone can dive deep into the code to write meta learning algorithms. It maintains compatibility with other PyTorch libraries like torchvision and torchtext.
  • We created a TaskGenerator for meta learning algorithms so we can have anybody create meta learning tasks from supervised datasets.

Thank you for reading our blog post and hopefully you found it helpful,

Team learn2learn

Acknowledgements

  1. The RL environments are adapted from Tristan Deleu’s implementations and from the ProMP repository. Both shared with permission, under the MIT License.
  2. TorchMeta is similar library, with a focus on supervised meta-learning. If learn2learn were missing a particular functionality, we would go check if TorchMeta has it. But we would also open an issue ;)
  3. higher is a PyTorch library that also enables differentiating through optimization inner-loops. Their approach is different from learn2learn in that they monkey-patch nn.Module to be stateless. For more information, refer to their ArXiv paper.

--

--

PyTorch
PyTorch

PyTorch is an open source machine learning platform that provides a seamless path from research prototyping to production deployment.