torchdistill — a modular, configuration-driven framework for reproducible deep learning and knowledge distillation experiments

Author: Yoshitomo Matsubara

Yoshitomo Matsubara
PyTorch

--

This article summarizes key features and concepts of torchdistill (v1.0.0). Refer to the official documentation for its APIs and research projects.

Logo

torchdistill is a modular, configuration-driven machine learning open source software (ML OSS) for reproducible deep learning and knowledge distillation experiments. The ML OSS is available as a PyPI package (pip install torchdistill) and offers various state-of-the-art knowledge distillation methods and enables you to design (new) experiments simply by editing a declarative yaml config file instead of Python code. Even when you need to extract intermediate representations in teacher/student models, you will NOT need to reimplement the models, that often change the interface of the forward, but instead specify the module path(s) in the yaml file.

torchdistill is modular, configuration-driven

Pipeline with abstracted modules + a declarative PyYAML config file

In torchdistill, many components and PyTorch modules are abstracted e.g., models, datasets, optimizers, losses, and more! You can define them in a declarative PyYAML config file, which contains almost everything to reproduce the experimental result and can be seen as a summary of your experiment. In many cases, you will NOT need to write Python code at all. Take a look at some configurations available in configs/. You’ll see what modules are abstracted and how they are defined in a declarative PyYAML config file to design an experiment.

For example, you can instantiate torchvision.datasets.CIFAR10 for its training and test datasets in two lines, using a PyYAML configuration file (test.yaml).

from torchdistill.common import yaml_util
config = yaml_util.load_yaml_file('./test.yaml')
train_dataset = config['datasets']['cifar10/train']
test_dataset = config['datasets']['cifar10/test']

test.yaml

datasets:
&cifar10_train cifar10/train: !import_call
_name: &dataset_name 'cifar10'
_root: &root_dir !join ['~/datasets/', *dataset_name]
key: 'torchvision.datasets.CIFAR10'
init:
kwargs:
root: *root_dir
train: True
download: True
transform: !import_call
key: 'torchvision.transforms.Compose'
init:
kwargs:
transforms:
- !import_call
key: 'torchvision.transforms.RandomCrop'
init:
kwargs:
size: 32
padding: 4
- !import_call
key: 'torchvision.transforms.RandomHorizontalFlip'
init:
kwargs:
p: 0.5
- !import_call
key: 'torchvision.transforms.ToTensor'
init:
- !import_call
key: 'torchvision.transforms.Normalize'
init:
kwargs: &normalize_kwargs
mean: [0.49139968, 0.48215841, 0.44653091]
std: [0.24703223, 0.24348513, 0.26158784]
&cifar10_test cifar10/test: !import_call
key: 'torchvision.datasets.CIFAR10'
init:
kwargs:
root: *root_dir
train: False
download: True
transform: !import_call
key: 'torchvision.transforms.Compose'
init:
kwargs:
transforms:
- !import_call
key: 'torchvision.transforms.ToTensor'
init:
- !import_call
key: 'torchvision.transforms.Normalize'
init:
kwargs: *normalize_kwargs

When you take a close look at the above PyYAML, you will notice that !import_call, a constructor instantiated a class specified by key using kwargs under init (if kwargs does not exist, an empty dict will be used as kwargs). As you can see, the instantiation process is recursive.
I.e., config['datasets']['cifar10/train'] will be set by instantiating the following classes in order

  1. torchvision.transforms.RandomCrop as part of a list transforms
  2. torchvision.transforms.RandomHorizontalFlipas part of a list transforms
  3. torchvision.transforms.ToTensor as part of a list transforms
  4. torchvision.transforms.Normalize as part of a list transforms
  5. torchvision.transforms.Compose as transform
  6. torchvision.datasets.CIFAR10 as cifar10/train

Basically, you can use !import_call for any modules/functions in your locally installed torch and torchvision.
E.g., if you want to use CIFAR100 instead of CIFAR10, then you can just replace torchvision.datasets.CIFAR10 with torchvision.datasets.CIFAR100.

If you want to use !import_call for your own modules, refer to the documentation.

You can instantiate other types of modules (e.g., models) using !import_call , but it is not necessary to use the constructor for all the module types in a PyYAML file.

Example 1: Reproducing results of KD methods for ImageNet (ILSVRC 2012)

Using torchdistill, I reimplemented about 20 different KD methods. Some of the methods were tested in the papers for ResNet-34 and ResNet-18, a popular teacher-student pair for ImageNet (ILSVRC 2012).

I attempted to reproduce reported accuracy of ResNet-18 for ImageNet (ILSVRC 2012) dataset, using a popular teacher-student pair: ResNet-34 and ResNet-18 (except for Tf-KD, where ResNet-18 was used as a teacher). Hyperparameters are the same as those provided in the papers/code or by the authors.

https://yoshitomo-matsubara.net/torchdistill/benchmarks.html

The student model trained with all the reimplemented methods achieved better accuracy than the same model trained without the teacher model. However, most of the results (even the reported numbers in the original papers) did not outperform a standard KD method proposed in “Distilling the Knowledge in a Neural Network” (Hinton et al., 2014). See the first torchdistill paper for details.

All the configurations, checkpoints, and script are available in the official code repository.

Example 2: Reproducing GLUE test results of BERT

Similarly, I attempted to reproduce GLUE test results of fine-tuned BERT models in “BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding” (Devlin et al., 2019), using torchdistill with Hugging Face libraries such as transformers, datasets, evaluate, and accelerate.

https://aclanthology.org/2023.nlposs-1.18/

The results are from GLUE Benchmark (not the dev results). The fine-tuned BERT-Base and BERT-Large models achieved test results comparable to the results reported by Devlin et al., 2019.

Besides the standard fine-tuning (FT) experiments, I conducted knowledge distillation (KD) experiments for fine-tuning BERT-Base, using fine-tuned BERT-Large as teacher models. The KD method (Hinton et al., 2014) helped BERT-Base models improve the performance for most of the tasks, compared to those fine-tuned without the teacher models.

Those experiments were done on Google Colab. All the configurations and script are available in the official code repository. The model checkpoints and training log files are available at the Hugging Face Model repositories.

Final words

This article briefly introduced torchdistill and only a few of its features. As a PyPI package, torchdistill also offers popular small models, forward hook manager, dataset/model wrappers and loss modules for reimplemented KD methods, and more.

The GitHub repository also provides example scripts, Google Colab examples, demo e.g., extracting intermediate layers’ input/output (embeddings) without any modifications in model implementations. To learn more about torchdistill, see the repository and documentation.

If you have either a question or feature request, use GitHub Discussions. Please search through GitHub Issues and Discussions and make sure your issue/question/request has not been addressed yet.

Publications

--

--

Yoshitomo Matsubara
PyTorch
Writer for

ex-Applied Scientist at Amazon and an ML OSS developer. PhD in Computer Science. https://yoshitomo-matsubara.net/