Neural Tangent Kernel(NTK)の概要

Kai
LSC PSD
Published in
16 min readJan 23, 2020

今ニューラルネットワーク業界に旋風を巻き起こしていて、今後主流になるかもしれないし、埋もれるかもしれない理論である「Neural Tangent Kernel(ニューラルタンジェントカーネル、略してNTK)」の日本語版解説です。NTKがどういったものなのかを理解することを目的としています。

はじめに(英語のお勉強)
Tangentは日本語で「接線」です。サインコサインタンジェントだと思ったそこのあなた!違いますよ!

Neural Tangent Kernel(NTK)とは

2018年末に提案され、その理論が機械学習の真理に近いとは言われているものの、イマイチ結果に結びつかない理論です。

ニューラルネットワークモデルはy=f(x,θ)(y=出力、x=入力、θ=重みの集合)という関数で表すことができ、一般的には個々の重みを調整することで正しいyを導いていくアルゴリズムです。NTKは重みを調整することに主眼を置くのではなく、重みが変化することによって関数の形がどう変わるのか、すなわち重み全体の分布(特徴)がどう変化するのかに焦点を当てています。
θが無限個ある(理論の上ではθが加算無限個あっても特に問題ないと思います。いざ実装する際には大問題ですが。)という仮定の下では、重みベクトルのカーネルは初期カーネルとほとんど同じ、つまり重みの状態は初期値からほとんど変化していないことが分かっています。このような条件下では、重みを定数とみなせるのでモデルを線形近似することが出来、勾配降下法を容易に用いることができるという理論です。

……というわけで、簡単な例を用いながらNTKを理解していきたいと思います。(英語で非常に丁寧な解説記事がありましたので、以下の内容はそれをベースに書き足したり削ったりしながら書いています)

簡単な例

すごく簡単な例から考えてみたいと思います。1次元の入力、及び1次元の出力があり、隠れ層(中間層)が2つでその幅がm、reluを用いたシンプルなニューラルネットワークで考えてみます。文字で見るより図で見た方が早いでしょう。

m=100でランダムに初期化させたとき、xとf(x,w)の関係は下記の図のようになります。

m=100でreluなネットワーク。100×100もあると大体似たような挙動になります。

この隠れ層の幅を無限にすると、この図はガウス過程に従った図となりますが、それをまじめに考察しだすと厄介な話になります。ガウス過程については今後出てきませんが、NTKを実装する際には大事になるのでこのワードだけ頭の片隅に置いといてください。というわけでこの図はこれ以降扱いません。

さて今回のモデルに対していくつかの表現方法の共通認識を持っておきたいと思います。

  • ニューラルネットワーク関数はf(x,w)と呼びます。ここでxは入力、wは重みのベクトルを全て並べたもの(サイズp)です。
  • 今回の例では、データセットは点(x,y)でN個あるとします。そのためデータセットは {x̄ᵢ,ȳᵢ}(i=1~N)と表せます。

これをシンプルなアプローチで学習させることを考えます。つまり、最小二乗法によって損失を計算し、フルバッチの勾配降下法をするだけです。この損失は以下のように記述できます。

これはベクトル表記を使えばもっと簡単にできます。

  • 最初に、全ての出力データセットȳをつなげ、サイズがNのベクトルを作ります。
  • 同様に、この関数による全てのモデル出力f(x̄ᵢ, w)を並べて予測ベクトルy(w)∈Rᴺを1つ作ります。基本的にy(w)ᵢ=f(x̄ᵢ, w)です。これはニューラルネットワーク関数f(・, w)を関数空間上の一つベクトルとしてみることに似ています。

これによって損失は以下の通りシンプルに書けます。

ここで、データセットのサイズNはどこでも変更することはありません。つまりNは固定値です。定数なので損失関数において不要な値となりました。したがってNは、式の簡潔さを維持しつつ、かつ全体に大きな影響を与えないように外すことが可能です。(導関数をスッキリさせるために1/2はそのまま残しておきます。)

これによってベクトル表記にすることが出来ました。上図の単純な例ではN=2です(2つの青い点)。

ネットワークの訓練は、勾配降下法を用いてこの損失を最小化させるだけです。

最小二乗法の勾配降下法にてm=100のニューラルネットワークを訓練している様子

ここで、2つの隠れ層(m×mの行列)間の重みについて、訓練の進捗状況を並べたgifを作成すると、非常に興味深い傾向がみられます。

m=10,100,1000に変えた時の訓練の経過を示したアニメーション

上記の画像、jpegじゃありません。幅10の時だけは少し変化が見て取れますが、それ以上の幅の時には静止画のように見えるでしょう。非常に怠け者なモデルですね!重みベクトルのノルムについて、初期値との相対的な変化を調べることでこれを定量的に確認することが出来ます。

上記の訓練結果をプロットすると以下の通りです。

Lossのカーブと、重みの変化量に関するグラフ

