Domain Generalization(in Computer Vision)

Harsh-Sensei
4 min readDec 14, 2022

--

Introduction

Through this blog, we would first understand the problem of Domain Generalization(DG) and its applications, especially in the realm of computer vision. This would be followed by various recent techniques used to tackle the problem(ahhh….transformers again) and a brief overview of DG algorithms like Fishr and SWAD.

Art Anya(Left); Sketch Anya(Right) (Source: SpyXfamily)

Domain Generalization

Though it might be very intuitive to humans for identifying an elephant(or Anya Forger) from a sketch, once real elephant is known, life is not so easy for AI. Distribution shifts in data during inference significantly affect the accuracy of machine learning models. Generalizing ML models over distribution shifts is a critical and well-studied task. The shifts in distributions maybe due to changes in background, lighting, texture and many more. We formulate this problem with regards to computer vision as follows:

Problem Formulation:

Problem Formulation

Below shown are various common setups in machine learning realm. Note that in Domain generalization setup, we do not have the test distribution as training input.

Inputs to various setups in test and train time (Source: DomainBed)

Researchers have used many approaches to tackle this problem, varying in the architecture of model, objective loss to minimize, etc.

This work, DomainBed, compiles many such researches and provides a standardized platform for comparing previous as well as upcoming approaches to tackle the DG problem.

Datasets

Some common datasets which are sectioned into domains and sub-sectioned into classes, and often used for DG analysis, are as follows:

Common datasets used for DG (Source : DomainBed)

Algorithms

Algorithms used for building generalizable machine learning models robust to distribution shift:

Expected Risk Minimization(ERM) : The learning paradigm which aims at predicting a hypothesis h that minimizes training error L(given below) over a samples distribution S(from some unknown distribution D) is referred by ERM.

Empirical Loss(Source : Uderstanding ML)

Fishr : This approach uses a regularization term on top of ERM in order to induce domain invariance to the learning model. The domain-level variances of gradients are matched across training domains by optimizing the modified loss. For more details, refer to this paper.

SWAD(Stochastic Weight Averaging Densely) : Several researches show that simply solving ERM on a complex, non-convex loss function can easily lead to sub-optimal generalizability by seeking sharp minima. This approach seeks to obtain a flat minima, resulting in smaller domain generalization gap. SWAD updates a pretrained model with a high constant learning rate scheduling. It then gathers model parameters for every K epochs during the update and averages them for the model ensemble. SWA finds an ensembled solution of different local optima found by a sufficiently large learning rate to escape a local minimum. For more details, refer this paper.

Model Selection

One of the major concerns regarding earlier approaches was the way their hyperparameters and architecture selection are determined. Since the test distribution can be potentially different from all the training domains as validation set, model selection is not as straightforward as other supervised learning tasks. Below described are 3 ways of model selection:

  1. Training-domain validation set : Each training domain is split into train and validation set. The validation sets of all domains are pooled together to form a single validation set and the model with maximum accuracy on pooled validation set is chosen.
  2. Leave-one-domain-out cross-validation : Given n training domains, n models are trained on (n-1) domains with same hyperparameters, leaving 1 domain out for validation each time. Average over n accuracies is taken, and a model(with the hyperparameters) is chosen based on the maximum average accuracy.
  3. Test-domain validation set (oracle) : This method determines the model by accessing the test domain, however the paper limits the number of queries to the test domain(~20) for model and hyperparameter selection.

Transformers for DG

Recent works(GMOE) have shown that vision transformers can yield SOTA accuracies when used for DG setup. Further, using Mixture of Experts layers embedded in multiple transformer blocks can further boost up the accuracies. And what’s more? Using these novel architectures are orthogonal to the DG algorithms like Fishr and SWAD, using which pushes the accuracies even further. The setup of transformers for DG is similar to classification task using ViT:

  1. Input image is divided into patches, embedded to feature space using MLP and added to positional encodings
  2. Classification token is appended to the above embeddings and the entire set of patches(along with classification token) is passed through multiple layers of attention and MLPs.
  3. Classification is done using a classification head(usually MLP, again) on only the classification token
Vision Transformers with Mixture of Experts (Source : )

References

  1. In Search of Lost Domain Generalization : https://arxiv.org/pdf/2007.01434.pdf
  2. Empirical Risk Minimization : https://towardsdatascience.com/learning-theory-empirical-risk-minimization-d3573f90ff77
  3. Transformers with MoEs for DG : https://arxiv.org/pdf/2206.04046v5.pdf

--

--

Harsh-Sensei

Pursuing B.Tech in Computer Science Engineering at IIT Bombay. Eternally excited about robotics, machine learning and computer graphics