Ch 10. Vision Transformer Part II — Iterative Erasing of Unattended Image Regions in PyTorch

Helping the model better detect objects in images by iteratively erasing (i.e. darkening) regions of the image unattended by ViT using its self-attention weights

Lucrece (Jahyun) Shin
Artificialis
9 min readJan 28, 2022

--

*This post’s associated Colab Notebook contains step-by-step code for ViT iterative erasing prediction algorithm.

[ViT Model Overview] We split an image into fixed-size patches, linearly embed each of them, add position embeddings, and feed the resulting sequence of vectors to a standard transformer encoder. In order to perform classification, an extra learnable [class] token is prepended to the sequence. (Source: ViT Paper)

In Vision Transformer Part I, I discussed a fairly new image classification model in this post called Vision Transformer (ViT) introduced in An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale paper (2020) and fine-tuning ViT on Xray threat detection dataset. In this post, I will discuss how I improved ViT’s prediction performance using Iterative Erasing Prediction Strategy, which iteratively masks the test image with interpolated ViT attention weights, highlighting the object in interest and raising the class confidence as shown below:

Three iterations of “erasing” the image using [class] token’s final layer attention weights.

Here are the list of topics for this post:

1. Iterative Erasing using Visual Attention

  • Step 0. Define ViT_IterativeErasing class
  • Step 1. Extract ViT [class] token’s final layer attention weights
  • Step 2. Mask the test image with attention mask
  • Step 3. Iteratively mask the image until a class is detected

2. ViT Performance using Iterative Erasing

3. Possibilities for Future Research

1. Iterative Erasing using Visual Attention

Motivation: Human Perception

When we are judging whether an object is present in an image, we zoom in on the area that seems to contain the object and pay less and less attention to the surrounding areas, until we feel certain that the object is indeed present in the image. Such attention process of humans may take less than a second in the brain so we might not even be aware of it. This was the motivation behind the algorithm that iteratively “erases” (i.e. darkens) the unattended regions of the image until the model becomes quite certain that a class object is present in the image. Keep in mind that there is no training involved for this algorithm, as it is only a prediction heuristic for the testing stage.

The entire step-by-step code for the iterative erasing algorithm is shown in my Colab Notebook. Here I will go over the procedure by explaining key Python/PyTorch code snippets line by line:

Step 0. Define ViT_IterativeErasing class

Here I am defining ViT_IterativeErasing class that contains useful methods for the algorithm.

Here are the class input parameters :

  • encoder: trained ViT encoder
  • classifier: trained classifier
  • detect_th: image is repeatedly erased until the model’s predicted probability of the resulting image containing a class exceeds this detection threshold (or the number of masking reaches max_iter)
  • reconsider_th: if the model’s predicted probability of the resulting image containing a class is below this reconsideration threshold, iteration is stopped and “no class detected” result is returned.
  • max_iter: maximum number of iterations to erase images
  • img_transform: torchvision.transforms object to apply to raw images (optional)

Step 1. Extract [class] token’s final layer attention weights

When given a test image, ViT encoder spits out 2 outputs: final embedding vectors for all input tokens and a stack of attention weights for all layers and all heads. So here I am performing multiple operations on ViT encoder’s attention weights output when given a test image. Ultimately, I want to get only the [class] token’s final attention weights with regards to input image tokens. Recall that for ViT, a learnable classification token similar to BERT’s [CLS] token is prepended to the sequence of input image tokens, whose state at the output of ViT encoder serves as the image representation.

Three ViT Variants and associated hyperparameters

It’s useful to remember that I used ViT-Base for my experiments, containing 12 layers, 12 attention heads, and final embedding dimension of 768 as shown in the table.

  • line 1: Pass the pre-processed image x through the trained ViT encoder. The two outputs are 1. features: final embedding vectors of all image tokens + [class] token and 2. att_mat: a stack of attention weights for all layers and all heads.
  • line 2: Stack the 12 self-attention layers (ViT-Base).
  • line 3: Average across 12 heads (ViT-Base).
  • line 4–6: Since each encoder layer contains a residual connection, values in layer l+1 becomes 𝑉[𝑙+1] = 𝑉[𝑙]+𝑊_𝑎𝑡𝑡*𝑉[𝑙]=(𝑊_𝑎𝑡𝑡 + 𝐼)𝑉[𝑙] where 𝑊_𝑎𝑡𝑡 is the attention matrix. So add an identity matrix to the attention matrix, then normalize it.
  • line 7: Compute the attention map grid size (each attention map is a square).
  • lines 9–12: To compute the attention of layer i features w.r.t. input image tokens, recursively multiply all previous layers’ attention matrices.
  • line 14: Extract the last self-attention layer’s attention maps.
  • line 15: Extract [class] token’s attention map (which is the first of 197) w.r.t. input image tokens (thus [1:]).
  • lines 16–17: Resize [class] token’s attention map into square dimension and normalize.

Step 2. Mask the image with attention mask

Here I am masking the original test image with [class] token’s final layer attention map. Since the image size and attention map size don’t match, I first have to interpolate the attention map to have the same size with the image. Below is an example of an image before and after the masking operation:

Before masking/erasing operation (left) using [class] token’s last layer attention map and after (right)
  • line 1: Resize (interpolate) [class] token’s attention map (e.g. 14 by 14) to test image size (e.g. 224 by 224).
  • lines 2–3: If dimensions are reversed, reverse them back.
  • line 6: Multiply the mask to each channel of the image — this is the key masking/erasing operation. The idea is that all attention weight values are between 0 and 1. So multiplying small attention weights ~0 with the original image pixel will reduce the pixel value thus darkening the pixel, while multiplying larger attention weights ~1 with the original image pixel will keep the pixel value around the same.

