Report on Text Classification using CNN, RNN & HAN

Akshat Maheshwari
Jul 17, 2018 · 8 min read


Hello World!! I recently joined as NLP Researcher (Intern 😇) and I was asked to work on the text classification use cases using Deep learning models.

In this article I will share my experiences and learnings while experimenting with various neural networks architectures.

I will cover 3 main algorithms such as:

  1. Convolutional Neural Network (CNN)
  2. Recurrent Neural Network (RNN)
  3. Hierarchical Attention Network (HAN)

Text classification was performed on datasets having Danish, Italian, German, English and Turkish languages.

Let’s get to it. ✅

About Natural Language Processing (NLP)

One of the widely used Natural Language Processing & Supervised Machine Learning (ML) task in different business problems is “Text Classification”, it’s an example of Supervised Machine Learning task since a labelled dataset containing text documents and their labels is used for training a classifier.

The goal of text classification is to automatically classify the text documents into one or more predefined categories.

Some examples of text classification are:

  • Understanding audience sentiment (😁 😐 😥) from social media
  • Detection of spam & non-spam emails
  • Auto tagging of customer queries
  • Categorisation of news articles 📰 into predefined topics

Text Classification is a very active research area both in academia 📚 and industry. In this post, I will try to present a few different approaches and compare their performances, where implementation is based on Keras.

All the source code and the results of experiments can be found in jatana_research repository.

Image for post
Image for post

An end-to-end text classification pipeline is composed of following components:

  1. Training text: It is the input text through which our supervised learning model is able to learn and predict the required class.
  2. Feature Vector: A feature vector is a vector that contains information describing the characteristics of the input data.
  3. Labels: These are the predefined categories/classes that our model will predict
  4. ML Algo: It is the algorithm through which our model is able to deal with text classification (In our case : CNN, RNN, HAN)
  5. Predictive Model: A model which is trained on the historical dataset which can perform label predictions.
Image for post
Image for post

Analysing Our Data :

We are using 3 types of dataset with various classes as shown in table below:

Image for post
Image for post

Text Classification Using Convolutional Neural Network (CNN) :

CNN is a class of deep, feed-forward artificial neural networks ( where connections between nodes do not form a cycle) & use a variation of multilayer perceptrons designed to require minimal preprocessing. These are inspired by animal visual cortex.

I have taken reference from Yoon Kim paper and this blog by Denny Britz.

CNNs are generally used in computer vision, however they’ve recently been applied to various NLP tasks and the results were promising 🙌 .

Let’s briefly see what happens when we use CNN on text data through a diagram.The result of each convolution will fire when a special pattern is detected. By varying the size of the kernels and concatenating their outputs, you’re allowing yourself to detect patterns of multiples sizes (2, 3, or 5 adjacent words).Patterns could be expressions (word ngrams?) like “I hate”, “very good” and therefore CNNs can identify them in the sentence regardless of their position.

Image for post
Image for post

In this section, I have used a simplified CNN to build a classifier. So first use Beautiful Soup in order to remove some HTML tags and some unwanted characters.

def clean_str(string):
string = re.sub(r"\\", "", string)
string = re.sub(r"\'", "", string)
string = re.sub(r"\"", "", string)
return string.strip().lower()

texts = [];labels = []

for i in range(df.message.shape[0]):
text = BeautifulSoup(df.message[i])

for i in df['class']:

Here I have used Google Glove 6B vector 100d. Its Official documentation :

‘‘‘ GloVe is an unsupervised learning algorithm for obtaining vector representations for words. Training is performed on aggregated global word-word co-occurrence statistics from a corpus, and the resulting representations showcase interesting linear substructures of the word vector space. ’’’

For an unknown word, the following code will just randomise its vector. Below is a very simple Convolutional Architecture, using a total of 128 filters with size 5 and max pooling of 5 and 35, following the sample from this blog.

