Bootstrapping a multimodal project using MMF, a PyTorch powered MultiModal Framework

PyTorch
PyTorch
Published in
7 min readJun 11, 2020

A solid foundation for your next vision and language research/production project

Authored by: Facebook AI Research (Amanpreet Singh and Vedanuj Goswami)

Reasoning across modalities is critical to intelligence. There is a growing need to model interactions between modalities (e.g., vision, language) — both to improve AI predictions on existing tasks and to enable new applications. Multimodal AI problems range from visual question answering, image captioning and visual dialogue to embodied AI, virtual assistants and detecting hateful content on social media. Better tools — both for researchers to develop novel ideas and for practitioners to productionize use cases — have potential to accelerate progress in multimodal AI.

MMF (short for “a MultiModal Framework”) is a modular framework built on PyTorch. MMF comes packaged with state-of-the-art vision and language pretrained models, a number of out-of-the-box standard datasets, common layers and model components, and training + inference utilities. MMF is also used for multimodal understanding use cases by several Facebook product teams as it facilitates pushing research to production quickly.

Key characteristics of MMF are:

1. Usability:

  • Built on PyTorch 1.5
  • A model zoo with 12+ state-of-the-art (including BERT-like) models
  • A dataset zoo with ~20 datasets with automatic downloads
  • Comprehensive documentation and tutorials
  • A clean, easily extensible API
  • Starter code for several multimodal challenges

2. Modularity and Configurability:

  • Modular components like encoders, decoders, embeddings, layers and processors to build models and datasets from scratch
  • A new configuration system based on OmegaConf
  • Commonly used metrics and losses

3. Scalability:

  • Distributed training support along with best practices for maximum performance
  • Sweep scripts for launching large scale SLURM jobs
  • Checkpoint, early stopping and other functionality for making training and evaluation easier

What follows is a tutorial on MMF. It has two parts. In the first part, we will use a pretrained model to train, evaluate and make a submission to the Hateful Memes Challenge. In the second part, we will learn how to build a custom Hateful Memes detection model in MMF from basic building blocks and train + evaluate with that. The Hateful Memes Challenge consists of a multimodal integrity task for classifying which memes are considered hateful. You can also run this tutorial on Google Colab by using this notebook.

Part 1 : Getting Started

Step 1 — Install MMF

First, we will install MMF to download and install all the required dependencies. We then check if the download was successful.

Prerequisites : Python 3.7+, Linux, MacOS or Windows

pip install —-pre mmfpython -c “import mmf; print(mmf.__version__)”

It should show the version of mmf installed.

Step 2 — Download the Hateful Memes Challenge dataset

To get the Hateful Memes dataset, follow these steps:

  1. Go to DrivenData challenge page for Hateful Memes
  2. Register, read and acknowledge the agreements for data access.
  3. Go to the Data Download page, download the dataset clicking on the “Hateful Memes challenge dataset” link.
  4. Take a note of the password provided for the zip.

Once downloaded, we convert the dataset to MMF format:

mmf_convert_hm --zip_file <zip_file_path> --password <password>

Step 3 — Visualize Samples

Let’s now try to visualize a few samples from the dataset to understand what the data and annotations look like.

Note: Some of the images in the hateful memes dataset are sensitive and may not be suitable for all audiences. Please run the next code responsibly keeping these conditions in mind.

build_dataset(“hateful_memes”) builds the dataset and loads the annotation files and images. dataset.visualize(num_samples=8) will visualize 8 samples from the dataset in a grid.

Step 4 — Evaluate Pretrained models

We will use a pretrained model (MMBT) to classify some randomly selected memes from the dataset to see if they are hateful or not.

MMBT.from_pretrained initializes the model and loads a pretrained model weights from our model zoo. mmbt.hateful_memes.images is the model zoo key for the pretrained model. .classify method takes a path to an image and the accompanying meme text to generate a prediction using the model. It will output the label and confidence. The label will be 1 if it is hateful and 0 if it is benign.

Step 5 — Submit predictions to DrivenData

Next, we will submit predictions generated our model to make a submission to the Hateful Memes Challenge hosted on DrivenData:

mmf_predict config=projects/hateful_memes/configs/mmbt/defaults.yaml \ model=mmbt \dataset=hateful_memes \run_type=test \checkpoint.resume_zoo=mmbt.hateful_memes.images

This will generate a csv file with predictions and will output the path to the csv file. You can submit this csv file to DrivenData for results.

