Building a Text Classifier using RNN

Sri Geetha M
Nerd For Tech
Published in
5 min readJun 25, 2021

In our Last Story, we discussed on building a text classifier without using an RNN. Here in this article, we are going to discuss on building a text classifier using Recurrent Neural Network (RNN).

SMS SPAM CLASSIFICATION USING RNN

Why RNN & Why not ANN?

Let us imagine, when we were first taught to write the alphabet A on a practice sheet, we would have done something like this.

Practising Letter A

During the process of writing itself, we would have realised that the pen is moving out of the line and the strokes must be changed. So if possible we will erase or at the least change the direction of pen in due course. This is what Recurrent Neural Network (RNN) exactly does.

Traditional neural networks will process an input and move onto the next one disregarding its sequence. ANNs complete the process and compute the error through Backpropagation. Unlike ANNs, RNNs essentially give a feedback even while the process is on. RNNs adapt Backpropagation through time (BPTT) since the model is temporal and deals with Time-series / Sequence data. Traditional Feed-Forward Neural Networks cannot comprehend this as each input is assumed to be independent of each other whereas in time-series/ sequence data each input is dependent on the previous input. Unlike feedforward neural networks, RNNs can use their internal state (memory) to process sequences of inputs.

Recurrent Neural Networks (RNN)

Recurrent Neural Network (RNN) is a generalization of Feed-Forward Neural Network with an internal memory component. RNN is recurrent in nature since it executes the same function for each data input, and the current input’s outcome is dependent on the previous computation. After producing the output, it is replicated and transmitted back into the recurrent network.

Transforming ANNs to RNNs

The nodes in different layers of the neural network are compressed to form a single layer of recurrent neural networks. A, B, and C are the parameters of the network.

Fully Connected Recurrent Neural Network

x -Input Layer, h -Hidden Layer, y -Output Layer & A, B, C -Parameters that affect the model output.

x(t-1), x(t), x(t+1), …. -Input vectors at the time steps t-1, t, t+1 respectively. If h(t) is the current new state at timestep t then h(t-1) is the old state at timestep t-1.

At any given time t, the current input will be a combination of input at x(t) and x(t-1). The output at any given time is fetched back to the network to improve on the output.

Types of Recurrent Neural Networks (RNN)

Based on the number of inputs an RNN receive and the number of outputs it generates, RNN can be categorised into any one of the following types.

TYPES OF RNN
TYPES OF RNN (Continued)

Applications of RNN Types

One issue with the Vanilla Neural Networks even CNNs is that they only work with pre-determined sizes: They take in fixed-size inputs and produce fixed-size outputs. RNNs are useful because they let us have both input and output variable-length sequences.

One-to-Many: This takes in one input and produces many outputs. E.g. Image Captioning wherein it takes one input image and produces a sequences of text captions.

Many-to-One: This takes in a sequence of input and produces only one output. E.g. Sentiment Analysis wherein it takes a sequential text as input like movie review and produces the sentiments as output like whether the review is positive/ negative.

Many-to-Many: This takes in sequential input and produces sequential output. There are two scenarios: In the first case, E.g. Language Translation it takes sequential input in one language and produces sequential output in the mentioned translated language. In the second case, E.g. Video Captioning wherein the video is converted into sequence of images and captioned with sequence of texts

Applications of RNN types

SMS SPAM CLASSIFIER USING RNN

Now Let’s start building a Text Classifier using RNN. For detailed view on Preprocessing texts, Click here. Basics of Text Pre-Processing is illustrated.

Steps in Text Pre-Processing

The dataset for sms spam classification can be downloaded from the UCI Repository.

  1. Import necessary packages.
  2. Download and read the dataset with pandas and do some exploratory data analysis to understand it.
  3. Add a simple word cloud function to display the common words used in SPAM & HAM messages. This gives an idea of what kind of words are dominant in each class. To make a word cloud, first separate the classes into two pandas data frames data_ham & data_spam
  4. Prepare the dataset for training by splitting into train and test datasets and the text column containing messages (spam/ ham) are preprocessed as discussed above. X_train & X_test being the messages in train and test datasets and y_train & y_test being their corresponding spam (1) / ham (0) numeric labels
  5. Add a SimpleRNN Layer from keras.
  6. The model is compiled and summary generated. The mode is now trained for 50 epochs but stops early when it reaches minimum validation loss. This is done by using EarlyStopping() and passing it to the callback parameter in model.fit().
  7. After prediction is made with the test data, a Classification report (with Precision, Recall, F-score and support) and a confusion matrix are generated for analysis.
  8. Now when the model performance looks good, as a final step of model building, save the trained model as well as the tokenizer. Saving the tokenizer ensures that you can process new datasets in the same way as you processed them during training. Python pickle library can be used to save the tokenizer.
  9. Final step is to load the model and the tokenizer and to use it to make predictions for new messages. The output of the prediction will be 0 if it is a ham message and 1 if it is spam.

To access the Complete COLAB CODE.

THANK YOU & HAPPY LEARNING

--

--