Luc Nies
Orikami blog
Published in
6 min readJul 31, 2018

--

Diagnosing Myocardial Infarction using Long-Short Term Memory networks (LSTM’s)

Source: djekova.info

Introduction

A couple of months ago Keras released the CuDNNLSTM and CuDNNGRU layers, which are special implementations of the regular LSTM and GRU layers backed by NVIDIA’s cuDNN library. This means that if you have access to a CUDA GPU, training recurrent neural networks just got a whole lot faster.

In this blogpost I’ll be showing a simple implementation of an LSTM network implemented in Keras using the new CuDNNLSTM . I will also be using this opportunity to explore the PTB Diagnostic ECG Database, a database containing ECG data of 238 heart patients, and try to automatically recognize patients who suffered from a myocardial infarction from the healthy control.

Long-Short Term Memory

First of all, what are Long-Short Term Memory (LSTM) networks? The short answer is: it is a special network capable of “remembering” previous inputs. Each neuron, or LSTM node, in such a network maintains an internal state, based on previous input. It uses it’s current state to make predictions about new input data. During training it learns what it should remember and what it should forget to make proper predictions. This property is why LSTMs are well suited for classification tasks concerning time series data such as natural language processing or … the classification of ECG signals!

If you are more interested in the long answer and want to know more about the inner workings of LSTMs, I really recommend you read Colah’s blog on this topic as where he explains the topic expertly.

PhysioNet and the PTB Diagnostic ECG Database

PhysioNet hosts several large data sets of physiological signals and released a related open-source software package, the PhysioToolkit, which contains many useful packages to allow for easy processing of their data sets. One of the data sets is the PTB Diagnostic ECG database. It contains ECG records of 290 subjects, of which 52 are healthy controls, 148 have had a Myocardial infarction (more commonly known as a heart attack). The remaining 90 subjects suffer from a different heart disease. For the sake of this example we will only use the recordings of the patients with Myocardial infarction and the healthy controls.

Preparation

We will be using the python implementation of the WaveForm DataBase (wfdb) Software Package from the PhysioToolkit to download the data from Physiobank and interact with it. You can install it via pip:

pip install wfdb

Besides wfdb we will also be using Tensorflow, keras, numpy, pandas, sklearn, and matplotlib. As these packages are pretty standard in the field of machine learning, I am assuming you have already installed these and won’t cover installing them.

Before we start coding, first import all the used packages:

Getting the data

The wfdb package supports the downloading of PhysioNet databases. First get a list of available records to download from the ptdb database and then download those files. After the download (this takes a couple of minutes), let’s load the first record and see what data it contains.

A record contains several attributes. The ones we will be using are:

  • record.p_signal: The raw sensor values from the ECG. It is numpy array of sig_len by n_channels. The signal length (number of samples) per recording can differ, but the number of channels is always 15
  • record.sig_name: the names of each channel
  • record.sig_len: the length or number of sample of the recording
  • record.comments: A list of strings containing extra information on the subject. It also includes the reason for admission (the fifth item) which we will be using as a label

Let’s get all these attributes as metadata and load it into a pandas dataframe for easy handling. We’re not loading the raw signal data yet to save some memory.

Data exploration: Comparing the data

We are now ready to start doing some data science! Let’s first compare the ECG’s of a healthy control and infarction patient:

Comparison of the first channel of the ECG of a Healthy control and a subject who had a Myocardial infarction.

As can be seen from this image, the ECG signal of the healthy control is a lot more constant and less noisy compared to that of the myocardial patient. This difference is visible over multiple channels and over multiple patients. The codeblock below shows the code which allows you to compare more patients and channels with each other. If you do, you’ll notice that the ECG’s of the myocardial patients generally look a lot more noisy and has less profound spikes.

Loading the sensors data of two subjects and plot them side by side.

Creating training and test set

To train and validate the LSTM network we first split the available data into a training and test set. The test set subjects are not included in the training set. Since not all recordings have the same lengths, all the signal data is divided into equally sized windows. In this case the window size is 2048, which is large enough to always include 2 heart beats. The code block below shows the splitting of the data in training and test sets, and the splitting of each signal in smaller windows/sequences.

Since some subjects have multiple recordings and each recording is split into separate windows, we need to keep track to whom each sequence in the test set belongs to. This is stored in record_list.

If you want to reproduce the same results as I got, you can set the seed of the random number generator to this pseudo-randomly chosen seed:

np.random.seed(1337)

Building the LSTM

We will be using a simple 3 layered LSTM network with dropout after each layer. The number of LSTM nodes per layer start with 256 and are halved every next layer. The first LSTM two layers return the whole output sequence, while the last LSTM layer only returns the last step of its output sequence, thus dropping the temporal dimension.

Training the network

Since the data is split into windows, the network is trained on windowed data, learning the label for each window/sequence. As there is a large class imbalance (there are 16235 infarction sequences, where there are only 3675 control sequences), learning the healthy control is going to be harder than learning the infarction. We therefore need to adjust the sample weight of each sequence and give the control sequences a large weight.

Training this network does not take that long, around 5 minutes on a GTX1080.

Predicting

As the ECG data of each subject is divided into smaller sequences, we first need to predict a label for each sequence. As soon as we have those predictions we can group them for each patient and do patient-level prediction. In this case we simply take the mean of the sequence-level label and use that for our final patient-level label. If the average is 0.5 or lower, meaning that more than half of the sequences are classified as healthy, we classify the patient as healthy. If the average is more than 0.5 the patient is diagnosed with a Myocardial infarction.

The code block below shows how to do the predictions, calculate the labels and use classification_report to calculate the precision and recall.

Results

The results can differ a bit depending on the division of the test and training set, but you can expect an precision and recall of at least 0.90 on Myocardial infarction. For the healthy control these numbers are a bit lower, around 0.7 for both precision and recall.

Results when setting the seed to 1337.

Conclusion

The lower score on the healthy control is not strange considering there was a lot less data available of healthy subjects. However, in medical settings it is often preferable to have more false positives than to miss an actual true positive. That does not take away the fact that there are still several methods to improve the performance of the model as we did not do any preprocessing, data-augmentation or played around with hyper-parameters such as the learning rate. Despite the lack of such optimizations, I think that for such a simple LSTM network we got pretty good results. This shows LSTM are a very powerful tool when in classifications tasks when sequential data is involved.

Care to learn more cool stuff or collaborate with Orikami data scientists on personalized healthcare projects you or we work on? Send an e-mail to Luc@orikami.nl or call +31 24 3010100.

--

--

Luc Nies
Orikami blog

Data scientist. Interested in machine learning, AI, robotics, VR or just anything tech-related.