The Migration Guide from Chainer to PyTorch
Authored by Preferred Networks
Welcome to the migration guide from Chainer to PyTorch! As announced in December 2019, the Chainer team has decided to shift our development efforts to the PyTorch ecosystem. Today we would like to introduce pytorch-pfn-extras (PPE), which bridges the gap between Chainer and PyTorch, and guide you through how you can migrate from Chainer to PyTorch using PPE.
What is pytorch-pfn-extras?
pytorch-pfn-extras is a pure Python library which implements supplementary components for PyTorch. The initial version of the library contains the following features:
- Extensions Manager & Reporter provides an interface to extend your training code. It is similar to Chainer’s Training API, but training loops are not managed by PPE so that you can use it with your manual training loops or training frameworks like Ignite. PPE is shipped with a variety of extensions for logging, progress bar, parameter/variable statistics, snapshot, etc., which are ported from Chainer.
- Lazy Modules provides Linear/ConvXd modules that infer the parameter shape so that you can define Linear/ConvXd layers without specifying input data size.
- Distributed Snapshot provides an extension to manage the generation of snapshot during training.
Migration from Chainer code
The goal of this section is to illustrate how to translate a script written in Chainer to PyTorch + pytorch-pfn-extras, taking a simple MNIST MLP training script as an example.
You can see the original Chainer code and migrated PyTorch code in full. You may notice from the diff that these codes are similar to each other.
Training/Evaluation loops
In the training loop, you can define Extensions Manager (ppe.training.ExtensionsManager) and inject Extensions to the manager to enrich your training loop. The manager has a role similar to the Trainer object in Chainer, but it does not manage the training loop itself.
The manager object manages extensions registered and calls them accordingly.
manager
counts up the number of iterations each time the context manager manager.run_iteration()
is called. ppe.reporting.report
is called len(train_iter)
times during one epoch, and the average of them is displayed at the end of epoch.
# Chainer
trainer.run()
The evaluation function can be specified to extensions.Evaluator
as an argument eval_func
. It will be called automatically at each end of the epochs. The argument passed to the evaluation function is a single unit of batch of test_iter
.
Models
Most functions and links (“module” in PyTorch terms) in Chainer are available in PyTorch. Please refer to the mapping table in the Framework Migration Guide.
The following example illustrates the difference of network definition in both frameworks:
In the PyTorch code above, we use the LazyLinear
module implemented in PPE. Unlike torch.nn.Linear
module, it accepts None as an input size argument; in this case the size of inputs is inferred during the forward process. Note that you need to run a dummy forward call before registering parameters to the optimizer (see the full code for details).
Snapshot
You can take snapshots by using the snapshot extension.
# Chainer
trainer.extend(extensions.snapshot(n_retains=1, autoload=True), trigger=(frequency, 'epoch'))# PyTorch
manager.extend(ppe.training.extensions.snapshot(), trigger=(frequency, 'epoch'))
All information used in a training loop (e.g., models, optimizers, updaters, iterators) are collected from the manager and serialized accordingly. When used in conjunction with torch.distributed, pass saver_rank
keyword argument to the snapshot extension to tell the MPI rank which will write the actual snapshot.
All the ranks need to execute the extension due to synchronization happening inside. Otherwise, a deadlock will happen.
Further Reading
For further information, please see our Framework Migration Guide and examples in pytorch-pfn-extras repository.