Visual question answering with multimodal transformers

PyTorch implementation of VQA models using text and image transformers from Hugging Face

Tezan Sahu
Data Science at Microsoft
10 min readMar 8, 2022

--

Recent years have seen significant advancements not only in the respective domains of Natural Language Processing (NLP) and Computer Vision (CV) but also in tasks involving multiple modalities (text + image features) such as image captioning, visual question answering (VQA), cross-modal retrieval, visual common-sense reasoning, and more. Among these, VQA has particularly drawn the interest of several researchers.

What is VQA?

VQA is a multimodal task wherein, given an image and a natural language question related to the image, the objective is to produce a natural language answer correctly as output.

Source: OK-VQA: A Visual Question Answering Benchmark Requiring External Knowledge on arxiv.org.

It involves understanding the content of the image and correlating it with the context of the question asked. Because we need to compare the semantics of information present in both of the modalities — the image and natural language question related to it — VQA entails a wide range of sub-problems in both CV and NLP (such as object detection and recognition, scene classification, counting, and so on). Thus, it is considered an AI-complete task.

VQA with multimodal fusion models

Multimodal models can be of various forms to capture information from the text and image modalities, along with some cross-modal interaction as well. In fusion models, the information from the text and image encoders are fused into a combined representation to perform the downstream task.

A typical fusion model for a VQA system involves the following steps:

  • Featurization of image and question: We need to extract features from the image and obtain the embeddings of the question after tokenization. The question can be featurized using simple embeddings (like GLoVe), Seq2Seq models (like LSTMs), or transformers. Similarly, the image features can be extracted using simple CNNs (convolutional neural networks), early layers of object detection or image classification models, or image transformers.
  • Feature fusion: Since VQA involves a comparison of the semantic information present in the image and the question, there is a need to jointly represent the features from both modalities. This is usually accomplished through a fusion layer that allows cross-modal interaction between image and text features to generate a fused multimodal representation.
  • Answer generation: Depending on the modelling of the VQA task, the correct answers could either be generated purely using natural language generation (for longish or descriptive answers) or using a simple classifier model (for one-word/phrase answers present in a fixed answer space).

Following are some methods used to perform the individual feature extraction and feature fusion steps:

Types of multimodal data fusion. Image created by the author.

In this article, I explore the idea of late fusion by fine-tuning pretrained text and image transformer models, as they are simpler to train.

With this background in place, it’s time to delve into the code and implement our multimodal model for VQA. First, we process the DAQUAR dataset. Because all the questions have single word/phrase-type answers, we consider the entire vocabulary of answers available (answer space) and treat them as labels. This converts visual question answering into a multiclass classification problem. We then train our multimodal transformer model and evaluate it using some established metrics for VQA. Toward the end, we compare and explain the results for various combinations of textual and image transformers used for featurization.

Tl;dr: This repository contains all the code mentioned in this article. Although GitHub gists are used as code snippets throughout this article, if copied directly, they may not work as intended. Please refer to this notebook present in the repository for working implementations.

Preliminaries

Installing required packages

We need to create a virtual environment and install the required packages:

datasets==1.17.0
nltk==3.5
pandas==1.3.5
Pillow==9.0.0
scikit-learn==0.23.2
torch==1.8.2+cu111
transformers==4.14.0

Note: It is advisable to have some GPU access to train the multimodal models because they are large and require considerable time for training otherwise.

Setting up the environment

To set up the environment for training our multimodal VQA model, we need to import the required modules and set the appropriate device for PyTorch.

Data preparation

For the VQA model training, we use the full DAtaset for QUestion Answering on Real-world images (DAQUAR) dataset, which contains approximately 12,500 question-answer pairs based on images from the NYU-Depth V2 Dataset.

Sample images, questions, and answers from the DAQUAR Dataset. Source: Ask Your Neurons: A Neural-based Approach to Answering Questions about Images. ICCV’15 (Poster).

Preprocessing the dataset

The raw dataset contains the actual images separately in the images/ directory. All the question-answer pairs are present on consecutive lines in a .txtfile as shown below:

what is on the desk and behind the black cup in the image4 ?
bottle
what is in front of the monitor in the image6 ?
keyboard
...

We run the following script to pre-process these question-answer pairs. It normalizes the questions by removing the image IDs present in the question. The questions and answers, along with the corresponding image IDs extracted during normalization, are stored in a tabular (CSV) format. Moreover, because the original DAQUAR dataset provides only about 54 percent of the question-answer pairs for training (this amounts to only around 6700 samples, which is very less for training), we produce our custom split (80 percent training and 20 percent evaluation) from the overall data.

This script produces data_train.csv and data_eval.csv files, along with answer_space.txt, containing a vocabulary of all the answers.

These files are already available in the dataset/ folder of the repository for direct consumption and can also be found on Kaggle.

Loading the data

Now we are set to load this data using this processed dataset. For this, we use the datasets library from Hugging Face. Since we model this task as a multiclass classification task, we should assign labels to every answer. These labels are derived from the indices of the answers in the answer space.

We can also inspect entries present in our training or evaluation dataset (specific or random) using Jupyter notebook:

A random entry from the training dataset after loading and creating labels from the answer-space.

Defining a multimodal collator for data

Up to this point, we have just loaded the questions, answers, and corresponding image IDs, along with the labels. To feed the information about the question and actual images batchwise into our multimodal model, we need to define a data collator.

This collator will process the question (text) and the image and return the tokenized text (with attention masks) along with the featurized image (basically, the pixel values). These will be fed into our multimodal transformer model for question answering.

