Why Do Better Loss Functions Lead to Less Transferable Features? — Paper Summary
Paper: Why Do Better Loss Functions Lead to Less Transferable Features?
Link: https://openreview.net/forum?id=8twKpG5s8Qh
Authors: Simon Kornblith, Ting Chen, Honglak Lee, Mohammad Norouzi
Tags: Computer Vision, Transfer Learning, Centered Kernel Alignment, Analysis
Code: —
Misc. info: Accepted to NeurIPS’21
What?
The choice of loss function for a given model can impact the classification performance on Imagenet, however also impacts the transfer learning capabilities. The authors systematically study multiple losses and try to understand the reason behind this observation
Why?
CrossEntropy is the typical loss used in classification models, however, some techniques like Label smoothing are proposed to improve the performance of the models. It is good to understand if the improvement in accuracy is worth the trade-off, especially if we want to use the model for other downstream tasks.
How?
Losses studied: Softmax cross-entropy, Label smoothing, Dropout on the penultimate layer, Final layer regularization (L2 on weights of the final layer), Logit penalty (L2 on logits themselves), Logit normalization (Applying softmax on normalized logits with some temperature scaling), Cosine softmax (Softmax cross- entropy loss applied on the similarity between weight vector and the penultimate layer plus some bias), Sigmoid cross-entropy (binary CE loss on each of the logits and sum them up across all the classes), Squared error (like the name says, MSE + some tricks). Check out pages 2–3 of the paper for more detailed mathematical notations and explanations along with references to the papers who used/ proposed them.
Experimental setup: ResNet-50 trained on Imagenet, all the hyperparameters for all the losses are decided based on some validation dataset. Two transfer tasks are considered — Linear transfer (freeze the network, add a final layer and train only that), and k-Nearest Neighbors (Find the features of all training data and test data and classify using k-NN). 8 datasets are used in transfer context — CIFAR-100, CIFAR-10, Food and others.
A couple of concepts we need to know before we look at the main results
Centered kernel alignment (CKA): This metric measures how similar weights of 2 layers are, and is invariant to rotation and isotropic scaling of the weights. Let X and Y are weight matrices from different layers/ same layer from different runs. CKA can be defined as follows —
Class separation (R²): Measure how within class variance and overall dataset variance change wrt to each other, and calculated using following formula. Essentially measures how seperated the classes are.
where sim(.,.) is scaled dot product.
Main results
- Losses which lead to higher ImageNet accuracy lead to less transferable features. Softmax loss leads to best transferable features. (See above — Table 1)
- When fine-tuned on other datasets, loss doesn't matter, they almost perform the same
- The choice of loss affects the deeper features as observed by CKA scores (Fig 1) and the sparsity of ReLU activations across the layers (Fig 2)
- The models with highest R² have the least transferable features (Figures 4 and 5)
- The models whose representations have greater class separation are “overfit,” not to the pretraining datapoints, but to the pretraining classes — they perform better for classifying these classes, but worse when the downstream task requires classifying different classes. (This gives me an idea, check out my comments at the end)
- Augmentation can improve accuracy without changing the class separation much.
Comments:
- For the overfitting point, I wonder if we train a 10 way classifier, and transfer to another 10 way classification dataset, will these results hold? Is it the # of classes that matter? or do we need the datasets (pertaining and transfer) to have the same/similar-ish distribution?
- Overall, very interesting paper. I love how the paper reads really well. All the experiments are well thought and paper puts forward a clear message.
- I guess takeaway is, CE is still the best loss, but I wonder if there are any applications where squared error loss can be useful..
- Check out this prior work from the same authors.