Should We Abandon LSTM for CNN?

Geoffrey So
AI/ML at Symantec
Published in
10 min readMar 29, 2019

Introduction

It always depends. It depends on the specific problem, the data available and the time you are willing to spend.

The character “Data” from Star Trek: The Next Generation falling for a trap. Image Source: https://blackboxlabs.github.io/2018/04/18/machine-learning-dont-fall-into-this-data-trap/

For the casual readers not steeped in machine learning: you may wonder, what is an LSTM? It is not a cable news alternative. LSTM stands for Long Short-Term Memory, a type of computer neural network usually used to predict sequences of data. CNN on the other hand stands for Convolutional Neural Network, another type of computer neural network that is often used for classifying images. So why the question about abandoning one (LSTM) for the other (CNN)? That is because recently there’s literature that points out that CNN can achieve what LSTM has been used for and great at, namely predicting sequences, but in a much faster, more computationally efficient manner.

· http://arxiv.org/abs/1509.01626

· https://arxiv.org/abs/1803.01271

· https://towardsdatascience.com/the-fall-of-rnn-lstm-2d1594c74ce0

If the articles and blogs are true, then why are people still bothering with LSTM (or other flavors of recurrent neural networks)? As any good scientist interested in this subject, I tried to reproduce results from these papers with data that I have access to. If the claims of CNN superiority are true, then I expect to find the same results with my data, and CNNs should come out on top of LSTMs. And if I find that CNN does not come out on top, then maybe the claims that CNNs are better need refinement. Maybe CNN is often better, or only better under certain conditions; wouldn’t be the doom and gloom for all LSTM as others may have you believe from catchy titles.

Background on CNN/LSTM

CNN

So what makes a CNN different than a LSTM? To see what makes a CNN, I will show a picture representation of what the neural network sees (captured as features) in different layers, hopefully giving you a high-level understanding of what goes on under the hood of this complex machinery.

Image Source: https://www.kdnuggets.com/2016/11/intuitive-explanation-convolutional-neural-networks.html/3 Paper: http://web.eecs.umich.edu/~honglak/icml09-ConvolutionalDeepBeliefNetworks.pdf

In the lowest layer, the neurons in the network usually identify important small scale features, such as boundaries, corners and intensity differences. Then in higher layers, the network combines the lower level features to form more complex features such as simple shapes, forms and partial objects. And on the final layer, the network combines the lower features to form fully human recognizable objects. In this particular example, the layer 1 edge features are used to construct the layer 2 eyes and ears, which then get combined to resemble human faces.

One may ask, how does a CNN pick out the features in the first place? Great question! It does this by doing what is called a convolution (hence the name convolutional neural network). It is basically simple matrix multiplication, starting with randomly generated matrices (called “kernels” in machine learning speak).

Image Source: https://hackernoon.com/visualizing-parts-of-convolutional-neural-networks-using-keras-and-cats-5cc01b214e59

The kernels are multiplied with different sections of the image and when it encounters a feature that matches well with the kernel, a signal results, indicating that image patch contains an important feature. Suppose the randomly generated kernel didn’t find a good feature (doesn’t get many signals) then it gets updated through the backpropagation algorithm so that it can improve (through a series of chain rule matrix multiplications). The kernel update helps the kernels transition from random matrices, to matrices which discriminate important features. For example, in the first layer of the network, the kernels adapt to find image edges. This self-improving backpropagation algorithm is the way the machine learning algorithm “learns” from the data it is being fed. The algorithm becomes better and better at capturing important features through many iterations or “epochs” of the same data.

Readers interested in learning hands-on how a CNN captures features can visit TensorFlow Playground. The site has a nice GUI that allows one to easily build a CNN and get real-time results from the machine learning algorithm.

So that is CNN at a high level. What about LSTM?

LSTM

An LSTM is designed to work differently than a CNN because an LSTM is usually used to process and make predictions given sequences of data (in contrast, a CNN is designed to exploit “spatial correlation” in data and works well on images and speech).

At a simple level, an LSTM is just a neuron unit that feeds information back on to itself for the next time step in a sequence. It is a flavor of a more general class of neural network called recurrent neural network (RNN).

RNNs were designed to retain long range information, so that in a long sequence, the information is remembered and not lost. However, the feedback loop in RNNs created two problems. The first problem, the vanishing gradient problem, occurs when the updating gradient in backpropagation is small. By consecutive matrix multiplications of the chain rule through the depth of the network the value shrinks quickly, so when the update is tiny or zero, the network does not learn anything. The second problem is the opposite, the exploding gradient problem, occurs when the gradient is big, and gets much bigger quickly by the same consecutive matrix multiplication mechanism. This causes the model to crash or hang because it cannot handle the large values. Here’s a nice post that describes these problems in more detail:

The problem is exacerbated by network depth. When you “unroll an RNN” it looks something like this:

Image Source: http://colah.github.io/posts/2015-08-Understanding-LSTMs

The deeper the network, the more “t” units in this diagram, and the worse the problem of vanishing/exploding gradients because the more matrix multiplications it has to go through for gradient updates to reach from t back to 0.

