When Vision Transformers Outperform ResNets without Pre-training or Strong Data Augmentations — Paper Summary

Gowthami Somepalli
ML Summaries
6 min readFeb 21, 2022

--

Paper: When Vision Transformers Outperform ResNets without Pre-training or Strong Data Augmentations
Link: https://openreview.net/forum?id=LtKcMgGOeLt
Authors: Xiangning Chen, Cho-Jui Hsieh, Boqing Gong
Tags: Vision Transformer, Optimization, SAM, Analysis
Code: https://github.com/google-research/vision_transformer
Misc. info: Accepted to ICLR’22

What?

Training a ViT is hard, computationally expensive, and they tend to underperform ResNets of comparable size without lots of pretraining/ strong data augmentations. This paper proposes that using a principled optimizer like SAM [1] can make ViTs on par with Conv-models without the need for the aforementioned tricks.

Why?

Vision Transformers are ubiquitous these days. However, training them requires a large amount of computing since they are highly sensitive to hyperparameters and the amount of training data. This paper tries to understand why ViTs behave this way from the loss-landscape perspective.

How?

Prior results: I will add some results/ conclusions from the prior work on which the authors built up the results in this paper.

  1. [2] showed that the trainability of NN can be characterized by the associated NTK condition number. Smaller the number, easier it is to train the model
  2. SGD and Adam are first-order optimizers and they ignore the curvature information of the loss surface. They only try to find the minima, they do not care whether the minima is flat/ sharp. It has been shown in prior works that flatter minima leads to better generalization (check paper page 4 for the list of references for this claim)
  3. SAM works by optimizing the following objective. This essentially means we want to find “w” which minimizes the loss over the neighborhood. Please check out the paper for how the min-max is solved.
Optimization objective of SAM. \rho is the size of the neighborhood ball.

Experimental Setup: The authors evaluated ResNet-152, Vit, Mixer (B,S versions). All the models are trained from scratch on Imagenet for 300 epochs with the basic Inception-style preprocessing. Please check the paper for complete training details.

Hessian max eigenvalue, λ_max, computation is done using 10% of the ImageNet training images via power iteration method [used 100 iters to ensure convergence]

Main Results:

1. The ViT and Mixer models end up in sharper minima when trained with first-order optimizers compared to ResNet, however when we used SAM for optimization, we ended up in wider minima. [Figure 1]

Fig 1.

2 (a). The computed NTK condition number and the maximum eigenvalue, λ_max, of hessian, the values are in the order ResNet < ViT < Mixer, implying the ResNets are easiest to train as well as their minima are wider. While Mixers are at the worst end of the spectrum. [Top row of Table 1]

2 (b). In the bottom section of Table1, we can see that the ViTs with SAM outperform the ResNets [with or w/o SAM]. Mixers benefit the most with SAM while ResNets benefit the least. Authors conjecture that the inductive bias inherent to Conv-nets like translation invariance and locality might be responsible for wider minimas of ResNets.

2 (c). We also see that ViTs trained with SAM perform more robustly when evaluated on ImageNet-C [3] dataset which corrupts images with noise, blur etc.

Table 1

3 (a). The authors tried to understand which part of the architecture is responsible for such high λ_max (~ max curvature) values. Hence they broke down the hessian into small diagonal Hessians and calculated the λ_max for each of the layers. Results for a few blocks are shown in the following table (Tab 2).

3 (b). In table 2, we can see, the embedding layer is the one incurring the sharpest geometry. As we get deeper, move from Block 1 to Block 12, the λ_max reduces. SAM reduced the dominant eigenvalues in all the layers (drastically for early layers in fact!)

3 (c). We can also see the norm of the weight vector goes up in the presence of SAM implying, weight decay (which we typically use in first-order optimizers) might be quite helpful with regularization in ViTs or Mixers.

Table 2

4 (a). In most cases, SAM improves performance much more than training with Augmentations even in low data regimes for ViTs and Mixers. However, Augmentations work better for ResNets in the low-data case (as seen in i1k 1/10 case). See Table 3 below for the complete results.

Table 3

4 (b). In the above table, we see the results of ViT+AUG are really close to that of ViT+SAM in most cases. The authors ask, are the Augmentations flatten the loss surface too? The analysis shows, they do but in different ways. When we look at λ_max for ViT+AUG (Fig2, right), is even worse than ViT without augmentation. However, when we look at loss surfaces (Fig 2, left), the surface seems flat near the convergence point for ViT-AUG, but when we venture far, it’s not.

The authors say, “The difference is that SAM enforces the smoothness by reducing the largest curvature via a minimax formulation to optimize the worst-case scenario, while augmentations ignore the worse-case curvature and instead smooth the landscape over the directions concerning the inductive biases induced by the augmentations.”

Fig 2. (Left) Cross entropy loss surfaces of ViT in different settings. (Right) Max eigenvalue of Hessian for each of the models.

Mis. results: There are a couple of observations the authors made, a bit unrelated to the main topic of the paper. They are as follows —

  1. ViT activations are sparse, hence there is huge potential for pruning. Less than 10% of neurons have values greater than zero for most layers. In ResNet the number is > 50%.
  2. Also, the ViT+SAM qualitatively gave better segmentation maps.

Comments:

Overall good paper. Well written and not too dense, and the authors conveyed their message forward clearly with interesting analysis. A few questions/ issues I had —

  1. I wish the authors showed how computationally expensive using SAM is, compared to augmentations.
  2. Will this claim, “SAM helps with training & generalization of ViTs” hold for smaller datasets say on CIFAR-10? [they showed how the transferability is, but not the training-from-scratch case]

Bibliography:
[1] — Foret, Pierre, et al. “Sharpness-aware minimization for efficiently improving generalization.” arXiv preprint arXiv:2010.01412 (2020).
[2] — Xiao, Lechao, Jeffrey Pennington, and Samuel Schoenholz. “Disentangling trainability and generalization in deep neural networks.” International Conference on Machine Learning. PMLR, 2020.
[3] — Hendrycks, Dan, and Thomas Dietterich. “Benchmarking neural network robustness to common corruptions and perturbations.” arXiv preprint arXiv:1903.12261 (2019).

--

--

Gowthami Somepalli
ML Summaries

Ph.D. student in Computer Science at University of Maryland. Find me at: https://somepago.github.io/