Graph Convolutionで自然言語処理を行う(テキスト分類編) Part5 (End)
Part5ではテキスト分類にあったGraph Convolutionの使い方を検証します。というのも、Part4の実験においてGraph Convolution単体の精度がランダムに選ぶのと大差なかったためです(精度はF1スコアのことで、以下も同様です)。もちろん、先行研究に倣いLSTMと組み合わせれば精度は上がります。ただそれはLSTMのおかげであってGraph Convolutionの効果とは言い難いです。そのため、Graph Convolution単体でテキスト分類の精度を上げることに挑戦します。
本記事で改善したポイントは以下2点となります。
- ノードの特徴を維持するようにする
- グラフの構築とモデルへの入力とで、前処理を分ける
改善した結果、以下のような結果となりました。Graph Convolution単体でもそこそこの成果が得られ、またLSTMでブーストができることを確認できました。
今回の検証では、Graphに”Static”という種類が追加されています。これは、前方の単語に対し固定的に接続を行ったGraphです(接続は片側方向)。例えば、”How are you”の場合以下のようになります。RNNの処理を模した接続です。
単語分散表現による分類ついてはMeanを取るよりAddの方が精度が格段に高かっため、Addを使用するようにしています(Additive Word Vector)。これに伴い、Graph Convolutionについても各ノードの結果をマージする方法をMeanからAddに切り替えています。
実験からは、以下の知見が得られました(以下の記述には、全て「テキスト分類においては」という但し書きがつきます)。
- Graph Convolution単体で、ベースラインに近しい精度を出すことができる。ただ、上回るわけではない。
- グラフ構造(Dependency/Similarity)による精度の差は小さく、単純に前方の単語にリンクを張ったほうが精度は良い(Static)。
- LSTM/Bi-directional LSTMと組み合わせることで精度を上げることができる。組み合わせる場合、Graph Convolutionを行う前に組み込むほうが良い。
- 組み合わせによるメリットは、精度面だけでなく学習速度面もある。
- Graph Convolutionの層を重ねると精度は悪化する。Headを増やしてもそれほど効果がない
実験に使用したNotebookは以下になります。
以下で、各改善ポイントと結果の詳細について解説を行います。
ノードの特徴を維持するようにする
Graph Convolutionの分類精度は、単純に単語分散表現をAddするより格段に悪くなっています。
Graph Convolutionでは、ノードの特徴は周辺ノードの特徴を集約して作成されます。以下は、Graph Attention Networkの実装コードです。
mask = -10e9 * (1.0 - A)
attention += maskattention = tf.nn.softmax(attention)
dropout_attn = Dropout(self.dropout_rate)(attention)node_feature = tf.matmul(dropout_attn, dropout_feat)
tf.matmul(dropout_attn, dropout_feat)
を行うと、Attentionの配分がある特徴が集約され、配分のないノードの特徴は含まれなくなります。そのため、自身への接続がない限りGraph Convolution後のノード特徴には自身の特徴は含まれなくなります。極端な話、接続がないノードは特徴がなくなってしまいます。
これを解決する手法として、常に自身への接続を持つ方法があります(隣接行列の対角を1にする)。ただ、これは本来「自分自身への接続」を表現するものであり、自身の特徴を残すために使うのは適切ではないでしょう。
そこで、単純にノードの特徴を足すようにしました。Residual Connectionのような形になります。
aggregation = tf.matmul(dropout_attn, dropout_feat)
node_features = drop_feat + aggregation
この結果、精度は大きく改善しました。
グラフの構築とモデルへの入力とで、前処理を分ける
通常、モデルへデータを入力する際は語彙数を制限します。これは、登場頻度の低い単語を「未知語」としてまとめてしまう処理です。今まではグラフを構築する際も前処理済みのデータを使っていたのですが、「未知語」というトークンが含まれることでグラフを作成するための構文解析に影響が出るのは火を見るより明らかです(Part4より)。
そこで、グラフの構築には前処理を行わない生のデータを使用するようにしました。その結果、以下のように正常にグラフが構築できるようになりました。
グラフのエッジに対するAttentionをヒートマップで可視化したものが以下になります。
学習結果の解析
LSTMを併用する効果についてみてみます。以下は、左が素のDependencyで、右がLSTMをプラスした場合の結果です。
LSTMを併用したほうが、val_accの上昇が早いことがわかります。これはSimilarityについても同様でした。
Graph Convolutionの層を重ねると、精度が落ちるという興味深い現象がありました。2 Layerにしたところ、以下のようによろしくない感じになりました。
原因は定かではなく、ちょっと謎です(=>後日、Twitterにて指摘をいただき同様の現象を指摘している論文を教えていただきまいた。Representation Learning on Graphs with Jumping Knowledge Networksです)。層を重ねるほど遠い接続先を考慮できるのですが、それが悪さをしているのかもしれません。
この他Headの増減なども行ってみましたが、劇的には精度が改善することはありませんでした。こうしたチューニングは人がやるより機械にやらせた方が良いので、ハイパーパラメーター探索のツールの導入も検討したいところです。
Part5、そしてテキスト分類編はこれで終了です。次回は、次の検証に移る前に最近の研究のサーベイ、またハイパーパラメーター探索ツールの検証などを行っておきたいと思います。