重みベクトルのノルムは、一つのベクトルに含まれる全ての重みを使って計算しています。また、その他ハイパーパラメータ(学習率など)は全て一定です。

さて、隠れ層が大きくなればなるほど、重みは変わらないという結果となりました。この事実についてもう少し詳しく考えてみましょう。

テイラー展開

というわけでネットワーク関数を重みの初期値の周りでテイラー展開します。

一次近似まで行い、単純な線形関数に近似しちゃいましょう。入力に対するモデルの全出力(y(w))に対してベクトル表記を用いると次のように書き直すことが出来ます。

行列のサイズを明示的にすると以下の通りです。

全てのベクトルと行列のサイズを明示的に示した図

初期出力y(w0)や、ある重みにおける勾配ベクトル∇wy(w)(=ヤコビアン)は単なる定数です。なのでこの近似式は重みにのみ依存する線形モデルとみなせます。そのため最小二乗法の損失を最小化するタスクは、単なる線形回帰です!しかしモデルの勾配計算は線形演算ではないので、モデル関数は入力に対して未だ非線形であることには注意してください。とはいいつつ実はこのモデルは、初期化時のモデル出力の勾配を定義とした特徴マップφ(x)による、単なる線形モデルとみなせます。入力xはこの特徴ベクトルによって写像されます。

この特徴マップは、Neural Tangent Kernelと呼ばれるカーネルを入力側に自然に誘発します。このカーネルの詳細は後程説明するとして、最初にこの線形近似について考えてみましょう。でも真面目に考え出すと長い話になりますので、ここではあくまで感覚的なまとめに留めておきます。

線形近似の際に考えるべきことは、ヤコビアンの変化量ではなく、ヤコビアンの相対的変化量です。これは、初期値における
・(モデル出力と実際の出力の差)÷(重みにおける勾配の変化量=ノルム)
によって、いわゆる重み空間における”距離”を定義し、これに
“ヤコビアンの変化率=ヘッセ行列∇²w y(w0)”をかけ、”ヤコビアンのノルム “で割ることで考えることが出来ます。なので以下の様にκ(w0)を設定します。

モデルを線形近似したいということは、このκ(w0)=ヤコビアンの相対的変化率を出来るだけ小さくしたいという事です(≪1)。ヤコビアンが変わらない=勾配が変わらない=線形だということです。これまでの結果から、隠れ層の幅mをm→∞に発散させると、κ→0に収束することが分かっています。ただし、κ→0は重みが正しく初期化されている場合のみ成立しない事に注意です。その初期値は、入力層のサイズに反比例する分散を持った、独立で平均が0であるガウス分布に従う必要があります(LeCunの初期化)。

また、ヤコビアン∇wy(w)にほとんど影響を与えない程微小なwの変化で、十分に‖(y(w₀)−)‖ の変化を誘発できることがChizat氏とBach氏の論文「Lazy Training in Differentiable Programming」によって示されています。これについては直感的な理解をしておきましょう。

一つ一つの隠れ層が大きいとき、出力に影響を与えるニューロンは大量に存在することになります。これらのニューロンの重み全てが僅かに変化するだけでも、その出力を大きく変化させる可能性があるため、ニューロンはデータに合わせようと思っても、ほとんど変える必要がないということです。

重みの変化量が少ない程、線形近似の精度は高まります。隠れ層の幅が大きくなると、ニューロンの移動量は減少し、モデルは線形近似に近付いていきます。

参照元サイトではこの辺のことについてもっと詳しく証明していますが、ここでの目的はNTKの概要を掴むことなので、すっ飛ばします。

モデル出力の調整

κ(w0)についてもっと考えてみます。

このモデルの出力に係数αをかけます。

これは当然モデルの倍率を変えているだけです。…うーん ‖(y(w0)−ȳ)‖が邪魔ですね。初期化時にこのモデルは常に0を出力する、つまりy(w0)=0と仮定してしまえばこれを取り除くことが出来ます。

これでαを∞に飛ばすことで、モデルを線形(κ(w0)→0)にすることができます!上にあげた”Lazy~”という論文の著者らは、このκ(w0)の量をモデルの逆相対尺度と呼びました。

これを視覚化するために、1次元モデルを準備しました。このモデルの重みwは、w0=0.4としています。

また、初期状態で出力が0になるようにします。モデルの線形性に対するαの効果を確認するには、xを特定の値、例えばx=1.5で関数の値を見るだけです。αを変えると次のようになります。

x=1.5のときの、1変数モデルにおけるα(出力調整係数)の変化とグラフの図

αが大きくなるにつれて実際の関数が線形(基本的には接線)化していくことが観察できます。

また、ある1点に対する2つのモデルの損失の形を調べることもできます。線形関数の損失の学習軌道は綺麗な放物線を描きます。なので実際の損失もαが増加につれ、この放物線に近付くと予想されます。

