TensorFlow の SessionRunHook で Optuna の枝刈り

Shuhei Fujiwara
5 min readDec 18, 2018

--

この記事は TensorFlow Advent Calendar 2018 の 23 日目のはずだったけど、早く書き終わって我慢できなかったので空いてた 6 日目にねじ込んでしまったやつです。

こういう話があったので、ちょっと TensorFlow 用に Optuna の枝刈り機能を実装できそうか手を動かしてみました。

できた?

一応動くね!

前提知識

  • TensorFlow の Estimator をちょっと使ったことがある
  • Optuna が何なのかは知っていてサンプルコードを試しに動かしたことがある(枝刈りが何かの説明とかはしません

何ができれば良い?

Optuna 側に callback 的なもので定期的に現在の評価指標と反復回数を渡せれば OK のようです。

lightgbm の場合はこんなふうに callback が実装されています。

各学習の途中でこうやって Optuna に trial.report で評価指標を報告して、今枝刈りするべきかは trial.should_prune に反復回数を渡すと教えてくれるようです。

SessionRunHook ってなんぞや

Estimator には callback に近いものとして SessionRunHook というものが用意されています。Estimator は裏でよしなに tf.Session を作って、必要に応じて tf.Session.run を走らせてくれるのですが、その tf.Session.run の前後や tf.Session が作られたり閉じられたりするタイミングでやっておいて欲しい処理をねじ込むことができます。

Evaluation の結果をどこから持ってくる?

Evaluation を行った際に保存される summary ファイルから持ってくるのが良いでしょう。

TensorFlow の early stoppingそのような形で実装されています。

とは言え、 tf.Session.run が走る度 (1 step ごと) にファイルを読みにいくのはヤバ過ぎるので、適当に n 反復ごとに summary を取りにいくように書くのが良いでしょう。もちろん TensorFlow の early stopping でも同じようなことをしています。

余談 (読まなくて良いです)

なんかわざわざファイルから持ってくるの面倒くさくない?と思われそうですが、このへんは Estimator の設計思想に因るところが大きいかなと(勝手に)思っています。

Estimator では evaluation の処理も training で保存された checkpoint をファイルから読み込むというやり方をしています。そもそも Estimator の設計は training と evaluation が別々のセッションで動くという形になっているので、モデルの重みを共有するにはファイルを経由する必要があるのです。一見するとイケてないですが、 training と evaluation の処理を完全に分離するメリットはかなり大きいと思います。たとえば分散学習のときに evaluation の処理は専用の worker を用意してしまえば、 evaluation の頻度を上げたから training が全然終わらないみたいな状況を避けることができます。

最終的にできあがった SessionRunHook

Session が作られる前に begin が次のような処理を行います:

  • Estimator で step 数を保存するために使う Tensor への参照を取得

tf.Session.run の前に before_run が次のような処理を行います:

  • before で取得した global_step_tensor を tf.Session.run のついでに評価して欲しい Tensor として返す

tf.Session.run の後に after_run 次のような処理を行っています:

  • SecondOrStepTimer を使って n 反復ごとに Estimator の eval_dir にを見に行き、 summary から最新の情報を持ってくる
  • Summary には step 数の情報も残っているので、前回見た時より step 数が増えていたら Optuna の trial に結果を渡す
  • あとは他のフレームワークと同じ

刈るぞ!

刈り過ぎた。

今回のサンプルコードはここに全部置いてあります。

残課題

  • テストしてねえ(ていうかこれ何をどうやってテストするんだ!?
  • 分散学習をちゃんと考慮してない気がする(これ報告は master だけで良いはずだけど worker はどう片付けるのが良いんだろ?
  • tf.keras の場合は callback があるのでそっちも試したい(これは放っておいても誰かが作るでしょって感じするけど

などなど。

最近 TensorFlow 力が落ちていたので良いリハビリになりました。

--

--