Alzheimer Detection based on Images of MRI Segmentation

Claudia Quintana Wong
The Startup
Published in
5 min readJun 28, 2020
https://www.alztennessee.org/system/news/view/36/diabetes-research-leads-to-possible-alzheimers-treatment

Alzheimer’s is a progressive disease, where dementia symptoms gradually worsen over a number of years. In its early stages, memory loss is mild, but with late-stage Alzheimer’s, individuals lose the ability to carry on a conversation and respond to their environment.

Although current Alzheimer’s treatments cannot stop Alzheimer’s from progressing, they can temporarily slow the worsening of dementia symptoms and improve quality of life for those with Alzheimer’s and their caregivers. Image Processing plays an important role in the early detection of Alzheimer’s disease so that patients can be prevented before irreversible changes occur in the brain.

In this project, I have chosen the Alzheimer’s Dataset which contains data collected from various websites with each and every labels verified. The main goal is to build an end-to-end model to predict the stage of Alzheimer’s from MRI images.

Exploring data

The data consists of Magnetic Resonance Imaging (MRI) images presented in the training and test set. The images are classified into four different classes according to the stage of the disease:

  • Mild Demented (0)
  • Moderate Demented (1)
  • Non Demented (2)
  • Very Mild Demented (3)

If you are running out of Kaggle, the following code will help you to download the dataset. You should replace kaggle_username and kaggle_key for your actual credentials.

import os
import json
!pip install git+https://github.com/Kaggle/kaggle-api.git --upgrade
credentials = {"username":"kaggle_username","key":"kaggle_key"}
os.environ['KAGGLE_USERNAME']=credentials["username"]
os.environ['KAGGLE_KEY']=credentials["key"]
!kaggle datasets download -d tourist55/alzheimers-dataset-4-class-of-images
!unzip alzheimers-dataset-4-class-of-images.zip

Once the dataset has been downloaded, let find some insights. The dataset contains 6 400 images, 5121 in the training set, and 1279 in the dataset, distributed as it is shown below.

Distribution of classes
Distribution of classes: Mild Demented (0), Moderate Demented (1), Non-Demented (2), Very Mild Demented (3)

In order to implement a supervised model in PyTorch, the data should be stored in PyTorch datasets and data loaders. DataLoaders allow loading data in batches, this way we can save RAM memory. We will use the ImageFolder method to load images into a Dataset. We will apply randomly chosen transformations while loading images from the training dataset. These transformations will be applied randomly and dynamically each time a particular image is loaded.

When building deep learning models, it is a common practice to divide the data into train, validation, and test subsets. As the original dataset is only divided into train and test, we will take some examples from the training data as validation.

Once the data loader for every subset has been created, we are ready to start implementing the model.

Models

I am going to present to you just one model of the various implemented. More details about how the different models were implemented can be found in this notebook: https://jovian.ml/claudiaqw/alzheimer-detection. We will be implementing a deep model based on transfer learning.

In transfer learning, the neural network is trained in two stages:

  1. pretraining, where the network is generally trained on a large-scale benchmark dataset representing a wide diversity of labels/categories (e.g., ImageNet);
  2. fine-tuning, where the pretrained network is further trained on the specific target task of interest, which may have fewer labeled examples than the pretraining dataset. The pretraining step helps the network learn general features that can be reused on the target task.

The model implementation:

In this case, I have used the pretrained weights of the architecture DenseNet161, which was trained on the ImageNet dataset. I remove the last layer in charge of classifying specific classes in the ImageNet and add on top of the pretrained layers a different linear classifier according to the specific problem we are solving now. This last layer is trained on our specific task.

Training

The following step is to train the model for it to learn the proper weights for the added layer and also, the pretrained weights can be adjusted to our task. The code chunk below shows how to implement training and evaluation. I have implemented two different possible ways of training, the train method utilizes a fixed learning rate, while the train_one_cycle instead of using a fixed learning rate, uses a learning rate scheduler, which will change the learning rate after every batch of training.

We obtain the following learning stats by training the model implemented above with the rain_one_cycle. Details are shown below:

It can be seen that the model reaches a very high accuracy on the validation dataset. However, the accuracy decreases in the test dataset. This is a common issue in image processing tasks, it means our model has learned patterns in the training data but it is not able to generalize to new data. This can be produced by two main reasons: our model is suffering from overfitting or the test data is very different from the training data. That is why is useful to analyze the learning curves, which shows the model behaviour over the training phase in the training and validation data.

In this case, the closeness of the loss functions in the last training stage suggests that the model is not overfitted. So, a good idea to increase accuracy is to come up with different architectures, maybe deeper, and vary the hyperparameters. Also, data augmentation strategies can help the performance of the model. It is all about trying new tricks.

Results

Several models were implemented, trained, and evaluated over this dataset. The images below show the learning curves corresponding to the model presented in this post, which reached a higher performance compared to the others implemented.

From the table above, we can conclude that the approaches based on transfer learning reached higher performances than simpler ones. This demonstrates, once more, the importance of building models on top of pretrained architectures in visual recognition, especially when our dataset is relatively small. Though the performance is not bad, the model accuracy can be improved with some more tricky strategies :).

More details about the implementation can be found here: https://jovian.ml/claudiaqw/alzheimer-detection

The goal of this post is to show you how easily you can build deep models on PyTorch and how useful they are for solving everyday situations. You can make a great difference in the world by resolving real science problems like this one.

References:

--

--

Claudia Quintana Wong
The Startup

Computer Scientist | Professor at University of Havana