PyTorch — Dynamic Batching

Illia Polosukhin
Sep 6, 2017 · 3 min read

If you have been reading my blog, you may have seen that I was a TensorFlow contributor and built a lot of high-level APIs there.

In Feb 2017 though, I have left Google and co-founded my own company — Where we are teaching machines to write code from natural language.

As part of this work, we are building Deep Learning models that are reading or writing code in a tree format. After trying to manage this complexity in TensorFlow, I’ve decided to give a try to PyTorch.

PyTorch is a framework built by Facebook AI researchers and has been growing in popularity in Natural Language and Reinforcment Learning research community. It’s main benefit is in dynamic graph building principle — compared to Tensorflow, where graph is built once and then “executed” many times, PyTorch allows to dynamically rebuild graph using simple Python logic, as if you were doing computation with numpy arrays.


This flexibility attracted people who work with complex input/output data [e.g. language, trees, graphs] or need to run some custom logic in the middle of the computation [e.g. Deep RL].

Here I want to talk about batching things. Even though PyTorch is fast by using GPU accelerators and in general pushing computation on C modules, if you are not batching your computation — you are still going to pay the toll.

Recursive neural network [TreeLSTM as an example] are especially hard to batch, as each example is a different tree.

The naive implementation would look like this:

There is a way to batch this manually: going after each operation that processes inputs differently, figuring out how to batch inputs and then unbatch outputs. Here is an example of this in great article by James Bradbury.

Alternative, is to have a system that would decide to batch things for us depending on exact inputs / outputs we want to compute. Inspired by method described in paper by Moshe et al. “Deep Learning with Dynamic Computation Graphs” [implemented in TensorFlow Fold but seems to be not maintained], very well depicted in this animation:


I have implemented this principles in a simple TorchFold class, with next interface:

See full implementation at

Now, if we want to encode tree with a TreeLSTM / Model from previous gist, here’s how we will need to change the code:

Here, at every invocation of encode_tree_folded, the Fold graph is dynamically constructed by adding nodes via fold.add, where op is the name of the function in model to be called. It automatically figures which ops can be groups together and which should follow.

Then at fold.apply time, the operations from passed model are called, passing them batched input tensors [possibly with different batch sizes at different steps] and routing outputs automatically to next steps.

Comparing speed between unfolded and folded versions (on a simple model here):

Regular: 0.18 sec/step (100 dim), 2.19 sec/step (500 dim)

Fold: 0.05 sec/step (100 dim), 0.22 sec/step (500 dim)

Getting 3–10x speed up, due to reducing inefficiency in computations.

This tool is generally useful for any complex architecture [including RNN] as it removes need to think about batching at least for first experiments.

You can find implementation and examples here:

PS. While writing this article, I have found a recent article on this topic — with implementation for DyNet.

PSS. Since upgrading to PyTorch 0.2.0 I saw a slight degradation in performance of TorchFold, so for best speed try running with 0.1.12 until it’s fixed.


Teaching Machines to Code

Illia Polosukhin

Written by

Co-Founder @ NEAR Protocol - leading mobile blockchain revolution. I'm tweeting as @ilblackdragon.



Teaching Machines to Code