DeiT šŸ”„ ā€” Training Data-Efficient Image Transformer & distillation through attention, Facebook AI -ICMLā€™21

Momal Ijaz
AIGuys
Published in
8 min readFeb 6, 2022

This article is the second paper of the ā€œTransformers in Visionā€ series, which comprises summaries of the recent advanced papers, submitted in the range of 2020ā€“2022, to top conferences, focusing on transformers in vision.

*NerdFacts-šŸ¤“ have additional intricate details, which you can skip and still be able to get a high-level flow of paper!

āœ… Background

After the booming success of transformers, in the NLP domain, many computer vision researchers started applying them to Computer Vision tasks too, for leveraging the power of attention in convolutional networks too. One of the famous works in this line was ViT, Vision Transformer. ViT did pretty well on the image classification task and outperformed SOTA convolution-based architectures(ResNet) on image classification in terms of accuracy and pre-training cost. But these results were obtained after training ViT on the huge dataset for multiple TPU-core days, which limits their adoption and usage. Because one needs to pre-train their ViT, on a huge dataset to unlock its full potential because, on smaller datasets, convolutional networks perform a lot better than ViTs.

ViTs make good classifiers if trained on lots of data

DeiT

[NerdFact-šŸ¤“: After Googleā€™s team introduced ViTs, the research community became active to make efficient variants of ViTs. Facebook AI, apparently won the race and introduced DeiT.]

DeiT, stands for data-efficient transformer, which focuses on making a convolution-free model, that is trained on lesser data and can outperform convolution-based networks.

Specifically, DeiT, takes ViT as base architecture, introduces a few modifications, and trains it with specific strategies, to reduce ViTsā€™ data dependency. They call their modified ViT, DeiT. A transformer on diet :-), no high-calorie big data! DeiT was able to out-perform classic ViT and Convolutional networks, of comparable sizes, for image classification on ImageNet, without any external data.

The rest of the article covers the DeiT architecture (aka modifications in classic ViT), training strategies, experiments, and results.

1. DeiT Architecture

For understanding DeiT architecture, you should have a basic idea of transformers, Vision Transformer(ViT), and Knowledge distillation. I discussed Transformers and ViTs earlier (if you need more explanation, linked articles, are a good source) so letā€™s talk about knowledge distillation.

1.1 šŸ§  Knowledge Distillation

Knowledge distillation is a very simple and smart concept, introduced by Geoffery Hinton, in 2015. Watch him, explain it himself. The basic idea is that we have two networks, a stronger, larger, pre-trained teacher network and a weak, small, randomly initialized student network. The student network learns and tries to become more like the teacher, by minimizing a loss function in which the target label is the teacherā€™s probability distribution (aka output of softmax layer of teacher).

[NerdFact-šŸ¤“: In most cases, teacherā€™s softmax distribution is much closer to original ground truth labels and has not much different effect than training with ground truth labels, but by adding a temperature term in softmax, the distribution becomes much smoother and it allows the student to learn what classes the teacher think are most similar to ground truth, distilling knowledge!]

Letā€™s discuss two types of distillation: DeIT uses the hard one.

a. Soft Distillation: When the student tries to recreate the distribution predicted by the teacher. In soft distillation, we reduce KLL divergence* but loss between softmax of teacher and softmax of the student.

Soft Distillation Objective

Ļˆ =Softmax | Lce = cross-entropy loss | Zs = logits of the student | Zt = Logits of Teacher | Y= ground truth | Yt= teacherā€™s predicted labels | Ī» = Weight of loss terms

The first term (red box) computes the cross-entropy loss of studentsā€™ prediction and ground truth labels, the second term(green box) is the distillation component of loss. This term computes KLL divergence loss between softmax of studentsā€™ and teachersā€™ logits. The contribution of each part is controlled by a weight term lambda. Ļ„ is softmax temperature .

[NerdFact-šŸ¤“: Higher temperature(Ļ„) means more melted / softer / distributed softmax distribution, lower temperature, gives more spiky distribution.]

b. Hard Distillation: When the student tries to recreate the labels predicted by the teacher. In this, we reduce cross-entropy loss between labels (argmax of softmax) of teacher and softmax of student.

Hard Distillation Objective

In hard distillation, the first term(red box)is similar to soft distillation. The second term reflects(orange box) the change in soft and hard distillation, this time the distillation component is the cross-entropy loss between softmax of students and labels of a teacher. Both components contribute equally to the total loss.

* [NerdFact-šŸ¤“: KLL Divergence loss is similar to Cross-Entropy, it tries to compute difference between two probability distributions, and can be used interchangeably with cross-entropy, read here!]

1.2 āš—ļøDistillation Token:

The second modification to convert ViT to DeiT, is the addition of a distillation token to architecture.

DieT distillation Token

In a vanilla ViT, just one learnable CLS token was added to the network, which captured the self-attention among all the input image patches, and was passed through simple MLP and softmax, to perform classification at the end. In DeiT architecture, authors added an additional class-type token called distillation token. This token is exactly like the CLS token, itā€™s initialized randomly, is learnable, and has a fixed last position, unlike the fixed first position of the CLS token.

1.3 Putting it all together!

Slide courtesy: Moazzam S. Fatemah N., Alec K., Connor M. ā€”> source:https://www.crcv.ucf.edu/wp-content/uploads/2018/11/deit_presentation.pdf