We use AutoTokenizer and AutoFeatureExtractor from Hugging Face transformers to convert the raw images and questions into inputs for featurization using the respective image and text transformers.

Defining the multimodal VQA model architecture

As mentioned previously, we use the idea of late fusion to define our multimodal model comprising:

  • A text transformer to encode the question and generate embeddings
  • An image transformer to encode the image and generate features
  • A reasonably simple fusion layer that concatenates the textual and image features and passes them through a linear layer to generate an intermediate output
  • A classifier, which is a fully connected network with output having the dimensions equal to that of the answer-space

We model VQA as a multiclass classification task. Thus, cross-entropy loss becomes a natural choice for the loss function to be minimized.

Besides training a particular VQA model with multimodal transformers, we intend to experiment with various pre-trained model combinations and evaluate their performance on the DAQUAR dataset.

Pretrained models for textual encoding

Pretrained text transformers for experimentation to provide textual features.

Pretrained models for image encoding

Pretrained image transformers for experimentation to provide visual features.

Creating the collator and multimodal model

Because we aim to experiment with multiple combinations of text and image transformers, it is reasonable to implement a function for creating the corresponding collators with the respective models.

For demonstration in this article, we will create the collator and model using the tokenizer, feature extractor, and models from pretrained BERT and ViT.

Evaluation metrics

We approach the VQA task as a multiclass classification problem in this article. Hence, accuracy and macro F1 score are straightforward choices as metrics for evaluating the performance of our model. However, because these metrics may often be too restrictive, penalizing almost correct answers (‘tree’ versus ‘plant’) as heavily as incorrect answers (‘tree’ versus ‘table’), we select a metric like WUPS as our primary evaluation metric. Such a metric considers the semantic similarity between the predicted answer and the ground truth.

Wu and Palmer Similarity (WUPS) Score

One option to evaluate open-ended natural language answers is to perform exact string matching. However, it is too stringent and cannot capture the semantic relatedness between the predicted answer and the ground truth. This prompts the use of other metrics that capture the semantic similarity of strings effectively. One such commonly used metric is the Wu and Palmer Similarity (WUPS) Score.

WUPS computes the semantic similarity between two words or phrases based on their longest common subsequence in the taxonomy tree. This score works well for single-word answers (hence, we use it for our task), but may not work for phrases or sentences.

Although nltk has an implementation of the WUPS based on the WordNet taxonomy, for our experimentation, we use the implementation of Wu and Palmer similarity as defined along with the DAQUAR dataset through the wup_measure(...) function.

Training the multimodal VQA model

We finally come to the part where we use the previously defined functions to initialize our multimodal model and train it using the Trainer from Hugging Face to abstract away most of the code required for setting up a PyTorch training loop. The hyperparameters such as training epochs, batch size, and so on, are passed to the Trainer by setting the corresponding values in the TrainingArguments.

For this article, we use the following hyperparameters:

Hyperparameters used for training the multimodal model.

The training of this BERT + ViT model (and all other combinations of transformers) was carried out on an NVIDIA A100 SXM4 40 GB GPU.

The model checkpoints are saved periodically in the indicated output directory based on the information provided in the TrainingArguments.

Making inferences using trained model

To use any of the saved model checkpoints for inferencing, the question must be tokenized, and image features must be extracted appropriately (as done in the collator). These would serve as input to the model, with weights loaded from the trained checkpoint. The label predicted by the model is then mapped to the index of the actual answer in the answer space.

Examples of answers predicted by our multimodal VQA model on certain evaluation instances.

Comparing the performance of various models

A similar approach is followed to train VQA models with various combinations of text and image transformers by changing the text and image arguments while calling the createMultimodalVQACollatorAndModel(...) function.

The table below summarizes the performance of these different models. These details can also be found on this DVC Studio dashboard.

  • RoBERTa + BEiT performs the best in terms of both WUPS and accuracy.
  • RoBERTa-based models generally perform better than the rest. This can be attributed to the larger number of trainable parameters and the embeddings generated through more robust pre-training.
  • ALBERT-based models are expected to have lower performance because ALBERT is much smaller compared to BERT and RoBERTa. Yet, the ALBERT + ViT model can achieve scores comparable to the BERT + ViT model, despite having only around half the number of parameters.
  • For BERT and RoBERTa-based text transformers, the best results are achieved using BEiT as the image transformer. However, it does not perform up to the mark with ALBERT. This could indicate that higher quality textual embeddings are required to complement the image embeddings generated by BEiT.

Concluding remarks

In summary, we successfully implemented, trained, and evaluated a late fusion type of multimodal transformer model in PyTorch for visual question answering using the DAQUAR dataset. We also learned how to use the model weights from a trained checkpoint to answer questions related to an image. Last, we compared the performance of several models using different text and image transformers to featurize the question and image before performing fusion.

I hope this article has given you a good overview of some of the concepts involved in visual question answering and helped you understand the nuances of training multimodal transformer models in PyTorch for such a task. Feel free to check out the References section below for more details regarding concepts and terms that I might have breezed through in this article. Please leave any feedback or suggestions in the comments section below.

All the code mentioned in this article is available in this repository. The performance metrics for different combinations of transformers can be found on this dashboard. Feel free to clone the repository, follow the steps mentioned in the README.md and tweak the params.yaml file to experiment with different model architectures for the VQA task.

--

--

Tezan Sahu
Data Science at Microsoft

Applied Scientist @Microsoft | #1 Best Selling Author | IIT Bombay '21 | Helping Students & Professionals Ace Data Science Roles | https://topmate.io/tezan_sahu