K-FACとは?
大規模深層学習のための二次最適化の実現
深層学習における効率的な二次最適化の実現手法であるKronecker-factored Approximate Curvature (K-FAC) (James Martens et al., 2015) についてまとめる。
概要
Kronecker-factored Approximate Curvature (K-FAC) は(当時)University of Toronto の Ph.D. 学生であった James Martens らによって、2015年のICMLで提案された深層ニューラルネットワークのための最適化手法である。大規模な深層学習において応用が限られていた二次最適化手法のボトルネックである曲率 (Curvature) 計算の効率的な実現手法である。K-FACをはじめとする効率的な二次最適化手法の登場により、近年では、深層学習の学習時間短縮の目的において、二次最適化手法の高い収束性による恩恵が見直されてきている。
自然勾配学習法
自然勾配学習法(Natural Gradient Learning)は1998年にShun-ichi Amariによって提案された情報幾何に基づくニューラルネットワークの最適化手法である。自然勾配学習ではFisher情報行列を曲率として用いることで目的関数の地形を正確に捉え、「反復数」において学習の高速な収束を実現する。このため、自然勾配学習法は二次最適化の効率的な実現法としても知られる。
入力xに対して、yの確率を出力する確率モデルp(y|x;θ)のFisher情報行列は以下で定義される。
(データについての期待値をとるこの定義はempirical Fisherと呼ばれる。ミニバッチを用いた確率的学習の場合はミニバッチ内での平均に置き換えられる。)画像分類のための深層学習においては、負の対数尤度の期待値を(ニューラルネットワークによる予測と正解データとの誤差を定量化した)誤差関数とするのが一般的であり、Fisher情報行列を曲率の近似とみなすことが可能である。誤差関数に負の対数尤度が用いられている場合の誤差関数のHessianとFisher情報行列との関係は以下の式で表される。
このFisher情報行列を用いて、自然勾配学習法では、以下の更新則でパラメータが更新される。
ここで、関数の一次勾配にFisher情報行列の逆行列を適用した結果を自然勾配と呼ぶ。しかし、深層学習で用いられるニューラルネットワークの膨大なパラメータ数Nに対し、Fisher 情報行列サイズはN×Nであるため、この逆行列計算は現実的ではない。この問題のため、「学習時間」において高速化されない自然勾配学習法の利用は限られてきた。
自然勾配近似手法
近年、自然勾配学習法において計算のボトルネックとなっているFisher情報行列の逆行列計算を軽量化する(もしくは計算しない)ための近似手法が提案されてきており、深層学習の研究者らによって自然勾配学習法の「高速な収束性」が再評価されてきている。
近似のアプローチは大きく分けて3つ。(分類はこちらの文献を参考にした。)
- Fisher情報行列を(逆行列が計算し易いように)近似する
- Fisher情報行列を単位行列に近づける(reparameterization)
- 自然勾配を直接近似する
クロネッカー因子分解を用いた自然勾配近似(K-FAC)
K-FACも自然勾配近似手法の一つであると見なすことができ、上記の3つのアプローチのうち「1. Fisher情報行列を(逆行列が計算し易いように)近似する」手法に相当する。特に、その他の自然勾配近似手法と比べて、数学的原理に基づく最も効率的な近似手法である。
K-FACではFisher情報行列のブロック対角近似を行う。ここで、各対角ブロックは、Fisher情報行列のうち、ニューラルネットワークの各層のパラメータに対応した値である(例:3層のネットワークだと3つの対角ブロックができる)。次に、各ブロックを2つの行列のクロネッカー積で近似を行う(これをクロネッカー因子分解 (Kronekcer-factorization)と呼ぶ)。
ここで、行列のクロネッカー積の重要な性質である
を用いることで、Fisher情報行列の逆行列計算が(Fisher情報行列と比べ)非常に小さいクロネッカー因子の逆行列計算によって近似されることになる。
クロネッカー因子のサイズ(どれほど逆行列が軽量化されるか)を明らかにするために、クロネッカー因子分解の仕組みを解説する。1つの全結合層を例に取り、Fisher情報行列のこの層に対応する対角ブロック(便宜上、Fisherブロックと呼ぶ)に注目する。i層目におけるFisherブロックは
と表せる(期待値の表記は簡略化)。ここで、∇iはi層目のパラメータについての勾配である。深層ニューラルネットワークにおける効率的な勾配計算方法である誤差逆伝播法を用いれば
として、(各サンプルごとの)対数尤度の勾配を二つのベクトルのクロネッカー積で表現することができる。この関係を用いるとFisherブロックは
と「クロネッカー積の期待値」の形に変形することができる。K-FACではこのFisherブロックを近似し、「クロネッカー積の期待値」を「期待値のクロネッカー積」に変換する(クロネッカー因子分解)。
先に説明したように、クロネッカー因子分解により、Fisherブロックの逆行列計算が大幅に軽量化される。この効果を画像分類の分野で頻繁に用いられるAlexNetを例に取り解説する。ImageNet(1,000クラス分類)向けのAlexNetの最終層の全結合層に注目して行列サイズを比較した結果が以下の図になる。
ここまでをまとめると、K-FACとは、
- Fisher情報行列のブロック対角化(各対角ブロックは各層に対応)により、「層を跨いだパラメータ」の相関を無視
※ ブロック三重対角(隣接する層のパラメータの相関は考慮)を用いる手法も存在する。 - 各対角ブロック(Fisherブロック)のクロネッカー因子分解により、各層における「入力」と「出力についての勾配」の相関を無視
- 1. 2. の近似により、効率的にFisher情報行列の逆行列を計算し、自然勾配を求める。
を行う自然勾配近似法であると言える。ここでは全結合層のためのK-FACを紹介したが、畳み込み層のためのK-FACでは、2.に加えさらに近似が適用される。詳細は別資料を参照されたい。
最後に画像データセットCIFAR-10の分類問題(10クラス分類)を例にK-FACの有効性を示す。深層学習において一般的に用いられる確率的勾配降下法(SGD)、近似を伴わない自然勾配学習法(NGD)、そしてK-FACの学習曲線の比較を下図に示す。「反復数」においてNGDはSGDよりも高速に収束していることが分かるが、NGDは一反復あたりの計算時間が重いことから「経過時間」においてはSGDよりも遅い、ということが分かる。一方でNGDの近似手法であるK-FACは「反復数」においてNGDの学習曲線を再現している上に、効率的な計算の軽量化により、「経過時間」においても、SGDより高速に学習を完了させている。
これが、K-FACを始めとする自然勾配近似手法を導入する動機であるが、ImageNetを始めとする大規模深層学習においては、K-FACの応用は限られており、SGDに対する有効性の検証は十分になされていない。
その他のK-FACの応用例
本記事では、画像分類のための畳み込みニューラルネットワーク(CNN)の学習へK-FACを応用した例を紹介した。K-FACは既に他の種類のネットワークに応用されている。
- Recurrent Neural Network (RNN) への応用
Kronecker-factored Curvature Approximations for Recurrent Neural Networks,
James Martens, Jimmy Ba, Matt Johnson,
ICLR2018. - Reinforcement Learning への応用
An Empirical Analysis of Proximal Policy Optimization with Kronecker-factored Natural Gradients,
Jiaming Song, Yuhuai Wu,
arXiv:1801.05566 [cs.AI], Jan 2018. - Bayesian Deep Learning (変分推論)への応用
Noisy Natural Gradient as Variational Inference,
Guodong Zhang, Shengyang Sun, David Duvenaud, Roger Grosse,
arXiv:1712.02390 [cs.LG], Dec 2018.
K-FACの実装例
- TensorFlow
https://github.com/tensorflow/kfac - PyTorch
https://github.com/yaroslavvb/kfac_pytorch
(こちらの記事で紹介されています。) - Chainer
https://github.com/tyohei/chainerkfac
本記事では深層学習における効率的な二次最適化の実現手法である自然勾配法に始まり、自然勾配法の近似手法の種類、近似手法の一つであるK-FACの概要を説明し、「大規模深層学習におけるK-FACの導入」について、その背景とこれまでの研究成果について紹介しました。数学的な厳密さは犠牲にして、直感的な理解を目指しました。
最後まで読んでいただきありがとうございます。