Why are deeper neural networks harder to train?
Deep neural networks are an important class of neural networks that have been applied to numerous machine learning areas, such as natural language processing [WE:2016], computer vision [KR:2012] and speech recognition [HI:2012]. Training of such networks is often successfully performed by minimising a high-dimensional non-convex objective function. In a theoretical sense, we have only scratched the surface of this optimisation problem and a number of important questions remain to be proved about its behaviour.
A recent paper “Gradient Descent Finds Global Minima of Deep Neural Networks” by Du, Lee, Li, Wang and Zhai [DU:2018] sheds light on two unexplained behaviours of deep neural networks:
- you can achieve zero training loss for random initialised first order methods, even if the labels are arbitrary, and
- deeper networks are harder to train.
Du et al.’s paper considers three deep neural network architectures:
- multilayer fully-connected neural networks,
- deep residual network (ResNet)*, and
- convolutional ResNet.
They use a randomly initialised gradient descent algorithm to find the global minimiser of the empirical loss for learning. Their paper focuses on proving how much over-parameterisation is needed to ensure the global convergence of gradient descent. This provides a necessary condition for zero training loss for the three specified architectures and as well as an understanding of why deeper networks are harder to train.
Zero training loss for random initialised first order methods
It is well-known that using random initialised first order methods, like gradient descent for deep learning can achieve zero training loss, even for arbitrary labels [ZH:2016]. Over-parameterisation is the generally agreed reason for this, since if the neural network has a sufficiently large capacity it is possible for the neural network to fit all the training data. Highly over-parameterised architectures are common in practice, such as Wide Residual Networks that have a factor of one hundred times more parameters than the number of training data [ZA:2016].
Theorems 3.1, 4.1 and 5.1 [DU:2018] provide sufficient conditions for guaranteeing that gradient descent achieves zero training loss time for deep over-parameterised neural networks for each of the three architectures. These theorems states that if the width is large enough and the step size is set appropriately, then the gradient descent will converges to the global minimum with zero loss at linear rate.
Du et al. require the following conditions:
- the last layer’s Gram matrix is strictly positive definite,
- every two by two sub-matrix of every layer for multilayer fully-connected neural networks and the first layer for ResNet, has a lower bounded eigenvalue, and
- the activation function is Lipschitz and smooth.
These are relatively weak assumptions. The first condition provides non-degeneracy condition on the matrix. The second condition is a stability assumption guaranteeing that if the width is large, the Gram matrix at the initialisation phase will be close to the population Gram matrix. The last condition holds for many activation functions and allows for non-linear functions such as soft-plus.
Deeper networks are harder to train
Heuristically, we know deeper networks are harder to train. Various methods have been proposed to deal with this problem. He et al. [HE:2016] proposed the deep residual network (ResNet) architecture. ResNet uses a randomly initialised first order method to train neural networks with an order of magnitude more layers. Du et al. note that “Theoretically, Hardt and Ma [HA:2016] showed that residual links in linear networks prevent gradient vanishing in a large neighbourhood of zero, but for neural networks with non-linear activations, the advantages of using residual connections are not well understood.” [DU:2018]. Why does using residual connections in a deep residual network (ResNet) architecture offer a better convergence to the fully-connected feedforward networks?
The answer to this question falls out of Du et al.’s analysis. The bounds provided in the Theorems 3.1, 4.1 and 5.1 [DU:2018] depends on the number of neurons per layer. The nature of this dependency is different between the different architectures; with feed forward networks scaling exponentially with depth while ResNet only requiring polynomially scaling. For feed forward networks, “ the exponential dependency results from the amplification factor of multilayer fully-connected neural network architecture” [DU:2018]. These theorems clearly demonstrate the advantage of using residual connections and the underlying complexity of the multilayer fully-connected neural network.
Du et al.’s paper [DU:2018] proves that gradient descent on deep over-parametrised networks can obtain zero training loss with some relatively weak assumptions on the neural network. They also provide an explanation of why deeper networks are harder to train and why ResNet is better than a multilayer fully-connected neural network in terms of the convergence of the training loss.
In practice stochastic gradient descent are more likely to be used instead of gradient descent. However, we note that Du et al. expect that their analysis can be extended to stochastic gradient methods with similar convergence rates.
Their focus is on training loss and they do not look at test loss. It is important to be able to prove that gradient descent can also find low test loss solutions. Unfortunately existing research in this area is limited and this remains an open question.
* Du et al.’s ResNet architecture is modified, without loss of generality, to skip-connections at every layer, in contrast to the standard ResNet that skip-connections of every two layers. This provides an easier analysis and the results in the paper can be generalised to standard architectures.
[DU:2018] Du Simon S.,Lee, Jason D., Li Haochuan, Wang Liwei, Zhai Xiyu: Gradient Descent Finds Global Minima of Deep Neural Networks arXiv preprint arXiv:1811.03804., 2018.
[HA:2016] Hardt Moritz, Ma Tengyu: Identity matters in deep learning. arXiv preprint arXiv:1611.04231, 2016.
[HE:2016] He Kaiming, Zhang Xiangyu, Ren Shaoqing, Sun Jian: Deep residual learning for image recognition. In Proceedings of the IEEE conference on computer vision and pattern recognition, pages 770–778, 2016.
[HI:2012] Hinton, Geoffrey, Deng, Li, Yu, Dong, Dahl, George E, Mohamed, Abdel-rahman, Jaitly, Navdeep, Senior, Andrew, Vanhoucke, Vincent, Nguyen, Patrick, Sainath, Tara N, et al.: Deep neural networks for acoustic modeling in speech recognition: The shared views of four research groups. IEEE Signal Processing Magazine, 29(6):82–97, 2012.
[KR:2012] Krizhevsky, Alex, Sutskever, Ilya, and Hinton, Geoffrey E.: Imagenet classification with deep con- volutional neural networks. In Advances in neural information processing systems, pp. 1097–1105, 2012.
[WU:2016] Wu, Yonghui, Schuster, Mike, Chen, Zhifeng, Le, Quoc V., Norouzi, Mohammad, Macherey, Wolfgang, Krikun, Maxim, Cao, Yuan, Gao, Qin, Macherey, Klaus, Klingner, Jeff, Shah, Apurva, Johnson, Melvin, Liu, Xiaobing, Kaiser, Lukasz, Gouws, Stephan, Kato, Yoshikiyo, Kudo, Taku, Kazawa, Hideto, Stevens, Keith, Kurian, George, Patil, Nishant, Wang, Wei, Young, Cliff, Smith, Jason, Riesa, Jason, Rudnick, Alex, Vinyals, Oriol, Corrado, Greg, Hughes, Macduff, and Dean, Jeffrey: Google’s neural machine translation system: Bridging the gap between human and machine trans- lation. CoRR, abs/1609.08144, 2016
[ZA:2016] Zagoruyko Sergey, Komodakis Nikos: Wide residual networks. NIN, 8:35–67, 2016.
[ZH:2016] Zhang Chiyuan, Bengio Samy, Hardt Moritz, Recht Benjamin, Vinyals Oriol: Understanding deep learning requires rethinking generalization. arXiv preprint arXiv:1611.03530, 2016.