So far I’ve been talking about RNNs in general. The LSTM is a specific implementation of an RNN that introduces a more complex neuron which includes a forget gate. You can learn more about the details of LSTMs in the above links under the LSTM diagrams. Basically, if the signal passing through the neuron at a particular instant is deemed relevant by being similar to one of the signals it has seen before, then the forget gate in the neuron will not forget that sequence. A sequence of signal is maintained by being let through the forget gate with a certain size (fixed between 0 to 1). This way, the update to the gradient of the signal does not explode. And if a sequence of signal is no longer relevant the forget gate will help rid of it by letting 0 signal through.

Here are some additional details about cnn/rnn if the reader is interested:

Literature/Evidence

The works cited earlier, Zhang et al. 2016 tried to use CNN for classifying character sequences. Since text sequences are 1D signals in the time dimension, they want to transform text into 2D to exploit the advantages of using CNN to process 2D spatially correlated data. The way they changed 1D text into 2D is via a mechanism called text embedding. This is to represent each character of the text as a vector. So, instead of doing convolution in 2D on an image, they do the convolution in 1D in the time dimension but with a 2D kernel. The second dimension is the character embedding dimension, because each character of a word is now represented as a vector. So the sequence of character text labeled “Some Text” in the following image can now be represented as a matrix with the time axis going to the right.

Paper: https://arxiv.org/pdf/1509.01626.pdf

If that is too much matrix notation for your liking, you just have to know that this method is a way for the authors to use CNN on text data, which according to them, they are the first to attempt this back in 2016. The results they showed were competitive with state-of-the-art LSTM back then.

Fast forward roughly another two years, the authors Bai et al. 2018 showed their flavor of CNN can remember much longer sequences and again be competitive and even better than LSTM (and other flavors of RNN) for a wide range of tasks. They propose to use causality to restrict the signal a neuron sees, and to apply dilation to the convolution to make the CNN have a much wider receptive field (can see further into the past).

My Work

Reading all the above success stories got me excited that maybe I can apply these tricks of CNN to beat LSTM. One of the reasons why people would prefer to use CNN over LSTM is the amount of training time. The current generation of popular deep learning hardware are basically Nvidia graphics cards, and they are optimized to process 2D data with extreme parallelism and speed, which CNNs utilize. LSTMs on the other hand, process things more sequentially, so the deep learning hardware does not increase its speed by much, especially during the training phase of the network.

I had access to a dataset of regular clean URLs that regular people access daily for information and entertainment and some malicious URLs that bad actors on the internet use as a communication portal to a command center. These malicious URLs are usually randomly generated strings, and I would like to detect which URLs are of the randomly generated type. The LSTM is known to be competitive with the state of the art on this task, so it gave me a baseline for comparison. I worked to implement the techniques for the CNN mentioned in the Zhang et al. and Bai et al. papers to see if I could implement a CNN that beat the LSTM baseline in terms of efficacy. Long story short, no matter what I tried, I was not able to get the character CNN to beat the results of LSTM. I tried a learned embedding, fixed one-hot embedding, causal receptive field, and dilated convolutions. Although some experiments were promising, none were able to surpass the simple LSTM.

While CNNs couldn’t beat the LSTM in efficacy, there was a drastic difference in training time between the two architectures. This could be due to the aforementioned hardware optimization from Nvidia and the sequential nature of LSTM. The character CNNs I played with were able to achieve an accuracy comparable to LSTM in a fraction of the time. The CNN accuracy rapidly increased, then plateaued while the LSTM slowly continued to improve slightly for longer training iterations.

Conclusion/Possible Alternatives

Thinking back, when I first started this problem, I was full of hope, but when I reviewed the results of my experiments, my hopes were dashed. No matter what I tried, I could not find a CNN that outperformed the LSTM. Am I disappointed? Somewhat. Will this stop me from trying character CNNs on other sequence problems? Absolutely not! I am still hopeful that I will be able to use it for other types of sequence problems, especially if I care about the amount of time I spend training the neural network.

One thing that dawned on me after I completed this exercise is that maybe the dataset affected the results. Because of the randomness of the characters in the URL, I didn’t see any improvements when I limited the receptive field to be causal, which makes sense because they are supposed to be random and independent of the past and future. Also, many of the URLS are pretty short. So in retrospect, maybe it did not make that much sense in this case, to expect a CNN to capture more long range information than the LSTM, because there were not many that had that long of a length.

Given my experiments and the dataset I used, the main conclusion I arrived at is this: character CNNs can be comparable to LSTMs, but it does not mean it will always beat LSTM. Others may say I have not optimized my CNN, and that is true, but I also did not optimize my LSTM, both were just vanilla with small tweaks to architecture and not much hyperparameter tuning. I am still hopeful that one of these days I will find a dataset where a character CNN will beat LSTM. And of course, doing some literature search will turn up a sea of results of possible alternatives that may be even better. People are combining CNNs and LSTMs, using attentional mechanisms, etc., so there are always new avenues to explore. There is a clear trend in the literature of moving away from LSTMs, but they have not disappeared and many these experiments provide hints at why. It seems like the CNN vs. LSTM game is not over yet!

Image Source: https://giphy.com/gifs/elementary-Mt2RXYBhsicXS

--

--