Optimizing deeper networks with KFAC in PyTorch.

Optimization becomes less effective in first order methods like Adam as batch-size and depth increases. Second order methods like KFAC (an approximate natural gradient method) are a bit more expensive, but are much less affected by depth. For a difficult problem this translates to savings in wall-clock time.

I’ve recently experimented with KFAC in PyTorch. Its imperative style of programming made it easier to prototype optimization algorithms than graph-based approach of TensorFlow. For a fully connected network, an existing optimizer can be augmented with KFAC preconditioning in just a few lines of PyTorch, see “Implementation” below.

Consider the following autoencoder on MNIST

Optimizing this architecture using batch of size 10k, the advantage of KFAC is stark, 100x less iterations and 25x less wall-clock time than Adam to reach the test loss minimum

Description of experiment is here

Derivation

Traditional derivation of KFAC (Martens, Grosse, Ba) is motivated by information geometry. Below I give an alternative derivation — KFAC-style update is simply the Newton-step for a deep linear neural network.

To more derivation concrete, consider optimizing a deep fully-connected linear autoencoder. Without loss of generality, we can write our predictions Y as a function of parameter matrix W as follows:

Given labels \hat{Y} we can write our prediction error e and loss J:

To minimize J, we differentiate with respect to W and get the following for our gradient and SGD update rule:

Note that quantity “Be” in equation above is equal to the backprop matrix you get in a reverse-mode AD algorithm. It is the “grad_output” quantity passed into PyTorch backward() method.

To get the Hessian, we differentiate our gradient G again to get the result in terms of Kronecker product:

Dividing by the Hessian and rearranging, our Newton-update step becomes this

Matrices on each side of G are known as whitening matrices. The first matrix is the backprop whitening matrix, while the second matrix is the activation whitening matrix.

Note that matrix B is not directly available during backprop, and using “grad_output” in its place will get use Bee’B’ instead of BB’. That’s not a problem since we can generate any “e” by selecting target labels accordingly. Some choices for e:

  1. Padded identity matrix so that ee’ is identity. Then Bee’B’=BB’ exactly
  2. IID gaussian values. Then, Bee’B’=BB’ in expectation

Detailed derivation is here

Implementation

Basic implementation has three parts:

  1. capture: compute gradients and save forward/backprop values
  2. invert: compute whitening matrices
  3. kfac: apply whitening matrices during gradient computation

Steps “capture” and “kfac” can be accomplished with a version of Addmm that has a custom “backwards” method:

The body of training loop then looks like this

Full implementation is here

Note

The difference between Adam and KFAC shrunk to about 5x improvement in wall-clock time when I tweaked the experiment to make it more amenable to SGD

  1. Replace sigmoid activations with ReLU
  2. Add weight normalization
  3. Use batch size 128 for Adam