Published in


Report on Text Classification using CNN, RNN & HAN


  1. Convolutional Neural Network (CNN)
  2. Recurrent Neural Network (RNN)
  3. Hierarchical Attention Network (HAN)

About Natural Language Processing (NLP)

  • Understanding audience sentiment (😁 😐 😥) from social media
  • Detection of spam & non-spam emails
  • Auto tagging of customer queries
  • Categorisation of news articles 📰 into predefined topics

An end-to-end text classification pipeline is composed of following components:

  1. Training text: It is the input text through which our supervised learning model is able to learn and predict the required class.
  2. Feature Vector: A feature vector is a vector that contains information describing the characteristics of the input data.
  3. Labels: These are the predefined categories/classes that our model will predict
  4. ML Algo: It is the algorithm through which our model is able to deal with text classification (In our case : CNN, RNN, HAN)
  5. Predictive Model: A model which is trained on the historical dataset which can perform label predictions.

Analysing Our Data :

Text Classification Using Convolutional Neural Network (CNN) :

Image Reference : http://www.wildml.com/2015/11/understanding-convolutional-neural-networks-for-nlp/
def clean_str(string):
string = re.sub(r"\\", "", string)
string = re.sub(r"\'", "", string)
string = re.sub(r"\"", "", string)
return string.strip().lower()

texts = [];labels = []

for i in range(df.message.shape[0]):
text = BeautifulSoup(df.message[i])

for i in df['class']:
sequence_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
embedded_sequences = embedding_layer(sequence_input)
l_cov1= Conv1D(128, 5, activation='relu')(embedded_sequences)
l_pool1 = MaxPooling1D(5)(l_cov1)
l_cov2 = Conv1D(128, 5, activation='relu')(l_pool1)
l_pool2 = MaxPooling1D(5)(l_cov2)
l_cov3 = Conv1D(128, 5, activation='relu')(l_pool2)
l_pool3 = MaxPooling1D(35)(l_cov3) # global max pooling
l_flat = Flatten()(l_pool3)
l_dense = Dense(128, activation='relu')(l_flat)
preds = Dense(len(macronum), activation='softmax')(l_dense)

Here is the architecture of the CNN Model.

Text Classification Using Recurrent Neural Network (RNN) :

Image Reference : http://colah.github.io/posts/2015-08-Understanding-LSTMs/
MAX_NB_WORDS = 20000
tokenizer = Tokenizer (num_words=MAX_NB_WORDS) tokenizer.fit_on_texts(texts)
  • In this section, I will try to tackle the problem by using recurrent neural network and attention based LSTM encoder.
  • By using LSTM encoder, we intent to encode all the information of text in the last output of Recurrent Neural Network before running feed forward network for classification.
  • This is very similar to neural translation machine and sequence to sequence learning. Following is the figure from A Hierarchical Neural Autoencoder for Paragraphs and Documents.
Image Reference : https://arxiv.org/pdf/1506.01057v2.pdf
  • I’m using LSTM layer in Keras to implement this. Other than forward LSTM, here I have used bidirectional LSTM and concatenate both last output of LSTM outputs.
  • Keras has provide a very nice wrapper called bidirectional, which will make this coding exercise effortless. You can see the sample code here
sequence_input = Input(shape=(MAX_SEQUENCE_LENGTH,), dtype='int32')
embedded_sequences = embedding_layer(sequence_input)
l_lstm = Bidirectional(LSTM(100))(embedded_sequences)
preds = Dense(len(macronum), activation='softmax')(l_lstm)
model = Model(sequence_input, preds)
model.compile(loss='categorical_crossentropy',optimizer='rmsprop', metrics=['acc'])

Here is the architecture of the RNN Model.

Text Classification Using Hierarchical Attention Network (HAN) :

  • Here I am building a Hierarchical LSTM network. I have to construct the data input as 3D rather than 2D as in above two sections.
  • So the input tensor would be [# of reviews each batch, # of sentences, # of words in each sentence].
tokenizer = Tokenizer(nb_words=MAX_NB_WORDS)
data = np.zeros((len(texts), MAX_SENTS, MAX_SENT_LENGTH), dtype='int32')
for i, sentences in enumerate(reviews):
for j, sent in enumerate(sentences):
if j< MAX_SENTS:
wordTokens = text_to_word_sequence(sent)
for _, word in enumerate(wordTokens):
if(k<MAX_SENT_LENGTH and tokenizer.word_index[word]<MAX_NB_WORDS):
data[i,j,k] = tokenizer.word_index[word]
embedding_layer=Embedding(len(word_index)+1,EMBEDDING_DIM,weights=[embedding_matrix],input_length=MAX_SENT_LENGTH,trainable=True)sentence_input = Input(shape=(MAX_SENT_LENGTH,), dtype='int32')
embedded_sequences = embedding_layer(sentence_input)
l_lstm = Bidirectional(LSTM(100))(embedded_sequences)
sentEncoder = Model(sentence_input, l_lstm)

review_input = Input(shape=(MAX_SENTS,MAX_SENT_LENGTH), dtype='int32')
review_encoder = TimeDistributed(sentEncoder)(review_input)
l_lstm_sent = Bidirectional(LSTM(100))(review_encoder)
preds = Dense(len(macronum), activation='softmax')(l_lstm_sent)
model = Model(review_input, preds)

Here is the architecture of the HAN Model.


Here are the plots for Accuracy 📈 and Loss 📉

Observations 👇 :

  • Based on the above plots, CNN has achieved good validation accuracy with high consistency, also RNN & HAN have achieved high accuracy but they are not that consistent throughout all the datasets.
  • RNN was found to be the worst architecture to implement for production ready scenarios.
  • CNN model has outperformed the other two models (RNN & HAN) in terms of training time, however HAN can perform better than CNN and RNN if we have a huge dataset.
  • For dataset 1 and dataset 2 where the training samples are more, HAN has achieved the best validation accuracy while when the training samples are very low, then HAN has not performed that good (dataset 3).
  • When training samples are less (dataset 3) CNN has achieved the best validation accuracy.

Performance Improvements :

  1. Fine Tune Hyper-Parameters : Hyper-parameters are the variables which are set before training and determine the network structure & how the network is trained. (eg : learning rate, batch size, number of epochs). Fine tuning can be done by : Manual Search, Grid Search, Random Search…
  2. Improve Text Pre-Processing : Better pre-processing of input data can be done as per the need of your dataset like removing some special symbols, numbers, stopwords and so on …
  3. Use Dropout Layer : Dropout is regularization technique to avoid overfitting (increase the validation accuracy) thus increasing the generalizing power.

Infrastructure setup:



Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Akshat Maheshwari

Pre-final year student in Information Technology (IIIT Gwalior), Machine Learning Enthusiast and Technology Explorer.