Document Classification on NLP | Towards AI
The current state of Deep Learning practices is fascinating to me because there are interesting likenesses between the way that we humans process images and text, and the algorithms we employ to have computers process image and text.
The Continuous Bag of Words approach, where word embeddings for each word in the sentence are aggregated to represent the entire sentence as a single vector, can achieve very respectable results in many NLP tasks.
Crucial however in is word understanding sentences component order a. Oops, let me try that again. Is order however a in crucial word sentences understanding component. No, that’s not right either. What I meant to say is: “However, word order is a crucial component in understanding sentences.” This idea gave rise to recurrent algorithms like LSTM’s and GRU’s, which do use word order to make predictions. Here is a good resource for digging more into recurrent algorithms; a quick summary would be that these models keep track of a state that gets updated from word to word. This state can be thought of as memory, taking into account the previous ideas/concepts of the words when processing the next word.
As Kevin so elegantly put it, not all words in a sentence are equal in importance. This idea of giving more predictive power to the more important words in a sentence is defined as attention. To make the problem more complex, words do not have a constant level of importance; the importance of the word is dependent upon the context of the sentence.
In order to better understand attention, I implemented a simplified version of this paper. The authors went one step above a normal attention model and used a hierarchical approach, in which they learned not just the importance of words in a sentence, but also the importance of sentences in a document. The dataset that I was working with consisted of one or two sentences in each observation, so I felt that the sentence attention layer was not necessary.
Full disclaimer, I originally implemented this model on a dataset that must remain private. My example code uses the publicly available IMDB dataset, where some reviews may be long enough to warrant using hierarchical attention.
So, how does attention work? It’s one thing to look at a given sentence and say which words are important. However, this model obviously is useless if it’s not generalizable, so it needs to somehow learn the properties of words, as well as how these properties interact and which interactions result in significance.
Step One: Represent each word in the vocabulary as an embedding vector of N dimensions. This is a super common approach in NLP, more information here.
Step Two: Send each sentence of embedding vectors through a GRU. The GRU is going to have a hidden state in between each word. Typically for prediction, we only care about the final state, but for this model, we want to keep track of each intermediate state as well. Let hᵢ be the vector that represents the hidden state afterword i. Note that, while likely, not necessary, I followed the paper in using a bidirectional GRU. This means that the model runs through the sentence forward and backward. Each word i then have hidden states hᵢᶠ and hᵢᵇ, and we simply concatenate these two vectors into hᵢ and proceed.
Step Three: Feed each of the hᵢ through a fully connected linear layer, including a bias term. The paper recommends that the output size have dimension 100; I have not yet explored the efficacy of tweaking this hyper-parameter, although I think that could be an interesting research area. For each element in the resulting vector, take the tanh. Call this new vector uᵢ, again corresponding to word i.
Step Four: Send each of the uᵢ through another linear layer, this time without a bias term. This linear layer should have a scalar output, so now we have a single scalar value associated with each word i. Then apply the softmax function for each sentence; the scalars will sum up to one for each sentence. Let the scalar for sentence i be called αᵢ.
Step Five: We’re almost at prediction time. We now have for each word i in a given sentence, a vector hᵢ, and an importance scalar αᵢ. It’s crucial to understand here that these hᵢ vectors are different from the original word embeddings, as they have memory of the sentence in both the forward and backward directions. We take an element-wise weighted sum for all vectors in the review, call this review vector s.
Step Six: The function applied to s differs based on the objective of the model, but because this model is interested in document binary classification, I applied a final linear layer to the vector s, which returns a singular value p, the probability of belonging to class 1.
That’s it as far as prediction goes, everything past this model architecture is basic Deep Learning practices. The fun part about this model is that not only are we given a prediction, but the vector of α values represents the importance of each corresponding word in a given sentence.
To test how it works, I put together a little frontend component to interact with the model. A user simply needs to enter the text of a film review, and the user is returned the probability of positivity as well as a visual depiction as to how important each word in the review is. The code for launching this interface as well as the model training is all available at my Github. Here are screenshots from the two reviews that I tried.
I’m sorry to break the hearts of recruiters everywhere, but I’m not looking for frontend dev jobs — I prefer to stick with Machine Learning. The model was very sure about the predictions, which is awesome, but more interesting to me is the words that the model chose as important. It looks like it knew to use entertaining and bad to make predictions, and therefore weighted these words significantly more than the rest of the sentence.
Attention is an effective way to improve the predictive power of a model, and it is really interesting to see what words the model deems useful. I hope that after reading this, you’ll consider implementing attention into your model to boost both prediction and inference.