MLP-Mixer: An all-MLP Architecture for Vision — Paper Summary

Gowthami Somepalli
ML Summaries
5 min readAug 31, 2021

--

Paper: MLP-Mixer: An all-MLP Architecture for Vision
Link: https://arxiv.org/abs/2105.01601
Authors: Ilya Tolstikhin, Neil Houlsby, Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Thomas Unterthiner, Jessica Yung, Andreas Steiner, Daniel Keysers, Jakob Uszkoreit, Mario Lucic, Alexey Dosovitskiy
Tags: Machine Learning, Deep Learning, MLP architectures, Deep Learning architectures
Code: https://github.com/google-research/vision_transformer
Misc. info:

What?

A competitive all MLP based architecture (without convolutions or attention), which performs well on Image-net classification benchmarks.

Why?

Simpler architecture with even fewer biases than a Vision Transformer[1]. Also unique deep learning architectures are all the rage right now!

How?

Figure 1: MLP-Mixer architecture illustration.

Architecture

The input to the model are ‘S’ number of square image patches, which is similar to Vision transformer. Unlike ViT[1] where convolutions might be used to generate the patch or “token” representations, in MLP-Mixer, the patches are linearly projected into C dimensions using the same projection matrix. So the embedding dimension would be SxC.

The main contribution of the paper is the Mixer layer which has 2 main steps. 1) MLPs applied for a given patch across channels, known as channel-mixing MLPs (MLP1 in the above figure) 2) MLPs applied at a given channel across tokens, which are known as token-mixing MLPs (MLP2 in the above figure). Note that for a given Mixer layer, MLP1 and MLP2 are initialized only once, and are shared across different channels/ tokens. Also layer norm is applied on each channel before both the MLP operations. Like Self-Attention layers in transformer, the dimension of the output is same as that of input in Mixer layers too.

Each MLP is 2 fully connected layers with one non-linearity, GeLU[2]. (A bit of trivia, one of the authors claimed the usage of GeLU has no particular significance, and they did not try with other non-linearities — link)

D_s and D_c are hidden layer widths of token-mixing and channel-mixing MLPs, respectively. Since both of these are hyper parameters, the complexity of the model is linear in terms on input size. (Unlike attention-based models which scale quadratically.)

In a way, channel mixing part of the Mixer-MLP can be seen as a type of Conv-net with 1x1 filter. The token-mixing part of the Mixer actually increases the model’s receptive field to the complete image, as compared to Conv-nets who have smaller receptive fields in the beginning.

Post the Mixer layers, we just take the average pooling across the channels which leads to S-dimensional embedding which is further passed on to a Fully connected network to give the final class prediction!

Main results

Table1: Fine tuning results.

In Table 1, the authors present fine-tuned results on Imagenet, and average of 5 datasets (ImageNet, CIFAR-10, CIFAR-100, Pets, Flowers) after pre-training on ImageNet-21k and the Google’s proprietary datasets. The authors also present the speed of inference of each of the models, as well as the duration of pre-training in scale of 1000 TPUv3 days! 😱.

The blue dots represent models with attention, yellow dots are models with convolutions and purple are mixer models. It seems like Mixer-MLP results are very close to that of other standard models if not better, but they are faster than ViTs in terms of inference, but they need longer pre-training to reach similar level of performance!

Table 2: Various Mixer-MLP configurations

Similar to ViTs, there are multiple Mixer models with increasing complexity with higher MLP widths and increasing number of mixer layers as described in Table2.

In addition, when we pre-train with increasing amount of data, after certain point, the performance of Mixer-MLPs exceed Vision transformers as seen in Figure 2

Left: Figure 2 — Validation accuracy with increasing pretraining data size on Mixer models, ViTs and Convolutional models. Right:Figure 3 — Even if we shuffle the patches of data, Mixer model performs the same, while Resnet accuracy drops drastically.

The authors also performed an additional experiment by modifying the training data as follows (1) shuffle patches and shuffle pixels within in a fixed pattern (2) Global shuffling of all the pixels (An example can be seen in above Figure 3). Mixer-MLPs, performance is exactly the same in the first case, while the performance drops a bit in second experiment albeit less than ResNet’s[3] case! The author’s explain it as due to the Mixer’s invariance to the order of the patches and the pixels within them.

In the patch shuffling case, since the permutation applied is same for all images, for an MLP mixer, it is same as learning an additional permutation matrices with in the patch and across the patches. But in case of CNNs, who have the inbuilt prior of nearby pixels being related to each other, the shuffling of pixels disrupts the proper usage of the prior.

Figure 4: Visualizations of the hidden units of first, second and third token mixing MLPs of Mixer-B/16 models trained on JFT dataset.Each pixel in each of the blocks represent the weight for a patch in the original image.

In Figure 4, the authors visualize the pre-trained hidden units of token-mixing MLPs . For better visualization, the authors sorted all hidden units according to a heuristic that tries to show low frequency filters first. The authors also curated the visualization in such a way that each unit followed by its closest inverse.

An important point to note is, these visualizations are heavily dependent on the dataset (and augmentations) the model is pre-trained on. Check out the appendix of the paper to see the visualizations compared against different pre-training scenarios.

Comments:

Overall it is an interesting paper, with a new and simple architecture with surprisingly good results. Following are couple of questions I had while reading, I don’t have answers to them right now, feel free to comment if you have some thoughts around them.

  1. In the permutation experiment, I wonder if ViT’s also behave same as Mixers, since the pattern is constant through the dataset.
  2. I wonder how the performance would be without any pre-training.

Do let me know what you think in the comments!

Bibliography:
[1] — Dosovitskiy, Alexey, et al. “An image is worth 16x16 words: Transformers for image recognition at scale.” arXiv preprint arXiv:2010.11929 (2020).
[2] — Hendrycks, Dan, and Kevin Gimpel. “Gaussian error linear units (gelus).” arXiv preprint arXiv:1606.08415 (2016).
[3] — He, Kaiming, et al. “Deep residual learning for image recognition.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.

--

--

Gowthami Somepalli
ML Summaries

Ph.D. student in Computer Science at University of Maryland. Find me at: https://somepago.github.io/