間隔反復のDeep Learningへの応用

Shugo
KARAKURI Techblog
Published in
14 min readNov 20, 2019

こんにちは!カラクリ株式会社のAIチームの松本です.

本記事では,間隔反復を用いることによりDeep Learningの学習を高速化させる,という論文をもとに,実際に実装してみて効果を測定してみました.この論文では,間隔反復を用いることでDeep Learningの学習を高速化しつつ,精度はあまり下がらない,と結論づけていますが,それを追検証してみたいと思います.

論文:H. Amiri et al. “Repeat before Forgetting: Spaced Repetition for Efficient and Effective
Training of Neural Networks” (EMNLP 2017)

間隔反復とは,一度学習したものを,徐々に復習までの期間を徐々に延ばしながら繰り返し学習していくことにより,記憶の定着を図る学習方法のことです.間隔反復の一例としては,単語帳を使って行うライトナーシステムが有名で,今回はこのライトナーシステムをDeep Learningに応用して効果測定を行いました.ライトナーシステムについては後ほど説明しますが,名前を聞いたことはなくても実践したことのある人は結構いるのではないでしょうか.

1. Introduction

早速ライトナーシステムとはなにかについて説明したいと思います.

ライトナーシステムとは,学習する対象を正誤に基づいてグループ分けしていき,正答したものについては復習する回数を減らし,誤答率の高いものをより集中して学習することにより,学習効果を高めていく学習方法のことです.

学習者は,最初すべての学習対象を第1グループに置いておき,すべての質問に答えます.その後,正解したものは次のグループに送り,間違ったものについては第1グループに戻すといった行動を繰り返していきます.第1グループは高頻度(例えば毎日)で学習を繰り返すものの,先に送られたグループでは徐々に学習間隔を延ばしていく(例えば第nグループは2^(n-1)日ごと)ことによって学習をしていきます.

具体的な例を上げると,単語帳を3個用意する状況を考えましょう.単語帳1は毎日,単語帳2は2日ごとに,単語帳3は4日ごとに学習することにします.最初,単語帳1にすべてのカードを入れておきます.単語帳1で正答したものについては単語帳2に渡し,単語帳2で正答したものについては単語帳3に渡し,また,単語帳2もしくは3で誤答したものについては単語帳1に戻し,学習を進めていきます.

これにより,何度も間違えるものに関しては単語帳1に長く留まるため高頻度(ほぼ毎日)で学習することになりますが,あまり間違えないものに関しては単語帳3に長く留まるため学習頻度が落ち,よく間違えるものに集中して学習できます.

それでは,このライトナーシステムをどのようにDeep Learningに用いるかですが,通常Deep Learningで学習を進める際,各epochですべてのdataを学習に用いるところを,epochごとに使うdatasetを変更することによって実現します.正解を与えたdataに関しては学習に混ぜる頻度を減らし(次の単語帳に送る),間違えたものについては頻度を上げる(最初の単語帳に戻す)ということを行うことにより,各epochごとに使うdatasetを変えていきます.

例えば,dataがもともと10000個あるとして,第1 epochはすべてのdataについて学習を行います(最初はすべてのdataが単語帳1にあると考えれば良いです).このうち,1~8000に関しては正解のラベルを与えたとするとこれらは2 epochごとに学習するとします(このとき1~8000個目のdataに関しては単語帳2に移ったと考えておきます).第2 epochについては8001~10000個目のdataについてだけ学習を行い,8001~9000個目のdataに関して正解ラベルを与えたとします(このとき8001~9000個目のdataも単語帳2に移ったと考えます.つまり1~9000個目までのdataが単語帳2に移っています).このとき,すべてのdataを学習しないので,iterationの回数がすべてのdataを学習したときと比べて小さくなります.そして,第3 epochでは再びすべてのdataについて学習を行い,1~9000個目のdataについては正解したものは4日ごとに学習させ(単語帳3に移ったものと考えます),間違えたものに関しては再び毎日学習させます(単語帳1に戻ったと考えます).これらを繰り返すことによって,各epochに用いるdataを増減させることにより,学習効率を高めることを目指します.

