Compositional Learning and Analysis

Harsh-Sensei
5 min readJul 10, 2022

--

Recently, there has been many deep learning models giving surprisingly great results in multi-modal tasks text-to-image generation. The performance of DALL-E, GLIDE, Imagen, and various other models, though have different approachs to this problem, but all are affected by one thing, how well is the text model able to create representations of the input prompt.

And this is not limited only to text-to-image generation models, but to all those models which involve generating compositional representations from primitives. Consider the setting of zero-shot learning wherein the task is to predict unseen object-attribute pairs(crushed-envelope, red-wine, etc.) during inference time, without the model being trained on all the object-attribute pairs. For such tasks, it is necessary that the representation network is able to learn to represent the primitves(object, and attribute) and their composition (joint: object-attribute) appropriately in the latent space.

Compositionality may be an easy task for humans but daunting for data-driven deep learning models. This is an actively research area in machine learning and in this blog I would like to elaborate on one the works for compositional learning: “Task-Driven Modular Networks for Zero-Shot Compositional Learning”.

The paper describes an interesting technique for zero shot classification of object-attribute pairs using modular netoworks(a setting close to meta-learning) whose inputs are conditioned on a gating network dependent on the object-attribute pair. Before diving any further, let’s go through some important terminology,

  1. Gating network: Intuitively it means a network that would determine how much weight is to be given to each of the outputs from a unch of modular networks. Such networks are also used in Mixture of Experts(MoE) setting wherein the gating(or routing) network determines which expert models to choose for a particular input.
  2. ConceptDrop: Not a very general term, but the authors of the paper have dubbed this for the process of randomly dropping some negative attributes(it will soon get clear) in each epoch, in order have a regularization effect.

Now grab your gear and let’s go…

Below is the architecture of the model:

Architecture of task-diven modular nets (SOURCE: Paper)

So the main idea behind the technique is to train a gating network and the modular networks(which is ultimately follwed by a linear layer to map to a scalar) on image features and object-attribute pairs. Let’s appropriately define our problem:

Framing the problem properly

Training:

While training the model, the input is as described in the figure above, with a pinch of salt, to avoid excessive computations, a random sample of object-attribute are chosen for approximating the probability normalization factor.

The image is encoded into a feature space using a ResNet trunk which is frozen in the entire training process. The features are then passed into the first layer of modular nets(all have the same input), without any interference of gating network.

For the subsequent layers, the input is determined by outputs of previous layer from each modular net, and the weights determined by the gating network. Mathematically,

Inputs to i^th layer. o->output, x->input, g->gating(SOURCE: Paper)

where superscripts are layer superscripts, subscripts are for indexing in that particular layer, k->j represent weight for output k^th modular net in (i-1)^th layer to j^th modular net in i^th layer.

Gating network: The gating netowrk is a multi-layer neural network that takes the concatenated object-attribute embeedings(pre-trained Glove embeddings) and outputs all the gating weights.

Calculation of gating weights

Feature extraction network: The feature extraction network has many layers, and each layer again has many modules(the authors used 24 modules of 16-dimensional input and output vectors). The output from the last layers are concaynated and passed through a linear layer to output a scalar. Softmax is applied on the scalars obtained for all object attribute pairs corresponding to an image.

Once the scores are obtained, cross entropy loss is minimized using ADAM optimization technique.

Inference:

During inference time, to classify an input image into one of the object-attribute pair, the above discussed score is calculated for each of the object-attribute pair, be it seen or unseen during training time. Unseen categories may get a low score since they were not present during training time, and to tackle this a bias(also called caliberation bias) is added to the final score of such categories.

Based on the bias, the seen and unseen accuracies may vary. Higher bias would lead to better accuracies for unseen categories and lower bias would prefer seen categories.

Seen-Unseen accuracy graph for varoius approaches(SOURCE: Paper)

Compositionality Analysis

To anlyse whether the model trained is able to represent composition effectively, the authors plotted a 2-D t-SNE plot(2-D projection of higher dimensional vectors) of the output of the gating network(that is all the weights). The input object-attribute pairs which are visually similar are expected to be closer in the plot, if the model has learnt compositional representation well.

“Pairs tagged with text boxes that have white background, indicate examples where changing the attribute results in similar gatings (e.g., large/small table); conversely, pairs in black background indicate examples where the change of attribute/object leads to very dissimilar gatings (e.g., molten/brushed/coil steel, rusty water/rusty wire).” (SOURCE: Paper)

The paper also compares its results with some previous works. RedWine method composes the SVM classifier weights of primitive networks to make classifier for compositions of the primitive network. Please refer to this paper for more details.

That’s all I have to tell you about this topic, hope you found it insightful. Arigato Gozaimasu!

--

--

Harsh-Sensei

Pursuing B.Tech in Computer Science Engineering at IIT Bombay. Eternally excited about robotics, machine learning and computer graphics