sequence_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
embedded_sequences = embedding_layer(sequence_input)
l_cov1= Conv1D(128, 5, activation='relu')(embedded_sequences)
l_pool1 = MaxPooling1D(5)(l_cov1)
l_cov2 = Conv1D(128, 5, activation='relu')(l_pool1)
l_pool2 = MaxPooling1D(5)(l_cov2)
l_cov3 = Conv1D(128, 5, activation='relu')(l_pool2)
l_pool3 = MaxPooling1D(35)(l_cov3) # global max pooling
l_flat = Flatten()(l_pool3)
l_dense = Dense(128, activation='relu')(l_flat)
preds = Dense(len(macronum), activation='softmax')(l_dense)

Here is the architecture of the CNN Model.

Image for post
Image for post

Text Classification Using Recurrent Neural Network (RNN) :

A recurrent neural network (RNN) is a class of artificial neural network where connections between nodes form a directed graph along a sequence. This allows it to exhibit dynamic temporal behavior for a time sequence.

Using the knowledge from an external embedding can enhance the precision of your RNN because it integrates new information (lexical and semantic) about the words, an information that has been trained and distilled on a very large corpus of data.The pre-trained embedding we’ll be using is GloVe.

RNNs may look scary 😱 . Although they’re complex to understand, they’re quite interesting. They encapsulate a very beautiful design that overcomes traditional neural networks’ shortcomings that arise when dealing with sequence data: text, time series, videos, DNA sequences, etc.

RNN is a sequence of neural network blocks that are linked to each others like a chain. Each one is passing a message to a successor. Again if you want to dive into the internal mechanics, I highly recommend Colah’s blog.

Image for post
Image for post

Same preprocessing is also done here using Beautiful Soup. We will process text data, which is a sequence type. The order of words is very important to the meaning. Hopefully RNNs take care of this and can capture long-term dependencies.

To use Keras on text data, we first have to preprocess it. For this, we can use Keras’ Tokenizer class. This object takes as argument num_words which is the maximum number of words kept after tokenization based on their word frequency.

MAX_NB_WORDS = 20000
tokenizer = Tokenizer (num_words=MAX_NB_WORDS) tokenizer.fit_on_texts(texts)

Once the tokenizer is fitted on the data, we can use it to convert text strings to sequences of numbers. These numbers represent the position of each word in the dictionary (think of it as mapping).

  • In this section, I will try to tackle the problem by using recurrent neural network and attention based LSTM encoder.
  • By using LSTM encoder, we intent to encode all the information of text in the last output of Recurrent Neural Network before running feed forward network for classification.
  • This is very similar to neural translation machine and sequence to sequence learning. Following is the figure from A Hierarchical Neural Autoencoder for Paragraphs and Documents.
Image for post
Image for post
  • I’m using LSTM layer in Keras to implement this. Other than forward LSTM, here I have used bidirectional LSTM and concatenate both last output of LSTM outputs.
  • Keras has provide a very nice wrapper called bidirectional, which will make this coding exercise effortless. You can see the sample code here
sequence_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
embedded_sequences = embedding_layer(sequence_input)
l_lstm = Bidirectional(LSTM(100))(embedded_sequences)
preds = Dense(len(macronum), activation='softmax')(l_lstm)
model = Model(sequence_input, preds)
model.compile(loss='categorical_crossentropy',optimizer='rmsprop', metrics=['acc'])

Here is the architecture of the RNN Model.

Image for post
Image for post

Text Classification Using Hierarchical Attention Network (HAN) :

