Learning beyond datasets : Knowledge Graph Augmented Neural Networks for Natural language Processing — Explained!

Anshuman Mourya
6 min readJan 17, 2019

--

Neural Networks are heavily dependent on data. Ever wondered how can we use any information other than training examples from dataset to train a Neural Network? I came across this NAACL’18 paper : Learning beyond datasets which propose to enhance learning models with world knowledge in the form of Knowledge Graph.

Motivation

Learning models are mostly dependent on the availability of training samples (may be labeled or unlabeled). If we possess the world knowledge collected from past experiences and use it to train our model alongwith the training data , can the model performance be improved?. The paper addresses the question:

Is it possible to develop learning models that can be trained in a way that it is able to infuse a general body of world knowledge for prediction apart from learning based on training data?

Knowledge Graph

Model should have access to structured world knowledge that need not be domain specific. Knowledge Graph contain structured world knowledge in the form of fact triplets. Each triplet contain a subject entity, relation and object entity. For example <Dark Knight, other_name, Batman> or <New Delhi, capital, India> can be fact triplets of a Knowledge Graph.

To get an intuition, consider the example of classifying a news article : “Narendra Modi offered his condolences to the flood victims in Kerala” . If we have triplets <Narendra Modi, Prime Minister, India> and <Kerala, State, India> then it becomes very easy to classify this as political news.

KG entities and relations they are required to be encoded in numerical representations before they can be used in the model. There are many KG embeddings available, namely TransE, TransH, ManifoldE, TransG, ProjE, DKRL,etc. DKRL embedding is used in the paper. It is a semantically enriched embedding which make use of the description of the entities provided and has the property that tranlating an entity by a relation gives the entity it is related to (e1 + r = e2).

Image Source: Original Paper

KG Retrieval Models

General supervised learning model tries to find parameters such that

for training data x and ground label y. The world knowledge feature vectors are retrieved from the KG using the data x with some function F and a new set of parameters . So, the objective now becomes

Two models are proposed in the paper, both uses LSTM encodings as the form of P and soft-attention as the form of F . Models are designed keeping classification task in mind.

Figure representing the retrieval of entity/relation vector from KG. Image Source : Original Paper

Vanilla KG Retrieval

KG entities and relations are available as DKRL embedded form. Following subsequent steps are performed:

  1. Input x is taken as word vector representation x = (x₁,x₂,….,xₜ). This vector is encoded using LSTM to generate hidden states hₖ at each time stamp. The final representation of x ( lets say o) is taken as the average of hidden states at each time stamp. o is transformed to another dimension using a weight matrix W and applying ReLU over it. At last we get a context representation C .
  2. Step 1 is duplicated using separate LSTMs to form two separate context vectors Cₑ and Cᵣ . These context vectors will be used to extract relevant entity and relation vector from the list of knowledge graph entities and relations.
  3. Since DKRL embedding of KG possess translation property, a fact from the KG can be retrieved by extracting an entity and a relation vector and performing translation to retrieve the other entity. Soft-attention is applied over the list of KG entities to extract an entity vector from the KG. Attention for an entity eᵢ (represented in form of DKRL embedding) using entity context vector Cₑ is given by

where |E| is the number of entities in the KG. Same is done for KG relations as well and attention weights are calculated.

4. The final entity and relation vector is computed as the weighted sum with the attention of individual entity/relation vectors.

5. Now we have the fact triple f = [e,r,e+r] . f is transformed to another space using weight matrix V and applied ReLU activation. Now this vector is concatenated with the context vector C of input x , transformed to another dimension space using weight matrix U and applied a softmax to generate the predictions

The final predictions y includes the information from both data samples and the knowledge graph. The attention mechanism tunes itself while training to retrieve relevant facts helpful for classification.

Figure representing the pretraining and parameter transfer to overcome gradient saturation problem. Image Source : Original Paper

While training the Vanilla model gradient saturation problem occurred due to which the model started ignoring KG part and learned only from training data. To overcome this, they trained the KG part separately for some epochs and transferred those weights to joint train with the complete module. This eradicated the problem.

CNN based KG Retrieval

Cluster representation using two Conv layers. Image Source: Original Paper

In the Vanilla model, attention mechanism has to be applied over a very large space. CNN based model is proposed to reduce the number of entities and relations over which attention has to be generated. This is done by first clustering the entities/relations into l clusters with equal size using k-means clustering and then encoding each cluster using Convolutional filters. Entities/relations in each cluster are stacked to form a 2-D input to the CNN encoder. CNN encoder contains a 1-D filter convolutional layer followed by a pooling layer and again a 1-D filter convolutional layer followed by a pooling layer. This Convolutional architecture is also simultaneously trained with the rest of the architecture. It learns to represent the information from the most relevant entity in the cluster and adapts accordingly for the classification task. The Attention mechanism is now no more burdened with the large number of entity/relation.

Results

Paper presents results on two datasets — News20 dataset for news classification and SNLI for Natural language inference. Freebase 15k KG is used for the News20 classification task and Wordnet KG is used for SNLI dataset.

Results from the paper

The Convolutional based model shows significant over the plain LSTM. Main focus was on improving any deep learning model rather than beating the state-of-the-art. A weaker baseline has shown more improvement than that of stronger baseline ( which is already strong enough to classify on its on). Experiments are also performed on DBpedia ontology classification dataset with stronger baseline of 98.6%, which after augmenting with KG model gave 0.2% improvement. Thus the model is capable of improving stronger baselines as well.

They also tried reducing the training data size by some fraction. It was found that KG based model training resulted in lower loss value than the plain lstm model. It also showed significant improvement in terms of accuracy.

Conclusion

This paper was first of its kind that proposed novel Vanilla and CNN based KG Retrieval model. The idea of incorporating world knowledge can be extended to any other NLP tasks as well.

--

--

Anshuman Mourya

Deep Learning Enthusiast. Member of Statistics and Machine Learning lab @IISc Bangalore