2. Settings

実験1

MNISTとFASHION_MINISTの2つのdatasetを用いて実証してみました.ニューラルネットワークの構造は,ノードが100個の全結合層を1つとしたものを使い,optimizerとしてはAdamを使いました.[注1]

また,ハイパーパラメータについては下記の通りとしました.

・ミニバッチサイズ:300

・学習係数:0.001

また,実験1のライトナーシステムは,3つのグループにdataを振り分け,第nグループに入っているdataは2^(n-1) epochごとにに学習するようにします.つまり,第1グループに入っているdataは毎epoch,第2グループに入っているdataは2 pochごと,第3グループに入っているdataは4 epochごとに学習するというような状況を考えます.

実験2

事前学習済みの BERT を用いて,livedoor ニュースコーパス分類タスクを学習 (fine-tuning) しました.optimizerとしてはAdamを使いました.ハイパーパラメータ設定は,当ブログのこちらの記事に準じました.

実験2のライトナーシステムは,5つのグループにdatasetを振り分け,第nグループに入っているdataは2^(n-1) epochごとに学習するようにします.

[注1]
実際には,datasetをMNISTにしたものや,全結合層の前に畳み込み層を入れたもの,optimizerをSGDにしたもの,学習率を変更したものなどに対しても実験を行ってみたものの,ほとんど結果が変わらなかったため,上記の条件のもののみを結果として提示しています.

3. Results

まず,ライトナーシステムを用いた学習において最も特徴的だったのは,すべてのdataを学習するepochにおけるaccuracy/lossが高く/低く,あまり正解できないラベル(グループ1にあたるもの)を中心に学習するepochではaccuracy/lossが低く/高くなっている点です.これは実験1では顕著に現れた一方,実験2では実験1ほどきれいな周期としては見られませんでしたが,すべてのdataを学習したepochで毎回ではないもののaccuracy/lossが上がる/下がることが多かったです.

(左)FASHION-MNISTを用いた実験1,(右)livedoorニュースコーパスを用いた実験2.横軸はepoch数.

そして,最も重要となる速さと精度の比較ですが,検証誤差の最小値が3 epoch連続更新されないときに早期終了させます.しかし,ライトナーシステムにおいて精度の高い“良い”モデルとなるのがすべてのdataを学習させたepochのときとなるので,epoch数を合わせて比較するために,epoch数が早期終了した時点に最も近くライトナーシステムでも精度の高いモデルとなっているときのものと比較することとします.

通常の学習結果

実験1-vanilla

実験2-vanilla

ライトナーシステムを用いた学習結果

実験1-Leitner-1

実験2-Leitner-1

通常の学習結果と同epoch数におけるライトナーシステムを用いた学習結果を比較してみると,実験1では,ライトナーシステムを用いた方が約1.74倍速く,validation dataに対する精度も1%未満の低下で十分な速さを実現しています.また,実験2では,ライトナーシステムを用いた方が約1.41倍速く,精度は約1.75%の低下となっています.論文の結果ほどの差とはなりませんでしたが,一定の結果を確認することができました.

しかし,iteration回数(学習時間はiteration回数に比例しています)で比較してみると,通常の学習結果とライトナーシステムを用いた学習結果はほとんど変わりませんでした

実験1(すべてのdataを学習したepochのみを表示)

実験1において,横軸をiteration,(左)縦軸をaccuracy,(右)縦軸をlossとした図

実験2

実験2において,横軸をiteration,(左)縦軸をaccuracy,(右)縦軸をlossとした図

ライトナーシステムを用いた学習においても早期終了するまで学習させることを考えます.ライトナーシステムを用いた際は,実験1では4 epochごと,実験2では16 epochに良いモデルと考えられるので,良いモデルに対して検証誤差の最小値が3回更新されないとき,つまり実験1に対しては12 epoch,実験2に対しては48 epoch連続で検証誤差の最小値が更新されないときに早期終了させることとして実行します.

このとき,実験1ではライトナーシステムと通常の学習結果でそれほど大きな差は出ませんでしたが,実験2では約3.0%の精度向上となりました.