上図は正規化された損失です。正規化は、実際の損失がαによって変化しないよう、単にα²で割っています。右のgifはw=0.4近傍に寄せていっているので、細かな違いをよりはっきり見ることが出来ます。

ここでは2つの重要な視点があります。

  • 実際の損失は線形化された損失に予想通り近付いています。
  • 2つの損失の形の最小値は近くなり、さらに重要な事に、初期化値にも近くなります。これは、モデルを訓練している間、重みがほとんど動かないという前の観察結果と完全に一致しています。

勾配流(Gradient flow)

ここまでで、ニューラルネットワークやその他非線形なモデルが、いつ正確に線形近似されるかについては分かりました。次に、勾配降下法における訓練の動きについてみていきましょう。

これをちょっと式変形するとこうなります。

左辺は導関数の差分近似式に似ていますね。なのでこの方程式は微分方程式を差分近似した式とみなせます。学習率ηを限りなく0に近づけると、重みベクトルの更新を微小時間で行っている微分方程式の様に書き換えられます。

これを勾配流と呼びます。標準的な勾配降下の連続時間に相当するものです。ポイントは、学習率が十分に小さい場合、パラメータ空間における勾配降下の軌跡は、この微分方程式の解の軌跡に近似できることです。つまり、パラメータの学習は「パラメータの分布」を学習していると解釈できるという事です。さて、時間の導関数はドットを付けて、下記の様に書き換えます。

さらにここから時間変数を削除してみましょう。損失関数の代わりに勾配を使用すると、次のようになります。

ここでチェーンルール(合成関数を微分するときのルール)を利用し、この勾配流から導かれるモデル出力y(w)(これは基本的に関数空間上の運動となります)の振る舞いを以下のように導出できます。

赤の量(∇y(w)ᵀ∇y(w))を、Neural Tangent Kernel(ニューラルタンジェントカーネル、略してNTK)と呼びます。ある時間tにおけるwによる、2点のカーネル行列(=グラム行列)の値です。こいつをH(w)と書くことにしましょう。

ネットワーク関数をテイラー展開した部分に戻ると、線形化モデルには特徴マップφ(x)=∇wf(x, w0)があることが分かっています。この特徴マップに対応するカーネル行列(=グラム行列)は、全てのサンプル点の特徴マップ間の互いの内積を取ることによって得られます。これはまさにH(w0)のことです!

NTKの初期状態は、サンプルの特徴マップ間の互いの内積からなります。これはサンプル点上の外積として解釈することもできます(x̄ᵢはデータセットの入力であることを思い出しましょう)

ある時間tにおけるNTK(H(wt))を考えた時、これは時間に依存しているので簡単に解くことは出来ないと思うかもしれません。しかし、あるモデルが線形近似に近い場合(κ(w0) ≪1)、モデル出力のヤコビアンは訓練が進行しても変化しません、言い換えると

ということです。これはtangent kernelが訓練中一定であるので、「カーネルレジーム」と呼ばれます。そのため、訓練中の動きは非常に単純な線形常微分方程式に従います。

また、明らかにy(w)=ȳは常微分方程式における平衡状態であり、訓練のロスが0であるときに相当します。u=y(w)-とおくと、を消すように変数を変えることが出来ます。すると以下の様に単純化することができます。

この常微分方程式の解は指数行列で与えられます。

このモデルは十分にパラメータ数があるため、NTK=∇y(w0)ᵀ∇y(w0)はいつも正定値です。というのも、十分なニューロン数があれば、適当な学習率のGDまたはSGDで最適解に至れることが分かっているからです。正定値であるNTKのスペクトル分解を行う事により、勾配流の軌跡を、対応する固有値に比例した速度で減衰する独立な1D成分(固有ベクトル)に分解することが出来ます。重要な事は、これらは全て(全ての最小固有値が正であるため)減衰するという事であり、これは勾配流が常にlossが0である平衡に収束することを意味します。

この一連の議論を通して、勾配降下はそれがその線形化(これは倍率αを∞に飛ばすだけで達成できる)に近い限り、任意の非線形モデルに対して訓練ロスを0にすることが出来ることを示しました。これは勾配降下が訓練ロスが0になることを示している最近の論文における、ほとんどの証明の本質となります。

まとめ

NTK理論は素晴らしいのですが、実際はwを無限にして正確に計算されたNTK(正確にはそれを拡張した畳み込みNTK=CNTK)でさえ、MNISTやCIFARのような標準ベンチマークで7%ほど他のモデルに性能が負けていました。一応、最近の研究によってこのギャップは埋めることが出来ましたが、それでもResNet止まりのパフォーマンスしか出せていません。

NTKはニューラルネットワークの学習に新しい視点をもたらしました。発表から1年以上たった今も精力的に研究が進められています。訓練が進むにつれてカーネルがどのように変化するのかを知ることは、ニューラルネットワークのより良い理論を求めることに繋がっていくと思います。

--

--