DeiT š„ ā Training Data-Efficient Image Transformer & distillation through attention, Facebook AI -ICMLā21
This article is the second paper of the āTransformers in Visionā series, which comprises summaries of the recent advanced papers, submitted in the range of 2020ā2022, to top conferences, focusing on transformers in vision.
*NerdFacts-š¤ have additional intricate details, which you can skip and still be able to get a high-level flow of paper!
ā Background
After the booming success of transformers, in the NLP domain, many computer vision researchers started applying them to Computer Vision tasks too, for leveraging the power of attention in convolutional networks too. One of the famous works in this line was ViT, Vision Transformer. ViT did pretty well on the image classification task and outperformed SOTA convolution-based architectures(ResNet) on image classification in terms of accuracy and pre-training cost. But these results were obtained after training ViT on the huge dataset for multiple TPU-core days, which limits their adoption and usage. Because one needs to pre-train their ViT, on a huge dataset to unlock its full potential because, on smaller datasets, convolutional networks perform a lot better than ViTs.
DeiT
[NerdFact-š¤: After Googleās team introduced ViTs, the research community became active to make efficient variants of ViTs. Facebook AI, apparently won the race and introduced DeiT.]
DeiT, stands for data-efficient transformer, which focuses on making a convolution-free model, that is trained on lesser data and can outperform convolution-based networks.
Specifically, DeiT, takes ViT as base architecture, introduces a few modifications, and trains it with specific strategies, to reduce ViTsā data dependency. They call their modified ViT, DeiT. A transformer on diet :-), no high-calorie big data! DeiT was able to out-perform classic ViT and Convolutional networks, of comparable sizes, for image classification on ImageNet, without any external data.
The rest of the article covers the DeiT architecture (aka modifications in classic ViT), training strategies, experiments, and results.
1. DeiT Architecture
For understanding DeiT architecture, you should have a basic idea of transformers, Vision Transformer(ViT), and Knowledge distillation. I discussed Transformers and ViTs earlier (if you need more explanation, linked articles, are a good source) so letās talk about knowledge distillation.
1.1 š§ Knowledge Distillation
Knowledge distillation is a very simple and smart concept, introduced by Geoffery Hinton, in 2015. Watch him, explain it himself. The basic idea is that we have two networks, a stronger, larger, pre-trained teacher network and a weak, small, randomly initialized student network. The student network learns and tries to become more like the teacher, by minimizing a loss function in which the target label is the teacherās probability distribution (aka output of softmax layer of teacher).
[NerdFact-š¤: In most cases, teacherās softmax distribution is much closer to original ground truth labels and has not much different effect than training with ground truth labels, but by adding a temperature term in softmax, the distribution becomes much smoother and it allows the student to learn what classes the teacher think are most similar to ground truth, distilling knowledge!]
Letās discuss two types of distillation: DeIT uses the hard one.
a. Soft Distillation: When the student tries to recreate the distribution predicted by the teacher. In soft distillation, we reduce KLL divergence* but loss between softmax of teacher and softmax of the student.
Ļ =Softmax | Lce = cross-entropy loss | Zs = logits of the student | Zt = Logits of Teacher | Y= ground truth | Yt= teacherās predicted labels | Ī» = Weight of loss terms
The first term (red box) computes the cross-entropy loss of studentsā prediction and ground truth labels, the second term(green box) is the distillation component of loss. This term computes KLL divergence loss between softmax of studentsā and teachersā logits. The contribution of each part is controlled by a weight term lambda. Ļ is softmax temperature .
[NerdFact-š¤: Higher temperature(Ļ) means more melted / softer / distributed softmax distribution, lower temperature, gives more spiky distribution.]
b. Hard Distillation: When the student tries to recreate the labels predicted by the teacher. In this, we reduce cross-entropy loss between labels (argmax of softmax) of teacher and softmax of student.
In hard distillation, the first term(red box)is similar to soft distillation. The second term reflects(orange box) the change in soft and hard distillation, this time the distillation component is the cross-entropy loss between softmax of students and labels of a teacher. Both components contribute equally to the total loss.
* [NerdFact-š¤: KLL Divergence loss is similar to Cross-Entropy, it tries to compute difference between two probability distributions, and can be used interchangeably with cross-entropy, read here!]
1.2 āļøDistillation Token:
The second modification to convert ViT to DeiT, is the addition of a distillation token to architecture.
In a vanilla ViT, just one learnable CLS token was added to the network, which captured the self-attention among all the input image patches, and was passed through simple MLP and softmax, to perform classification at the end. In DeiT architecture, authors added an additional class-type token called distillation token. This token is exactly like the CLS token, itās initialized randomly, is learnable, and has a fixed last position, unlike the fixed first position of the CLS token.
1.3 Putting it all together!
So letās pass an image through DeiT. We divide the image into 16x16(patch size from ViT) patches, pass it through an embedding layer, and get fixed size patch embeddings of size d. We add one randomly initialized d -sized vector at the start, called a class token, and one at the end called a distillation token. We add position embeddings to all tokens and pass them through a stack of encoders (3 in the above picture). Each encoder comprises a self-attention and FFN layer. The size of the output remains, the same as the input. We discard all tokens and just take the output of the CLS token and Distillation token, pass them through two separate linear layers to project them to the number of classes, and compute loss (training) or predict a class(inference).
Both CLS token and distillation token have their own objective functions to learn from. These objective functions are:
- class tokenās objective function = first term in the loss function (red box), True labels
- distillation tokenās objective function = the second term in loss function (orange box), Tacherās predicted labels (Teacher here is RegNet Y-16GF)
[NerdFact-š¤: At inference time, authors averaged the softmax of distillation and CLS token, then passed through softmax +argmax to get better results]
2. Experiments and Results:
This is a very well-written comprehensive paper and authors at Facebook AI, conducted quite a few experiments to show how their novel technique is effective.
2.1 šÆSecret Sauce: FineTuning and Augmentation
Authors pre-trained their student and teacher models, on ImageNet-21K(224x224) and fine-tuned it on the same dataset but with a higher resolution (384x384). They also used extensive data augmentation to give the illusion of a large dataset, with fewer available samples, augmentation techniques include CutMix, MixUp, and RandAugment(playing with contrast, rotation, brightness, etc.)
2.2 Comparing TeachersšØāš« :
Authors tried different pre-trained teachers to distill their knowledge into a DeiT student model, which used the novel distillation process and was fine-tuned on 384x384 resolution.
It can be seen that transformer-student learned least from a transformer-teacher (First row) but learned most from a big convolution-teacher (Last row). Also, in all cases, students outperformed teacher networks!
[NerdFact-š¤: Dr. Mubarak Shah, Lead CRCV lab@UCF, thinks itās because transformer-student is more different from a convolution-teacher than a transformer-teacher, and hence has a lot more to learn from the first one].
2.3 Hard vs. Soft Distillation āļø:
The authors compared different model variants with hard and soft distillation on ImageNet. Ti224, S224, B224, and tiny, small and big variants of the model, with image dim 224x224. B384 is the base model, fine-tuned at 384x384. The authors observed that hard distillation gives better results than soft distillation. Also, adding distillation token (yellow highlighted part) improves the modelās performance slightly.
2.4 š Award for best performing, fastest model goes toā¦
ā¦ DeiT (Congratulations to the transformers family !)
Transformer-based modelsā out-performed convolution-based SOTA in terms of accuracy and performance trade-off. Strongest DeiT model, trained for 1000 epochs and fine-tuned on 384 resolution, with the novel distillation token-based hard distillation strategy, out-performed the strongest convolution based network KDForAA-B8, in terms of accuracy-throughput trade-off. Both models have the same accuracy of 85.8% but DeiT is much faster, with a throughout of 85 img/sec vs 25 img/sec of KDForAA-B8. This winner DeiT model is faster than its teacher with similar accuracy.
DeiT also outperforms ViT-B in terms of trade-off, ViT-B gives an accuracy of 85.9@88img/sec while DeiT offers 85.8@83img/sec. *All compared models have a similar or comparable number of parameters, so there is no advantage of deeper networks in this analysis.
2.5 Distillation works!
DeiT without distillation was having lower accuracy than efficient net, but adding novel token-based distillation made DeiT out-perform EfficientNets in terms of accuracy and throughput.
Conclusion
- DeiT introduced a novel distillation technique to make ViT perform well and generalize well, without being pre-trained on huge datasets. DeiT is eco-friendly, as it does not need large data and longer pre-training times to perform well.
- DeiT outperformed ViTs and Convolutional models in terms of the tradeoff between throughput and accuracy.
- DeiT pre-trained on just ImageNet can perform pretty well on downstream tasks, like fine-grained classification on smaller datasets including CIFAR-100, Oxford-102Flowers, etc.
š§ Butā¦
DeiT performs well and learns more from a convolution-based teacher, doesn't that makes DeiT, another hybrid architecture? As itās not completely free from convolutional dependency!
Distilling knowledge from a conv-teacher, makes DeiT learn convolutional spatial inductive biases that are missing from a transformer-teacher. Doesn't that call for a novel spatial attention-based transformer, or stronger positional encodings?
Adding one new distillation token, distilling knowledge from a conv-teacher, makes DeiT better. What if we used more than one stronger pre-trained teacher, each with its distillation token, to teach a DeiT student. Would this student be any better than the ensemble of its teacher?
ā¦Happy Learningā¤ļø!