So letā€™s pass an image through DeiT. We divide the image into 16x16(patch size from ViT) patches, pass it through an embedding layer, and get fixed size patch embeddings of size d. We add one randomly initialized d -sized vector at the start, called a class token, and one at the end called a distillation token. We add position embeddings to all tokens and pass them through a stack of encoders (3 in the above picture). Each encoder comprises a self-attention and FFN layer. The size of the output remains, the same as the input. We discard all tokens and just take the output of the CLS token and Distillation token, pass them through two separate linear layers to project them to the number of classes, and compute loss (training) or predict a class(inference).

DeiTā€™s loss function

Both CLS token and distillation token have their own objective functions to learn from. These objective functions are:

  • class tokenā€™s objective function = first term in the loss function (red box), True labels
  • distillation tokenā€™s objective function = the second term in loss function (orange box), Tacherā€™s predicted labels (Teacher here is RegNet Y-16GF)

[NerdFact-šŸ¤“: At inference time, authors averaged the softmax of distillation and CLS token, then passed through softmax +argmax to get better results]

2. Experiments and Results:

This is a very well-written comprehensive paper and authors at Facebook AI, conducted quite a few experiments to show how their novel technique is effective.

2.1 šŸ˜ÆSecret Sauce: FineTuning and Augmentation

Authors pre-trained their student and teacher models, on ImageNet-21K(224x224) and fine-tuned it on the same dataset but with a higher resolution (384x384). They also used extensive data augmentation to give the illusion of a large dataset, with fewer available samples, augmentation techniques include CutMix, MixUp, and RandAugment(playing with contrast, rotation, brightness, etc.)

Sample Results of Augmentation Techniques

2.2 Comparing TeachersšŸ‘Øā€šŸ« :

Analyzing same studentā€™s performance with different teachers

Authors tried different pre-trained teachers to distill their knowledge into a DeiT student model, which used the novel distillation process and was fine-tuned on 384x384 resolution.

It can be seen that transformer-student learned least from a transformer-teacher (First row) but learned most from a big convolution-teacher (Last row). Also, in all cases, students outperformed teacher networks!

[NerdFact-šŸ¤“: Dr. Mubarak Shah, Lead CRCV lab@UCF, thinks itā€™s because transformer-student is more different from a convolution-teacher than a transformer-teacher, and hence has a lot more to learn from the first one].

2.3 Hard vs. Soft Distillation āš—ļø:

Soft vs. Hard distillation

The authors compared different model variants with hard and soft distillation on ImageNet. Ti224, S224, B224, and tiny, small and big variants of the model, with image dim 224x224. B384 is the base model, fine-tuned at 384x384. The authors observed that hard distillation gives better results than soft distillation. Also, adding distillation token (yellow highlighted part) improves the modelā€™s performance slightly.

2.4 šŸ† Award for best performing, fastest model goes toā€¦

Accuracy vs. throughput tradeoff ā€” Transformers vs. Convolution based networks

ā€¦ DeiT (Congratulations to the transformers family !)

Transformer-based modelsā€™ out-performed convolution-based SOTA in terms of accuracy and performance trade-off. Strongest DeiT model, trained for 1000 epochs and fine-tuned on 384 resolution, with the novel distillation token-based hard distillation strategy, out-performed the strongest convolution based network KDForAA-B8, in terms of accuracy-throughput trade-off. Both models have the same accuracy of 85.8% but DeiT is much faster, with a throughout of 85 img/sec vs 25 img/sec of KDForAA-B8. This winner DeiT model is faster than its teacher with similar accuracy.

DeiT also outperforms ViT-B in terms of trade-off, ViT-B gives an accuracy of 85.9@88img/sec while DeiT offers 85.8@83img/sec. *All compared models have a similar or comparable number of parameters, so there is no advantage of deeper networks in this analysis.

2.5 Distillation works!

Efficieny of ViT and Efficient Net vs. DeiT

DeiT without distillation was having lower accuracy than efficient net, but adding novel token-based distillation made DeiT out-perform EfficientNets in terms of accuracy and throughput.

Conclusion

  • DeiT introduced a novel distillation technique to make ViT perform well and generalize well, without being pre-trained on huge datasets. DeiT is eco-friendly, as it does not need large data and longer pre-training times to perform well.
  • DeiT outperformed ViTs and Convolutional models in terms of the tradeoff between throughput and accuracy.
  • DeiT pre-trained on just ImageNet can perform pretty well on downstream tasks, like fine-grained classification on smaller datasets including CIFAR-100, Oxford-102Flowers, etc.

šŸ§ Butā€¦

DeiT performs well and learns more from a convolution-based teacher, doesn't that makes DeiT, another hybrid architecture? As itā€™s not completely free from convolutional dependency!

Distilling knowledge from a conv-teacher, makes DeiT learn convolutional spatial inductive biases that are missing from a transformer-teacher. Doesn't that call for a novel spatial attention-based transformer, or stronger positional encodings?

Adding one new distillation token, distilling knowledge from a conv-teacher, makes DeiT better. What if we used more than one stronger pre-trained teacher, each with its distillation token, to teach a DeiT student. Would this student be any better than the ensemble of its teacher?

ā€¦Happy Learningā¤ļø!

References:

  1. https://www.crcv.ucf.edu/wp-content/uploads/2018/11/deit_presentation.pdf
  2. https://intellabs.github.io/distiller/knowledge_distillation.html
  3. https://arxiv.org/pdf/2012.12877.pdf

--

--

Momal Ijaz
AIGuys
Writer for

Machine Learning Engineer @ Super.ai | ML Reseacher | Fulbright scholar'22 | Sitar Player