Graph Convolutionを自然言語処理に応用する Part2
Graph Convolutionを自然言語処理に応用するための調査Part2となります。Part2では、いよいよ実装を見ていきます。Graph Convolutionサーベイ編であるPart1はこちらからご覧ください。
Part2ではGraph Attention Networkの理論と実装を確認していきます。それに先立ち、まずPart1で学んだ内容に沿ってGraph Attention Networkという手法を整理しておきましょう。
Graph Attention Networkは
- 「Spatial Method」なGraph Covolutionであり、
- 「エッジの情報を加味するが、潜在表現までにはしない」手法です
まず、Graph Convolutionの研究を読む際は畳み込みの手法がSpatialかSpectralかに注意を払う必要がありました。そして、エッジの情報を加味する・しない、する場合どの程度かという区分けがありました。Graph Attention Networkは、速度そこそこで、エッジの情報もそこそこ加味できる手法です。ここから先エッジの情報を捨てても速度が欲しい場合はGraph Convolution、よりリッチな情報が欲しい場合はEdge EmbeddingなGraph Neural Networkを選択することになります。その意味では、Graph Attention Networkは最初に試すには良い手法でしょう。
Graph Attention Network
ここからはGraph Attention Networkの理論と実装双方を確認していきます。実装は公式のTensorFlow実装があるのですが、コメントが丁寧で見やすいKeras実装もあるため、そちらをベースに確認していきます(Keras実装は、公式のTensorFlowリポジトリからリンクが貼られています)。
処理の解説は、以下の流れで行います。
- 入力となる、グラフのデータ形式の確認
- ノードに対し重みをかける処理の確認
- Attention計算処理の確認
Graph Attention Networkへの入力は、当然グラフとなります。グラフは、以下のように表現されます。
- A: グラフの接続状況を表す行列(隣接行列=Adjacency Matrix)
- N: ノードの情報。各ノードは、サイズFのベクトルで表現されているとする。
この2つが入力となります。
入力されたノードに重みをかける処理は、以下のように行われます。重みをかけることで、ノードの表現をFからF’へ圧縮します。
linear_transf_X = K.dot(X, kernel) # (N x F')
F’へ圧縮した表現をベースに、Attentionの計算を行います。Attentionを計算するにあたっては、別途Attention用の重みをかけます。この重みにより、各ノードを一つの値で表現します。接続される側(self)、する側(neighbors)で重みは別々にし、2つを合算することで(N, N)の行列を作成します。
(N, N)の行列になるのは、N個のノードに対し遷移先のノードがN個あるためです(自分自身への接続を含む)。この算出は、実装においてはbroadcastを使い上手く計算されています(オリジナルのTensorFlow実装でも同様)。
dense = attn_for_self + K.transpose(attn_for_neighs) # (N x N) via broadcasting
演算結果であるdenseに対し活性関数、そしてsoftmaxを適用することでAttentionを計算します。この時、接続がないノードは無視するようマスクをかけます(オリジナルの実装では、隣接行列をbiasに変換し、biasの加算によりこの処理を行っています)。
# Add nonlinearty
dense = LeakyReLU(alpha=0.2)(dense)# Mask values before activation (Vaswani et al., 2017)
mask = K.exp(A * -10e9) * -10e9
masked = dense + mask# Feed masked values to softmax
softmax = K.softmax(masked) # (N x N), attention coefficients
ここまで紹介したAttentionを計算するための一連の処理は、論文における以下の図に該当する処理です(Figure 1 Left)。まず、自分自身(self=i)と接続相手(neighbor=j)のattention coefficientsを算出します(解説したとおり実装上は別々の重みを掛けた上での合算で算出します)。その後は、活性関数(下図中央のノード)を適用したのちsoftmaxでAttentionの計算を行っています。
これでAttentionが計算できました。あとは、Attentionを「ノードに重みをかけた結果」にかければ計算は終了です。なお、AttentionにはDropoutを適用しておきます。
dropout = Dropout(self.attn_dropout)(softmax) # (N x N)# Linear combination with neighbors' features
node_features = K.dot(dropout, linear_transf_X) # (N x F')
以上が、Graph Attention Networkにおける計算処理になります。この一連の計算は、kernel毎に行われます。CNNが複数のkernel(フィルタ)で処理を行うように、Graph Attention Networkでも複数のkernelを使った処理が可能です。各kernelの演算結果は、結合するか平均を取るかでマージします。
for head in range(self.attn_heads):
...
outputs.append(node_features)
if self.attn_heads_reduction == 'concat':
output = K.concatenate(outputs) # (N x KF')
else:
output = K.mean(K.stack(outputs), axis=0) # N x F')
if self.activation is not None:
# In case of 'average', we compute the activation here (Eq 6)
output = self.activation(output)
論文中の以下の図(Figure 1 Right)では、3つのkernel(緑・青・紫)を使って計算した結果をcocat/averageしてマージする様子を示しています。
以上でGraph Attention Networkの計算方法については理解できました。では、実際に動かして動作を確認してみましょう。
Graph Attention Networkを動かす
検証に使われているのは、CORAという論文間の引用ネットワークを表したデータセットです。1件でも参照している/されているものは2708で、各論文は7つのカテゴリに分けられます。論文の内容(Nodeの情報)と引用のネットワーク(Edgeの情報)から、きちんとクラス分類できるかがお題となります。
前処理済みのデータがあり(tkipf/pygcn)、このデータが使われています。前処理済みのデータでは、論文の内容(Nodeの情報)はStop word/低頻度後(10以下)を落とした1433の単語頻度で表されています。つまり、グラフの情報は以下のようにあらわされます。
- A: 隣接行列 (2708, 2708)
- N: ノード特徴(2708, 1433) ※ソースコード中ではX
- ラベル: (2708, 7)
実装ではGraph Attention Layerを2層重ねています。モデルの構築は以下のようになっています。
# Build model
model = Model(inputs=[X_in, A_in], outputs=graph_attention_2)
optimizer = Adam(lr=learning_rate)
model.compile(optimizer=optimizer, loss='categorical_crossentropy',
weighted_metrics=['acc'])
学習は以下のように実行しています。
model.fit([X, A],
Y_train,
sample_weight=idx_train,
epochs=epochs,
batch_size=N,
validation_data=validation_data,
shuffle=False,
callbacks=[es_callback, tb_callback])
X_in, A_inはそれぞれ(1, 1433)、(1, 2708)となっており、バッチサイズがN(=2708)となっています。このためバッチ単位で見た場合に本来の姿である(2708, 1433)のノード特徴、(2708, 2708)の隣接行列から(2708, 7)のラベルを予測しているという形になります。バッチをshuffleした場合グラフ定義が壊れてしまうため、shuffle=Falseとなっています。
実行すると精度が上がっていき(またval_lossが下がっていき)、きちんと動作しているのがわかります。
Epoch 1/2000
2708/2708 [==============================] - 6s 2ms/step - loss: 2.0158 - weighted_acc: 0.1071 - val_loss: 2.0056 - val_weighted_acc: 0.3480
Epoch 2/2000
2708/2708 [==============================] - 4s 1ms/step - loss: 2.0058 - weighted_acc: 0.2000 - val_loss: 1.9954 - val_weighted_acc: 0.5360
Epoch 3/2000
2708/2708 [==============================] - 4s 1ms/step - loss: 1.9929 - weighted_acc: 0.3643 - val_loss: 1.9850 - val_weighted_acc: 0.6400
Epoch 4/2000
2708/2708 [==============================] - 4s 1ms/step - loss: 1.9805 - weighted_acc: 0.4429 - val_loss: 1.9758 - val_weighted_acc: 0.7020
Epoch 5/2000
2
Part2では、Graph Attention Networkの理論と実装を確認しました。動作確認に使用したCORAはまさに自然言語テキストの分類そのものであり、すぐに応用へと移れそうです。
ただ、Graphを扱ったタスクはノードのクラス推定以外にLink Predictionといった接続推定など様々なものがあります。そのため、Part3では応用の前にいったん立ち止まり、Graph形式のデータにおけるタスクをいったん整理したいと思います。