SGD implementation in PyTorch
The subtle difference can affect your hyper-parameter schedule
The implementation of SGD with Momentum/Nesterov subtly differs from Sutskever et. al. and implementations in some other frameworks.
This is the formula the paper used (g being the gradient; v the velocity; p the parameters/weights; rho the momentum coefficient; lr the learning rate):
And this is the formula PyTorch used:
The only difference is where the learning rate is applied. In the paper the learning rate is applied when calculating new velocity, and in PyTorch the learning rate is applied when calculating new parameters/weights.
How it can affect the tuning schedule
It may or may not have observable impacts on train and validation loss, but being aware of the difference can help guide the tuning schedule toward the right direction.
For example, if we’re using
torch.optim.lr_scheduler.ReduceLROnPlateau schedule that reduce the learning rate once the validation score plateaus. If we misunderstood the PyTroch SGD implementation to be the one in the paper, we’d expect gradients to have much less influence in later velocity updates; in other words, we’d expect the momentum to increase. But in reality the momentum did not change. Instead, we were just getting smaller changes in parameters in each iteration.
If we substitute
v' = lr * v in the PyTorch formula, we get
p = p — v' and
v' = (lr * rho) * v + lr * g. If we set
rho' = lr * rho we’d get the same formula in the paper. So what PyTorch does is actually adjusting the momentum coefficient relative to the learning rate, so momentum stays invariant to changes in learning rate.
It’s not hard to modify the SGD implementation in PyTorch and make it consistent with the paper (If that’s what you want).
If we take a look at the source code, we’d find it quite easy to read:
Line 23 and 25 get the gradients. Line 30, 33, and 35 update the velocity. Line 39 updates the parameters.
So if we just tweak line 19 and apply the learning rate directly on the gradients (line 27) we’d have implemented the formula in the paper:
The small difference in implementation might not be a big deal, but can cause you some confusion when tuning if you have not understood it correctly. Moreover, tuning algorithms that are based on the alternative formula may not work as expected in PyTorch. For example, in YellowFin paper this is used, where learning rate and momentum coefficient are decoupled:
You’ll need to be careful when implementing those algorithms in PyTorch. Otherwise a huge amount of time is likely to be wasted on debugging.