I have taken reference from this research paper Hierarchical Attention Networks for Document Classification. It can be a great guide for Document Classification using HAN. Same pre-processing is also done here using Beautiful Soup. The pre-trained embedding we’ll be using is GloVe.

  • Here I am building a Hierarchical LSTM network. I have to construct the data input as 3D rather than 2D as in above two sections.
  • So the input tensor would be [# of reviews each batch, # of sentences, # of words in each sentence].
tokenizer = Tokenizer(nb_words=MAX_NB_WORDS)
data = np.zeros((len(texts), MAX_SENTS, MAX_SENT_LENGTH), dtype='int32')
for i, sentences in enumerate(reviews):
for j, sent in enumerate(sentences):
if j< MAX_SENTS:
wordTokens = text_to_word_sequence(sent)
for _, word in enumerate(wordTokens):
if(k<MAX_SENT_LENGTH and tokenizer.word_index[word]<MAX_NB_WORDS):
data[i,j,k] = tokenizer.word_index[word]

After this we can use Keras magic function TimeDistributed to construct the Hierarchical input layers as following. We can also refer to this post.

embedding_layer=Embedding(len(word_index)+1,EMBEDDING_DIM,weights=[embedding_matrix],input_length=MAX_SENT_LENGTH,trainable=True)sentence_input = Input(shape=(MAX_SENT_LENGTH,), dtype='int32')
embedded_sequences = embedding_layer(sentence_input)
l_lstm = Bidirectional(LSTM(100))(embedded_sequences)
sentEncoder = Model(sentence_input, l_lstm)

review_input = Input(shape=(MAX_SENTS,MAX_SENT_LENGTH), dtype='int32')
review_encoder = TimeDistributed(sentEncoder)(review_input)
l_lstm_sent = Bidirectional(LSTM(100))(review_encoder)
preds = Dense(len(macronum), activation='softmax')(l_lstm_sent)
model = Model(review_input, preds)

Here is the architecture of the HAN Model.

Image for post
Image for post


Image for post
Image for post

Here are the plots for Accuracy 📈 and Loss 📉

Image for post
Image for post
Image for post
Image for post

Observations 👇 :

  • Based on the above plots, CNN has achieved good validation accuracy with high consistency, also RNN & HAN have achieved high accuracy but they are not that consistent throughout all the datasets.
  • RNN was found to be the worst architecture to implement for production ready scenarios.
  • CNN model has outperformed the other two models (RNN & HAN) in terms of training time, however HAN can perform better than CNN and RNN if we have a huge dataset.
  • For dataset 1 and dataset 2 where the training samples are more, HAN has achieved the best validation accuracy while when the training samples are very low, then HAN has not performed that good (dataset 3).
  • When training samples are less (dataset 3) CNN has achieved the best validation accuracy.
Image for post
Image for post
Image for post
Image for post

Performance Improvements :

To achieve the best performances 😉, we may:

  1. Fine Tune Hyper-Parameters : Hyper-parameters are the variables which are set before training and determine the network structure & how the network is trained. (eg : learning rate, batch size, number of epochs). Fine tuning can be done by : Manual Search, Grid Search, Random Search…
  2. Improve Text Pre-Processing : Better pre-processing of input data can be done as per the need of your dataset like removing some special symbols, numbers, stopwords and so on …
  3. Use Dropout Layer : Dropout is regularization technique to avoid overfitting (increase the validation accuracy) thus increasing the generalizing power.

Infrastructure setup:

All the above experiments were performed on 8 core vCPU’s with Nvidia Tesla K80 GPU.

Further all the experiments were performed under the guidance of Rahul Kumar 😎.

Also I would like to thanks for providing me a very good infrastructure and full support throughout my journey 😃.


Jatana brings warp speed to the help desk by automating…

Thanks to Rahul Kumar

Akshat Maheshwari

Written by

Pre-final year student in Information Technology (IIIT Gwalior), Machine Learning Enthusiast and Technology Explorer.



Jatana brings warp speed to the help desk by automating replies to support requests so your agents can focus on the important details that make your customers go WOW!

Akshat Maheshwari

Written by

Pre-final year student in Information Technology (IIIT Gwalior), Machine Learning Enthusiast and Technology Explorer.



Jatana brings warp speed to the help desk by automating replies to support requests so your agents can focus on the important details that make your customers go WOW!

Medium is an open platform where 170 million readers come to find insightful and dynamic thinking. Here, expert and undiscovered voices alike dive into the heart of any topic and bring new ideas to the surface. Learn more

Follow the writers, publications, and topics that matter to you, and you’ll see them on your homepage and in your inbox. Explore

If you have a story to tell, knowledge to share, or a perspective to offer — welcome home. It’s easy and free to post your thinking on any topic. Write on Medium

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store