Step 6 — Training Models from Scratch

Now that we have generated results with a pretrained model, next we will learn how to train models from scratch. We will train and evaluate a model that has already been implemented in MMF. Specifically, we will train a model from the Hateful Memes paper, MMBT trained with grid features:

mmf_run config=projects/hateful_memes/configs/mmbt/defaults.yaml \  model=mmbt dataset=hateful_memes

Here we use the mmf_run command which is the CLI command for running training + validation. We specify the config file that is used for this training and the model and dataset that we train on.Training will log training loss and other metrics every 100 iterations and will log evaluation metrics every 1000 iterations. This training will run for 22,000 iterations. These are hyperparameters which are defined in the configuration file and can be modified/overridden.

Part 2 : Creating your own models

We will now go through the step-by-step process of creating new models using MMF. In this case, we will create a fusion model for the Hateful Memes challenge.

The fusion model that we will create concatenates embeddings from a text encoder and an image encoder and passes them through a two-layer classifier. The diagram below shows the model architecture.

To implement this model in PyTorch, we will build a class like:

This is a very simple model using only fully-connected layers for everything and thus, is not expected to give a good performance. MMF provides standard image and text encoders out of the box. For image encoder, we will use ResNet101 image encoder and for text encoder, we will use FastText embeddings. FastText embeddings cannot be trained end-to-end with a model in this case. So we will load the embeddings in the datasets itself by creating an MMF processor and pass them through a fully-connected layer as a proxy for an encoder. We will now follow the steps below to create our new model:

  1. Create a FastText sentence vector processor.
  2. Build a model using a classifier, image encoder and text encoders from MMF.
  3. Create an extensible hyperparameter config for the model.
  4. Create an experiment config for your training hyperparameters.
  5. Train a model and submit predictions.

Step 1 — Creating the processor

Processors can be thought of as torchvision transforms which transform a sample into a form usable by the model. Each processor takes in a dictionary and returns back a dictionary. Processors are initialized as member variables of the dataset and can be used while generating samples. Fasttext processor is available in MMF but it returns word embeddings instead of sentence embedding. We will create a fasttext sentence processor here.

Step 2 — Using MMF to build the model

We will start building our model LanguageAndVisionConcatusing the various building blocks available in MMF. Helper builder methods like build_image_encoder for building image encoders, build_classifier_layer for classifier layers etc take configurable params which are defined in the config we will create in the next section.

The model’s forward method takes a SampleList and outputs a dict containing the logit scores predicted by the model. Different losses and metrics can be calculated on the scores output.

Step 3 — Defining Configs

We define two configs needed for our experiments: (i) model config for model defaults (ii) user/experiment config. Model config provides the defaults for model’s hyperparameters and experiment config defines and overrides the defaults needed for our particular experiment. You can check out the configs with detailed comments in the colab notebook or here and here. Note that to use the fasttext processor we created, we update the text processor of the Hateful Memes Dataset’s config in our experiment config. You can read more about the MMF’s configuration system here.

Step 4 — Training the model

Now we are ready to train our model with the experiment config we created in Step 3.

mmf_run config="configs/experiments/defaults.yaml" \model=concat_vl \dataset=hateful_memes \training.num_workers=0

Step 5 — Submitting your model’s prediction to DrivenData

mmf_predict config="configs/experiments/defaults.yaml" \model=concat_vl \dataset=hateful_memes \training.run_type=test \checkpoint.resume=True \checkpoint.resume_best=True \training.num_workers=0

This command will generate a csv file which you can submit to DrivenData, as we saw earlier in Part 1. checkpoint.resume specifies that we want to resume from our already trained model and checkpoint.resume_best specifies that we want the best model among the trained ones.

Multimodal Challenges on MMF

MMF comes with starter code, baseline models and detailed tutorials for various challenges in multimodal vision and language space. Some of these challenges are :

  1. Hateful Memes [blog] [link]
  2. TextVQA [link]
  3. VQA [link]
  4. TextCaps [link]

At Facebook AI, we will continuously improve and expand on the multimodal capabilities available through MMF, and we welcome contributions from the community as well to build this resource. We hope MMF will be the framework of choice and be a catalyst for research in this area by providing a powerful, versatile platform for multimodal research.

Check out our source code here and documentation here.

--

--

PyTorch
PyTorch

PyTorch is an open source machine learning platform that provides a seamless path from research prototyping to production deployment.