OptunaのPruningが抱える課題

Tsukasa OMOTO
6 min readFeb 2, 2019

--

偶然にも、Optuna の開発者の1人である g_votte さんと NeurIPS2018読み会で、Optuna の Pruning(枝刈り)が抱える問題について議論させていただきました。この記事はその議論内容をまとめたものです。

先に1点断っておきます。記事タイトルに「Optunaの」と付けましたが、ここに書いてあることは枝刈りを使うどんなライブラリでも発生しうる問題です。Optuna 以外に枝刈りを提供するライブラリを私は知らないため、画期的なライブラリであるという敬意を込めて「Optunaの」と付けました。ご了承ください。

枝刈りが学習の初期段階に大きく依存する

どういうことかと言うと、これは図を見ればすぐにわかります。以下の2つの図は、学習率 (learning rate) を変えて、適当に LightGBM を回して得られた学習曲線です。

学習曲線(全体)
学習曲線(拡大)

学習の初期段階では、 lr=0.1 の曲線が最も低い loss を描いていますが、最終的には lr=0.025 のグラフが最も低い loss を達成しています。

学習率のような学習速度を変化させるパラメータを探索対象に含んで枝刈りをしてしまうと、モデル性能が良くなる可能性のあるパラメータをがっつり削ってしまう恐れがあります。

この問題は、例えば MedianPrunerのパラメータである n_warmup_steps を大きくすることで対処できますが、この値を大きくすることは枝刈りの効率を下げることになります。

ということで、学習初期の loss から最終的な loss を推定するような技術が欲しくなります。そして、これは特に根拠ない直感ですが、頑張れば出来そうな気がします。

k-fold cross validation との組み合わせが自明ではない

k-fold cross validation と枝刈りを組み合わせたいとき、どこの loss を利用するのか、枝刈りのタイミングをどうすれば良いのか自明ではなさそうです。

Optuna を利用して、枝刈りなしの k-fold cross validation を実現するには、FAQ を参考にして、例えば、以下の図のように2段構成にすれば出来ます。

k-fold cross validation (k=5) with Optuna

ここで、local objective は Optuna の枝刈り機能を利用するためだけに用意されたダミーな objective であることに注意してください。これの実装(擬似コード)は以下のようになります。

これに枝刈りが差し込める箇所は、 CV 全体を管理する study の 27行目と、各 fold の学習結果を管理する 19行目 の study になります。

まず、19行目の方に枝刈りを差し込むことには、前半で述べた学習曲線の問題は残りますが、CV に対して特に問題を起こさないです。

問題は 27行目 の方で、こちらは CV 全体の loss を管理するため、どの loss と比較しているのか注意が必要です。

CV 全体の loss と各 fold の loss の比較

ここで、1回の試行(15行目の __call__ を1回呼ぶこと)の中で、その試行が有望であるかのか無いのかの判定が難しいという問題が発生します。

例えば、最初の fold の loss と CV 全体の loss を比較することは通常はできないと思います。それぞれ、異なるデータセットで計算した loss と見ることができるからです。

この記事の執筆時には、1つの study が比較する loss を複数保持したり、指定するような機能は存在しないようです。そこで、Optuna の枝刈り機能に頼らないアイディアを考えてみると、

  • 1回の試行の中で、各 fold のうち、1つでも枝刈りされた local study があるとき、そこで全体の試行を打ち切る
  • 1回の試行の中で、1つでも枝刈りされなかった local study が存在するとき、枝刈りされてしまった全ての local study を枝刈りなしで学習し直す

などが容易に思いつきます。

私の経験では、前者の場合、ほとんどの試行が枝刈りされてしまい、うまく探索できていないように感じました。結果、現在は後者のアイディアを採用しています。この時、k の値にもよりますが、枝刈りなしの時に比べて半分、前者に比べて倍ぐらいの時間がかかっているようです。

hold-out で先に実験をし、残ったパラメータで CV

上記のアイディアとその問題について g_votte さんと議論したところ、先にホールドアウトでパラメータの候補を探索し、その候補に対して CV をするというのが探索効率的に良さそうだというアイディアが出てきました。

ホールドアウトのところは普通に枝刈りを導入し、枝刈りされず残ったパラメータの候補の全部か一部を枝刈りなしの CV の対象にすれば良さそうで、これは確かに、最終的に得られるパラメータの性能と枝刈り効率の2つの面で良さそうな印象を受けました。

枝刈りを採用する上で発生する2つの問題について紹介しましたが、これらの問題は誰もがすぐに直面する問題で、正直、すでに誰か偉い人が解決策やベストプラクティスを提案してるものだと思っていましたが、少し探してみたところ、そのような研究を見つけることができませんでした。g_votte さんも見たことがないようなことを仰っていました。

もしこれに関連する研究やアイディアをご存じの方がいらっしゃいましたらご連絡頂けると嬉しいです。また、最後になりましたが、ご多忙の中、議論にお付き合い頂きました g_votte さんに感謝申し上げます。

--

--