Graph Attention Network Layerを実装する Part1
Graph Convolutionの一種である、Graph Attention NetworkをKerasのCustom Layerとして実装します。これは、Graph Convolutionを自然言語に応用する検証の一貫です。
実際実装したレイヤーは以下になります。今回”Part1”と題していますが、Part2では実際のデータセットを使ってGraph Attention Networkの力を検証してみようと思います。
実装の背景
Graph Attention Networkの実装は著者により公開されています。また、そのKeras実装も有志の方の手で公開されています。これを使えばいい・・・なら楽なのですが、何れもバッチに対応していません。端的には、単一のグラフで動作することを想定しており、データ個別にグラフを持つことを想定していません。
自然言語に応用する際は、各文がグラフ構造を持つことが想定されます(単語間の依存関係から作成したグラフなど)。また、化学物質などを扱う際も、物質それぞれがグラフ構造を持っています。そのため、各データがグラフを持つケースに対応できるようにします。
実装のポイント
バッチに対応する場合、最初の次元がバッチサイズになります。そのため、行列演算が少しややこしくなります。tf.matmulは3次元以上の場合最初の次元がバッチと想定する(行列のリストとみなす)ため、バッチ対応時にはよく使う関数です。tf.tensordotは任意の次元について内積を計算できますが、互いにバッチサイズを含む行列の計算には対応できないので注意が必要です。
Kerasでは、Custom Layerを使い演算処理をまとめることができます。端的には、関数化するようなものです。Custom Layer化することで、テストが非常に行いやすくなります。レイヤ単位で動作を確認しておくことで、学習が上手くいかない場合個別のレイヤの問題と、レイヤの組み合わせの問題を切り分けて考えることができます。
機械学習モデルの開発は、以下の3ステップで時間を浪費してしまうことがままあります。KerasではCustom Layer、他のフレームワークでも同等の仕組みがあると思いますが、テスト可能な単位に切り分けることが肝要です。
- スクラッチでゼロからモデルを書く
- 学習が上手くいかないので、ハイパーパラメーターを調整する
- モデルの適当な部分をいじる(2に戻る)
テスト
テストでは、主に以下3点を見ています。
- Forward Check: 意図したshapeが出力されるか
- Training Check: 学習するかチェックする
- Expectation Check: 意図した内容を学習するかチェックする
Forward Checkでは、意図した行列変換を経て出力に至るかをチェックします。これが、まず最初に行うチェックです。途中でaxisをミスっているということはままあります。
Training Checkでは「これが解けないとダメ」という問題を作りチェックします。アルゴリズム的な問題を作ることもあれば、MNISTやBostonの価格予測など、デフォルトでデータが入っているような問題を使う場合もあります。
Expectationでは意図した内容を学習するかチェックします。Attentionであれば、意図した箇所にAttentionが当たるかなどです。ニューラルネットでは、学習するけれど意図したものではないということが往々にしてあります。「狙った内容を学習しているか」がExpectation Checkになります。
今回は、以下のコードでテストしています。
Training Check、Expectationでは、「最大値を持つ近傍ノードの値を出力」「近傍ノードのうち、自身に最も近いノードの値を出力」の2つを使っています(make_problems
のmax
とdistance
です)。結果、Training CheckはパスしたのですがExpectationはエラーでした。具体的には、意図したノード(最大値/最も近い)にAttentionがうまく当たりませんでした。
Expectation Check
上手くいかない対策として、以下の対応を行いましたが特に効果はありませんでした。ここはちょっと謎なところがあるので、いったん棚上げにしてPart2では実際の問題を解いてみようと思います。
- 理論面の問題: 論文中ではsourceノードとtargetノードの特徴をconcatしているが、実際には合計してしてる(これは著者自身がそうしてしまっているのだが・・・)。そのため、論文に忠実なconcatで実装する(Bi-directionalの場合もcocatの方が優秀なので、足すよりはよくなるはずという見込みだったが・・・)。
- attentionのみで解く: attention無しで実行したところ同等の精度が出た。そこから、kernel単体で学習しきってしまい、attentionを掛ける必要がなくなったのではないかと考察。attentionの重みのみで実験してみたが、結果は変わらず。
- 問題の修正: max/distanceに該当するノードの値が、他のノードに比べてそれほど差異があるものではない=他のノードを見てしまっても問題ない状態になっているのではと考察。分散/値範囲を広げて実験してみたが、結果は変わらず。
論文中では、Attentionが意図したノードに係っているかは詳しく分析していません。一応attention weightの重みをビジュアライズしたものがあるのですが、妥当かどうかはちょっとわかりません。
論文中では複数のkernelを使っており、Attentionよりそれが寄与している可能性もなくはないのではと感じています(これを証明するには実験が必要ですが)。上手くAttentionが当たらないのは上記に上げた通り問題自体に何か原因がある可能性もあるため、一旦は実データでの検証を行ってみようと思います(Part2に続く)。