Effect of Batch Size on Neural Net Training
--
co-authored with Apurva Pathak
Welcome to the first installment in our Deep Learning Experiments series, where we run experiments to evaluate commonly-held assumptions about training neural networks. Our goal is to better understand the different design choices that affect model training and evaluation. To do so, we come up with questions about each design choice and then run experiments to answer them.
In this article, we seek to better understand the impact of batch size on training neural networks. In particular, we will cover the following:
- What is batch size?
- Why does batch size matter?
- How do small and large batches perform empirically?
- Why do large batches tend to perform worse, and how can we close the performance gap?
What is batch size?
Neural networks are trained to minimize a loss function of the following form:
where
- theta represents the model parameters
- m is the number of training data examples
- each value of i represents a single training data example
- J_i represents the loss function applied to a single training example
Typically, this is done using gradient descent, which computes the gradient of the loss function with respect to the parameters, and takes a step in that direction. Stochastic gradient descent computes the gradient on a subset of the training data, B_k, as opposed to the entire training dataset.
B_k is a batch sampled from the training dataset, and its size can vary from 1 to m (the total number of training data points) [1]. This is typically referred to as mini-batch training with a batch size of |B_k|. We can think of these batch-level gradients as approximations of the ‘true’ gradient, the gradient of the overall loss function with respect to theta. We use mini-batches because it tends to converge more quickly, since it doesn’t need to make a full pass through the training data to update the weights.
Why does batch size matter?
Keskar et al note that stochastic gradient descent is sequential and uses small batches, so it cannot be easily parallelized [1]. Using larger batch sizes would allow us to parallelize computations to a greater degree, since we…