ExBERT Image Captioning
Authors: Yuanbiao Wang, Shuyi Chen, Chunyi Li and Xinyi Li
This article was produced as part of the final project for Harvard’s AC215 Fall 2021 course.
Introduction
Recent advances in Deep Learning — especially in the fields of Natural Language Processing and Computer Vision — have enabled complex cross-modal applications for both of the fields, such as image question answering, image generation with textual prompt, and image captioning. The image captioning task requires a model to generate specific descriptive text in response to the input image.
Image captioning is of great real-world value. For example, image captioning helps with image indexing, which lays the foundation for image search with text. It is also potentially helpful for automatic diagnosis and auxiliary image reading for the disabled. In this sense, we decided to develop an application to provide this service to users without the knowledge of deep learning.
Train a simple model on our own
We started by trying to train a model on our own. The model framework is composed of two major parts: an ImageNet-pretrained CNN for visual feature extraction, and a transformer to encode the image feature and decode it to understandable text. Due to limited computational resources, we chose the Flickr8k dataset for the training, and the training is performed on Colab Notebook. We call this model the baseline model.
Data preparing
Flickr8k dataset is publicly available and can be downloaded here. There are in total 8091 images, each labeled with 5 captions. Below is a random sample from the dataset.
Given the distribution of the caption length (see figure below), we removed the samples that has a caption with length less than 5 or more than our targeted length for training the model — 25. These invalid captions take only a small portion of our dataset and we were still left with 7643 images.
We then randomly set 80% of the cleaned dataset for training and the remaining 20% for validation.
We also tokenized and standardized the caption. We used tf.data to speed up training and applied image augmentation in the pipeline. We are using simple image augmentations like random flip and random crop.
Model Structure
In our baseline model, we first passed the image into a pre-trained CNN model — ResNet50V2 and got general image embeddings. We then passed the embedding to a transformer encoder-decoder network to get the predicted caption. The following figure shows the basic architecture of our model.
Result
We reached ROUGE1 accuracy of 38% on our validation set and sample captions generated from the validation set are shown below.
Predicted Caption: a black and white dog is running through a field grass
Predicted Caption: a man in a black jacket and a black jacket and a woman in a black coat is walking down a street
Predicted Caption: a man in a red shirt and blue jeans is standing in front of a crowd of people
This model yet has many shortcomings, especially with respect to the semantic understandability of the generated text. The output texts are filled with meaningless words. Also we noticed a very obvious “mode collapse” phenomenon, meaning that due the limitation of the dataset, the output text presents a similar characteristic. For example, they tend to start with “A man in a XXX shirt is ….”
Image Captioning With Bottom-up features
To provide better image captioning services, we found it to be of great improvement to use a type of visual representations called bottom-up features. The bottom-up feature is closely related to the image detection task. Object detection task is a combination of coordination regression and object classification. For each image, a well-trained object detection algorithm will identify the object position as well as their classes. This information provides a great abundance of human-understandable visual elements, and is greatly helpful to the improvement of image captioning models. In a famous paper written by Peter Anderson et al. Bottom-Up and Top-Down Attention for Image Captioning and Visual Question Answering, the extracted visual representations by object detection model are first applied to image captioning tasks.
Here we chose the SOTA model, ResNet101 trained with COCO dataset to be the bottom-up feature extractor (we use the detectron library); After the visual representations are extracted, we will pass them into a larger transformer model, and the encode-decode part is the same as our small baseline model. Most of our model architectures are modified from this code repo: https://github.com/ruotianluo/ImageCaptioning.pytorch
To better understand the captioned text, we also provide another service, to visualize the object detection result. We found the most confident(sorted by the log-softmax processed logits) proposed regions, and deleted the ones with repeated labels, and visualized them with d3.js that supports rich interaction. Our user can check either a single detected object, or check all the selected objects in this image. In this way, they will have a better understanding of the generated text.
Here we demonstrate some of the results from our SOTA model.
Framework and technical stack
We are using fastapi, a lightweight backend framework to provide our model-related API services. In the frontend, we used the most popular framework, React, to build the interface. We used Nginx for reverse proxy. We deployed our service in docker containers and Kubernetes. Each container is running in a separate kubernetes node and the cluster is deployed on the GCP platform.
Here we present to you some of the screenshots from our service.
Future Work and Conclusion
We expected to add the following features:
- Support of GPU inference. Current prediction with our SOTA model takes quite a bit time on the model inference with CPU and is not user-friendly. We seek to improve the model efficiency or use advanced hardwares.
- Support for the disability. We want to add a simple functionality to synthesize an audio with google text2speech service for the visually hindered people.
- Support for parallel visits. Right now, our model is in a blocking fashion, meaning that it can only process with one request at a time. If several users want to use our service simultaneously, it would take quite a long time. The correct way for doing this is that whenever a request is received, we assign a new thread worker to process it.
- Support for past results. Right now our service is memoryless and is not capable of presenting past results. To achieve this, we need to add a database and store the user-relevant data in it. We also need to add support for signup, login and session-ids.