I also converted the [class] token’s last layer attention map into a heatmap using cv2.applyColorMap and overlaying it on top of the original image. Here are some examples, where you can see that the [class] token pays more attention to parts of the image containing the class object :

Here, [class] token’s last layer attention map pays most attention to the sharp tip of the knife as well as the bottom tip near the handle.
Here, [class] token’s last layer attention map pays most attention to the gun part.

Step 3. Iteratively mask (erase) image until a class is detected

As a final step, I iteratively mask (i.e. erase) the image until ViT sees a non-trivial probability for detecting a class object in the test image. Here is an example of 3 iterations of erasing:

Three iterations of “erasing” the image using [class] token’s final layer attention weights.

Here P(gun) for each image indicates ViT’s predicted probability of classifying the image as gun. With each erasing step, the surrounding areas that do not contain the gun are darkened more and more severely. The model also becomes more confident that there is a gun in the image after each iteration, raising the confidence from 2.3% to 58.8%! Such result shows the effectiveness of the iterative erasing prediction algorithm in correctly identifying the threat object in the image.

Since I worked with a threat detection dataset, I designed the labels such that the first class (class 0) represents “no threat detected” class. The rest, class 1, class 2, etc., represents threat objects such as gun and knife. You can read more about it here.

  • line 2: Mask (erase) the image using the masking function from Step 2.
  • line 4: Check if the model probability of classifying the image into each threat object (class 1, 2, etc excluding class 0) exceeds reconsider_th. I set reconsider_th as a small value of 0.001, meaning that if the model sees greater than 0.1% probability of seeing a threat object in the image, I mask the image and test it one more time.
  • lines 5–7: [Option 1] — If the model probability for each threat object is below reconsider_th for all classes, stop iterating and just return class 0.
  • line 9: Check if the model probability of classifying the image into each threat object (class 1, 2, etc excluding class 0) exceeds detect_th. I set detect_th as 0.1, meaning that if the model sees larger than 10% probability of seeing a threat object in the image, I consider as the object as detected.
  • lines 10–11: [Option 2.1] — If the model’s predicted probability for a single threat object is above detect_th, stop iterating and return the index of detected class.
  • lines 12–13: [Option 2.2] — If the model’s predicted probabilities of multiple threat objects are above detect_th, stop iterating and return the indices of all detected classes.
  • line 15: [Option 3] — If options 1 and 2 do not apply, mask the image again and repeat the process until max_iter iterations are reached.

2. ViT Performance using Iterative Erasing

As introduced in ViT Part I post, for my masters research at University of Toronto, I worked on developing an Automatic Threat Detection for Airport Xray Baggage Scanner. Given Xray scan images like below, the model had to detect any gun or knife if present.

Samples of the three classes from web (source) and Xray (target) domains

The following table shows gun and knife recalls for Xray images using the iterative erasing prediction strategy. Recalls for both gun and knife are significantly higher than vanilla prediction, getting closer to those of the ADDA model with ResNet50 pre-trained on Stylized ImageNet. As a note, this iterative erasing is not possible with ResNet50, as ResNets do not use attention and there is no attention weights available.

Updated recall table for Xray images comparing ResNet50, ViT, and ViT with iterative erasing prediction heuristic

*Source-only means that the model is fine-tuned with web images (source domain) of gun and knife only, and NO Xray images (target domain) are used during fine-tuning. This method was used since I was not provided with sufficient Xray data (30~100 per class only) to train a neural network without overfitting. You can read more about domain adaptation and ADDA (Adversarial Discriminative Domain Adaptation) in this post if you’re interested.

3. Possibilities for Future Research

Attention Objective

Seeing the improved performance with iterative erasing prediction strategy, I experimented with making the attention weights trainable with attention objective on top of the standard classification objective. The idea came from Tell Me Where to Look: Guided Attention Inference Network paper (2018), where highly-attended areas of an image are masked and the objective becomes to minimize the probability of that masked image in being classified as the originally labeled class.

Masking highly-attended areas of the image. The new goal is to minimize the probability of the masked image in being classified as the originally labeled class (“cat” for the left pair and “knife” for the right pair).

I fine-tuned ViT same as before but just added the masked images as input and the guided attention as another loss function. The result was… Well, it was horrible😭, much worse than the attention maps from vanilla fine-tuning. It was funny that the attention weights of a model trained with classification loss only attends to reasonable areas, while training to improve the attention by adding the guided attention objective results in worse attention. So a possible future research would involve improving this guided attention objective to properly train the ViT attention weights.

ViT with Domain Adaptation

Moreover, I tried training Adversarial Discriminative Domain Adaptation (ADDA) with ViT as the backbone, which did not work out well either. Because I got fairly good result with using ResNet50 pre-trained on Stylized ImageNet, I did not have enough time to debug ADDA with ViT. Working on this can also be a possible future research project.

In Conclusion… (adieu for Xray, too)

This is it for ViT and Xray threat detection, but I am sure that there will be plenty of opportunities for me to use ViT in the future, as transformers are quickly finding its place in computer vision applications, as mentioned in this blog post about computer vision trends in 2021.

Also, this post will (most likely😂) conclude the list of my masters research project posts for threat detection for airport Xray scanner. If you’ve been following along, I can’t thank you enough and hopefully you learned something along the way. Thank you readers and I will be back soon with more deep learning posts, so stay tuned! 😉🐬🦋💙

--

--