Increasing Mini-batch Size without Increasing Memory

David Morton
2 min readApr 9, 2018

--

The most popular technique used to train neural networks today is gradient descent, where the error of the network is minimized by determining the first partial derivative of the error with respect to every adjustable parameter in the network. Knowing all of these partial derivatives (a.k.a the gradient) is enough to determine how to update all the parameters and lower the error a little bit. But since most datasets are too large to fit into memory, a randomly sampled mini-batch approach, dubbed stochastic gradient descent, is typical.

Recent research has claimed that by using larger mini-batches, one can safely make larger parameter updates (higher learning rates) and can therefore train networks faster, both in terms of the number of parameter updates and in terms of wall-clock time.

In this short post I will describe how you can train neural networks in pytorch without increasing memory usage. The technique is simple, you just compute and sum gradients over multiple mini-batches. Only after the specified number of mini-batches do you update the model parameters. In this way you can trade compute time for memory, but with the allowed increase in the learning rate afforded by the large (effective) batch size you can still come out ahead. At least that’s what I hope to show, later, in another post.

The technique (with example pytorch code)

A typical training loop in pytorch looks something like this:

You can obtain an effective mini-batch size that is a factor of batch_multiplier larger, where batch_multiplier is a positive integer with this modified loop:

Notice that all we’ve done is refrain from calling on the optimizer as often and divided the loss by the batch_multiplier. The division is needed because loss.backward sums the gradients each time through the loop.

Verifying equivalence with large mini-batches

To demonstrate that these virtual batches are equivalent to large (real) batches let’s examine a learning rate survey. I’ve posted about learning rate surveys before, but briefly, they are a plot of the training loss vs. learning rate for a model that was trained on an exponentially increasing learning rate.

As we can see, the models trained with actual large batches behave indistinguishably from the models trained using our new technique. Also, it’s clear from this survey that large batch sizes do in fact allow us to train with a larger learning rate before training becomes unstable.

This figure was generated using the pytorch-sconce library, which has support for virtual batches via the batch_multiplier parameter. Code to reproduce this figure can be found here.

--

--