ライトナーシステムを用い,早期終了まで学習

実験1-Leitner-2

実験2-Leitner-2

しかし,これは早期終了を判断するまでのiteration数を増やしすぎたことが原因かもしれません.実験2-vanillaにおいては早期終了した時点から早期終了を判断するまでに要するiteration数 が831 iterationであるのに対して,実験2-Leitner-2では早期終了を判断するまでに904 iteration分の学習を行っています.

早期終了を判断するまでのiteration数を減らすとどのような結果になるかも見てみたいのですが,実験2においては12 epoch目が実際の最小値となっているので,早期終了を判断するタイミングを早めても同様の結果となってしまい,今回の実験では,精度の向上がライトナーシステムに起因しているのか早期終了を判断するまでのiteration数が増えたことによるのかがわかりませんでした.

実験1では,通常の学習結果とライトナーシステムを用いた学習結果に大きな差は見られませんでしたが,実験2で見たようにライトナーシステムを用いたほうが早期終了した時点での正答率が向上する場合もあり,こちらに関しては今後追加で調べていきたいと思います.

また,実験2では,通常は良いモデルとならないと考えられる12 epoch目が最適モデルとなっている理由についても考察する必要があると思います.

4. Summary

今回の検証結果では,

ライトナーシステムを用いることにより,高精度を保ちつつ,高速に学習を行える

という論文の最も重要な結果が残念ながら否定される結果となってしまいました.

FASHION-MNISTを用いた実験1では明確に見えた周期性が,livedoorニュースコーパスを用いた実験2ではあまり見えなかった点については,以下のような2つの可能性を考えています.

  • 実験1では正答率の低いdata(つまり単語帳1に残り続けるdata)に対してのみ学習を行った際にvalidation accuracy/lossが低下/上昇することから,正答率の高いdataと低いdataでモデルの最適解に距離があり,continual learningにおけるcatastrophic forgettingのような現象が起きている可能性.つまり,実験1では下図における紫色の円内における外れ値が多く存在している一方,実験2では赤枠内のような分類が困難な点は存在するものの外れ値はあまり多く存在しない可能性.
2値分類における判別が困難な点(赤枠内)と外れ値となっている点(紫枠内)
  • 実験2では,ライトナーシステムによる周期性がもたらす精度の変化よりも学習曲線自体のもたらす精度の変化が大きくなっているために周期性の挙動が目立たなくなっている可能性.

この考察が正しいとすると,実験2では,正答率が低いdataに対してのみの学習においても正答率の向上が期待されるため,本来良いモデルになると期待されていない点で最適モデルとなっていることも納得がいきます.

5. Future Work

今回の実験で,ライトナーシステムを用いた場合,周期性の見える学習と周期性の見えにくい学習があることがわかり,周期性が見える原因として,catastrophic forgettingが起きている可能性を考えました.catastrophic forgettingを防ぐためには,continual learningで用いられているEWC法をライトナーシステムと組み合わせて用いることが考えられる,それによって高速化ができるかどうかを見てみることなどが考えられます.

(参考記事)チャットボットサービスにおける continual learning の検証

そもそもライトナーシステムを用いて学習した際に周期性が現れるdatasetでは,外れ値としてlabelが間違っているようなdataが混じっていると考えられるので,実際に一部のlabelをランダムに変えてみて周期性が現れるかを確認することもできると思います.また,今回実験2で見られたように,周期性が見えにくいdatasetではライトナーシステムのほうが高精度のモデルができたように,外れ値の少ないdatasetでは分類における境界値のdataを繰り返し学習するため,精度の向上も期待できるので,こちらに関してもさらなる検証を行う価値があるように思います.

また,今回の実験ではライトナーシステムのみしか扱いませんでしたが,もとの論文では,他にも様々な方法で学習に用いるdatasetを変更させて高速化を試みています.それらについても結果として提示されているのが,epoch対比での学習効率のみとなっているため,iteration(≒経過時間)対比での学習効率を実証して見る価値はあると思います!

--

--