Deep Domain Adaptation using PyTorch Adapt

Consider the scenario where you’ve trained a model on MNIST digits:

And you’ve been told that your model must also work on a dataset of randomly colored digits:

It turns out that you don’t have the labels for this dataset, so you can’t use supervised learning to retrain your model. But you can use domain adaptation, which is a type of algorithm for repurposing existing models to work in different domains.

Enter PyTorch Adapt, a new modular library for domain adaptation. You can use it with vanilla PyTorch, or with a provided framework wrapper. Let’s see how it works on the MNIST → MNIST-M task.

Using PyTorch Adapt for the MNIST → MNIST-M task

The following snippets are from this Jupyter notebook.

  1. Download the datasets and initialize a dataloader creator.

2. Setup the models (G and C), the domain adaptation algorithm (DANN), and the validator (IMValidator). We move much of the classifier model (C) to the trunk (G) because this tends to work better for DANN. To simplify our code, we’ll use the PyTorch Ignite wrapper.

3. Setup the visualization hook. For this demo I’ve written a simple function to visualize the features during training. (See the notebook for the function definition.) Since we’re using the PyTorch Ignite wrapper, we can use Ignite’s event handling system to add the visualizer hook.

4. Train the model. Here we train for only 4 epochs, though typically many more are needed for the best performance.

5. Compare the feature visualizations. Before training, the MNIST (blue) features are well clustered, but there is little overlap with the MNIST-M (orange) features. After 4 epochs of training, the overlap between the two domains has increased, indicating that the model is adapting to the new domain.

6. Compute accuracy on MNIST-M… or not? In a real-world application, you wouldn’t be able to compute accuracy because the target data doesn’t come with labels. But for the purpose of this demo, we’ll cheat and check the accuracy anyway.

The best accuracy after 4 epochs is 65.6%, compared to the starting accuracy of 57.4%. So training seems to be headed in the right direction.

To be continued

This post provides a brief overview of PyTorch Adapt. In my next post, I’ll explain how to easily customize algorithms using the pytorch_adapt.hooks module.

For more info, check out these links:

Computer science PhD student @ Cornell University (Cornell Tech).