Multi-Domain Learning

Labhesh Patel
Jumio Engineering & Data Science
6 min readJun 23, 2020

In the modern day world we live in, machine learning is becoming ubiquitous and is increasingly finding applications in newer and more varied problem areas. One of the primary reasons for this rapid advancement of machine learning is the availability of vast amounts of data to train the increasingly complex models. The present training paradigms, however, are restricted in the variety of data that they can handle. Most recent methods work with a single data source which typically comes from a narrow domain. The models often end up learning the bias inherent in the dataset and only perform well on the specific task. This approach also severely limits the generalization ability of such models when processing data from newer and previously unseen domains.

These problems are exacerbated when the data comes from domains with an extremely high degree of variance and when you need to build a single model to handle multiple such data sets. Training these traditional models on this wide variety of data domains makes it impossible for the models to learn nuances specific to each domain which are often crucial to performing a given task well. The common approach to this problem is retraining the model for each new data domain and applying the retrained model to classifying data points from that domain. Such an approach is extremely inefficient.

At Jumio we experience this challenge in several of the problems we work on. Consider the case of performing ID verification, something we care deeply about. The problem of ID verification involves processing an image which could represent anything from a driver’s license, to passport or a national identity card. The challenge is to be able to verify the veracity of the document itself and also use it to verify the correctness of personal information provided by the user.

As you can imagine, such a task is extremely hard due to the large variation in the kinds of ID cards which may be provided by the user. These documents have very different appearances depending on the country issuing them. Within a given country there may be additional degrees of variation depending on when and by which authority the ID was issued. Given the sensitivity of the task, maintaining a high degree of fidelity in the predictions generated by the model is crucial. A single model trained for this task would need to learn generic features for each of these IDs and hence would be unable to focus on the very fine grained features which are often needed to perform this task well. On the other hand, training a separate model or retraining the model for each document type and subtype is extremely inefficient and would require a large amount of data from every domain to train properly.

These problems motivated us to move toward multi-domain learning. Simply put, this is a paradigm which allows models to learn from a wide variety of domains without losing their ability to learn more nuanced features inherent in each of those domains.

Figure 1: This image shows an example of multi-domain learning where the images in the training set come from multiple different domains and the output domain could also differ substantially from the input domains.

Multi-domain learning is a relatively unexplored topic in the literature but there are some interesting approaches ready to be explored to address these challenges. One such approach is the “Branch-Activated Multi-Domain Convolutional Neural Network for Visual Tracking (BAMDCNN)” by Chen et al. They explore the use of a multi-branch architecture with different branches learning domain specific nuances while sharing a large number of domain agnostic layers. This hybrid architecture allows the model to learn fine grained features for each of the domains while not increasing the computation cost and number of parameters dramatically. The overall architecture of the BAMDCNN model is shown in Figure 2.

Figure 2: Overall model architecture of BAMDCNN

Model Architecture

We take inspiration from the BAMDCNN model and how it tackles visual tracking to present an approach for the ID verification problem we introduced earlier.

Our multi-domain learning model has three key components:

  • Generic Image Feature Extractor
  • Group Algorithm Based on Similarity
  • Branch Activation Method

Now let’s try to understand each of these components in detail:

  1. Generic Image Feature Extractor

This component is the shared part of the architecture which is domain-agnostic in nature. The same convolutional layers are used to extract image features, irrespective of the domain the image belongs to. In the BAMDCNN architecture, they use the relatively shallower VGG-M architecture for feature extraction to ensure computational efficiency while keeping a smaller receptive field. This allows the model to focus on the finer local features rather than global image features.

II. Group Algorithm Based on Similarity

For our work, we make no assumptions regarding the domains we receive our data from and so model our architecture to allow for an arbitrary number of domains with any amount of variation between them. To learn domain specific features in the branches of the multi branch architecture, the model needs to have some implicit notion of domain identification. Since we allow the domains to vary arbitrarily, such a distinction must be data driven. We apply a grouping algorithm based on similarity which uses unsupervised clustering to inherently create the notion of domains from data. One major advantage this offers is the ability to dynamically adjust the number of domains depending on the data itself. This frees us from having to pre-specify the number of domains and allows us the ability to use this in an online learning setting where the model is constantly adapting its definition of “domains.”

The clustering is performed in an online manner to allow for integration in an online learning setting. Each new ID seen by the clustering approach is first passed through the feature extractor. The resulting image representation is assigned to the cluster centroid closest to it provided that it satisfies two conditions:

i) The distance of this new example from the closest cluster center is below a certain threshold distance D

ii) The number of data points in the cluster to which this sample is the closest has a number of samples which is below a certain threshold M

The first condition is imposed to ensure coherence within a domain cluster so the variance within a domain doesn’t become too large. The second condition ensures that specific clusters don’t become too large, potentially absorbing multiple distinct domains within themselves. This also promotes the identification of more compact sub-domains within each larger domain.

If the new sample satisfies these constraints we then assign it to the cluster and update the cluster center. If the first condition fails then the sample is assigned as a new clustering point. The points which are closer to the new sample when compared to their own respective cluster centers are reassigned to this one and all relevant cluster centers are updated. If the sample ID fails the second condition we assign it to this cluster and then split the cluster into two new ones. We specify with the farthest two points in the cluster as the new cluster centers and reassign each point in the cluster to either of the new ones based on proximity.

III. Branch Activation Method

Once the clustering step is complete, a branch is created for each cluster since they are representative of different domains. The model is trained using the regular backpropagation algorithm with the same loss functions as regular models. The only difference is that the domain-specific branches update their weights only when training the network with points belonging to the cluster which corresponds to that branch. The shared layers still undergo weight updates for each training sample. This training method allows the model to learn domain specific features using the domain specific branch weights while a majority of the computation is shared between the different domains. This ensures the number of parameters doesn’t increase substantially leading to more efficient training than having separate models for each domain.

Conclusion

The extremely diverse nature of data today with representations from a large variety of domains has necessitated the need for models which can generalize to all of these domains while retaining their ability to learn fine grained nuances specific to each domain. We explored one such approach and discussed how it can be used to solve the problem of ID verification given the inherent diversity in the different ID cards. We hope that these training principles will guide more people to incorporate such multi-domain learning principles to their respective models!

--

--