Predicting Alzeheimer’s disease Using U-Net Algorithm

A basic Tensorflow and Keras Implementation for beginners

Image for post
Image for post

What is Alzheimer’s Disease?

Alzheimer’s disease is an irreversible degeneration of the brain that causes disruptions in memory, cognition, personality, and other functions that eventually lead to death from complete brain failure.

Worldwide, at least 50 million people are believed to be living with Alzheimer’s disease or other dementias.

One in 10 people aged 65 and older has Alzheimer’s dementia. It is the sixth-leading cause of death in the United States. Alzheimer’s disease results in progressive loss of tissue throughout the brain.

It results in progressive loss of tissue throughout the brain. In particular, an area of the brain called the hippocampus tends to show the most rapid loss of tissue earliest in the disease course.

Image for post
Image for post
Figure 1: Hippocampus localisation in the human brain

The hippocampus is essential for forming new memories, such as what one ate for lunch or a recent conversation. The progressive shrinkage of the hippocampus is responsible for the short-term memory loss that is the hallmark symptom of Alzheimer’s.

To understand more about how Alzheimer changes the brain and the role of hippocampus as shown in the video :

Can Alzheimer be prevented ?

Referring to specialist’s interpretation, it is difficult to tell if a person will be diagnosed with Alzheimer disease based on the earliest symptoms. However, it can be predicted using the AI algorithm and sophisticated sensor-based tools.

Researchers have developed a computer algorithm based on Artificial Intelligence (AI) that can accurately predict the risk and diagnose Alzheimer’s disease using a combination of brain Magnetic Resonance Imaging (MRI).

The AI strategy, based on a deep learning algorithm, is a type of machine learning framework. Machine learning is an AI application that enables a computer to learn from data and improve from experience.

Predicting Alzheimer disease is worth building up an AI based systems and devices for the purpose of identifying victims and to accelerate the diagnosis procedures and make it as early as possible to process necessary treatment.

Using MRI scans

The researchers used MRI scans of the brain, demographics, and clinical information of individuals with Alzheimer’s disease as well as ones with normal brain neurological behavior. Then developed a novel deep learning model to predict Alzheimer’s disease risk, which showed that their model could accurately predict the disease status on the other independent cohorts.

Discovering how AI algorithm can treat the data and train the model to predict Alzheimer’s disease.

Alzheimer’s Disease Detection Using MRI and U-Net Architecture

Building an algorithm based on U-net architecture that can help us to detect tumor in lungs or brain.

This project will help you to understand more the Architecture of U-Net, how to build your model in order to visualise the Hippocampus and detect Alzheimer Using IRM.

The content of the project is listed as follows:

  • Dataset preparation and preprocessing
  • Model Building
  • Model Training
  • Model Evaluation and Testing
  • Model Prediction

Before we start developing our model, we must learn how to prepare and manipulate our dataset.

  1. Data Preprocessing

Data preprocessing is preparing (cleaning and organizing) data to adapt it the building and training models. In simple words, data preprocessing is a data mining technique that transforms raw data into an understandable and readable format. It helps to clean, format and organize the raw data thereby making it ready-to-go for the model.

In this project, there are four significant steps in data preprocessing to prepare our data.

1.1 Acquire the Dataset

The dataset can be downloaded from the Academic Torrents website, which can be found here :

In this project, we will use NIfTI (Neuroimaging Informatics Technology Inititive) format; it can store data with different meanings. Imaging data, statistical values and other data (any vector, matrix, label set or mesh) Can be saved in a nifti1 *.nii or *.hdr/*.img file.

MR Images: MR Images is a type of scan that uses strong magnetic fields and radio waves to produce detailed images of the inside of the body (Brain, breasts, heart and blood vessels and so on).

1.2 Importing libraries

First, we need to install the entire requirements modules; tensorflow, keras, os, matplotlib, numpy, as well as others libraries related to specifications of the project.

In addition, one of the most needed modules is “nibabel” in order to read “nifti” format.

Read more about Python libraries for Data Science here:

1.3 Loading and Read the dataset

At this stage, we need to:

Store the path to our images dataset into a variable using os module “os.path.join”

Import ‘nibabel’ and load dataset (Note: “nibabel” does not load the image array, it holds until the data array is requested using get_fdata () method).

1.4 Standardize images

A critical preprocessing step in computer vision. Principally, the models adopted train faster on smaller images. The time required adds up when the image is larger or more complicated. Moreover, many deep learning models architectures require the same size of collected images though it is not the case for majority of data aquired.

Establishing a base size for all images fed into AI algorithms resulting a data set to a minimum image size.

The minimum image size is set to 32,32,1(Width, Height, and Channel).

Many other preprocessing techniques can be used to get your data Images ready to train in your model. Removing the background color from the images reduces the noise. Other projects may require brightening or darkening the images. Using data Augmentation technique to enlarge dataset with perturbed versions of the existing (Scaling, rotations, De-colorized, De-texturized and so on). In short, any adjustments that is needed to apply dataset are considered a sort of preprocessing. In addition, selecting the appropriate processing techniques based on the dataset and the solutions which builds intuition of which ones needed when working on different projects.

1.5 Splitting dataset

The final step is split dataset into two separates sets: Training sets and Test sets.

Image for post
Image for post
Figure 2 : Dataset Splitting

The “Train set” is used to train the model and the “test set” is used to test and evaluate the model.

Usually, the dataset is split into 70% train set and 30% test set or 80% train set and 20% test set.

In the code, the data is split using sklearn “from sklearn.model_selection import train_test_split”

2. Model Building

Using a U-net algorithm to build the model.

Image for post
Image for post
Figure 3 : U-net architecture (example for 32x32 pixels in the lowest resolution).

How U-Net Works ?

It consists of a contracting path (left side) and an expansive path (right side).

Contracting/down sampling path

The contracting path follows the typical architecture of a convolutional network. It is composed of four blocks, each block is composed of:

· 3x3 Convolution Layer + activation function (with batch normalization)

· 3x3 Convolution Layer + activation function (with batch normalization)

· 2x2 Max Pooling

NB: the number of feature maps doubles at each pooling, starting with 64 feature maps for the first block, 128 for the second, and so on. The purpose of this contracting path is to capture the context of the input image in order to be able to do segmentation. This coarse contextual information will then be transfered to the upsampling path by means of skip connections.


This part of the network is between the contracting and expanding paths. The bottleneck is built from simply two convolutional layers (with batch normalization), with dropout.

Expanding/up sampling path

The expanding path is also composed of four blocks. Each of these blocks is composed of:

· Deconvolution layer with stride 2

· Concatenation with the corresponding cropped feature map from the contracting path

· 3x3 Convolution layer + activation function (with batch normalization)

· 3x3 Convolution layer + activation function (with batch normalization)

At the final layer a 1x1 convolution is used to map each 64-component feature vector to the desired number of classes. In total, the network has 23 convolutional layers.

Image for post
Image for post
Image for post
Image for post

· Red Box → Representing the Input Images size

· Blue Box → Representing the left side of U-Net (Contraction path)

· Green Box → Representing the Right side of U-Net (Expensive path)

· Yellow Box → Final Bottle neck layer

The final step in the Model building is to fix the optimizer and the cost function used in the model.

The different type of Optimizer is possible to be applied at this stage:

· Cost function

The Objective of Machine learning and deep learning is to reduce the difference between the predict output and the actual output. This is also called as a Cost function or loss function.

As our goal is to minimize the cost function by finding the optimized value for weights.

To achieve this we run multiple iterations with different weights. This helps to find the minimum cost. This is Gradient descent.

A Gradient descent is an iterative machine learning optimization algorithm to reduce the cost function. This will help models to make accurate predictions.

We calculate the gradient, ∂c/∂ω, which is a partial derivative of cost with respect to weight.

α is learning rate, helps adjust the weights with respect to gradient descent

Image for post
Image for post
Figure 4 : Gradient Descent

W is the weights for the neurons, α is learning rate, C is the cost and ∂c/∂ω is the gradient

What is learning rate?

Learning rate is probably the most important aspect of gradient descent and other optimizers as well.

It controls how much the weights should be adjusted with respect to the loss gradient. Learning rates are randomly initialized.

Next goal is to minimize the cost function to find the optimized value for weights by running multiple iterations with different weights and calculate the cost to arrive at a minimum cost as shown below.

The picture bellow explains the Optimization curve

Image for post
Image for post
Figure 5 : Optimization Curve

The different types of Gradient descents are:

· Batch Gradient Descent or Vanilla Gradient Descent

· Stochastic Gradient Descent

· Mini batch Gradient Descent

To know more about Gradient Decent, it is recommended to follow the videos made by “Andrew Ng”


The objective of all optimizers is to change the attributes of your model such as weight and learning rate in order to reduce the losses.

Other types of optimizers based on gradient descent that are used though, and here are a few of them:

· Adagrad: adapts the learning rate specifically to individual features: that means that some of the weights in your dataset will have different learning rates than others.

· RMSprop: In RMSprop, learning rate is adjusted automatically and it chooses a different learning rate for each parameter.

· Adam: is a stochastic gradient descent method that is based on adaptive learning rate for each parameter from estimates of first and second moments of the gradients.

Comparison between various optimizers

Image for post
Image for post
Figure 6 : Comparison between various optimizers

According to the comparison above, “Adam” is a quite computationally efficient optimizer also it requires less memory space.

For that,using “Adam” as one of the most popular gradient descend optimization algorithms grants multiple advantages including the easy to implement optimization model.

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

Additionally, it is necessary to configure the specifications as follows:

Loss: pneumonia detection is using sigmoid activation in the final step, which resulted in either 1 or 0 (demented or not demented). Therefore, binary_crossentropy is the most suitable loss function

metrics: accuracy is the measurement metric to obtain the prediction accuracy rate on every epoch.

3. Model Training

At this step, use the data to incrementally improve our model’s ability to predict whether the person is demented or not.

#Fit the model
history =,
batch_size = 8,
epochs = 50,
validation_split = 0.2)

The model trains for 50 epochs with a batch size of 8.

The batch size is a number of samples (sample is a single row of data) processed before the model is updated.

Epochs is the number of complete passes through the training dataset.

About the two main parameters: Epoch & Batch size, Epoch is when an Entire dataset passed forward and backward through the Neural Network only once. The difficulty is how to choose the suitable number of epochs in the application in order to avoid the underfitting and Overfitting problems.

Image for post
Image for post
Figure 7 : Underfitting & Overfitting problems

Note: It is not possible to pass the entire dataset into the neural net at once. Therefore, it is preferred divide the dataset into Number of Batches or sets.

Once the model is completely trained, the visualization step of the loss and accuracy plots can lead to interpretations about the model efficiency.

4. Model Evaluation and Testing

After the training, you can see the the validation loss and the training loss in order to evaluate the model.

# Accuracy calculation
f, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))
t = f.suptitle('U-NET Performance', fontsize=12)
f.subplots_adjust(top=0.85, wspace=0.3)
max_epoch = len(history.history['accuracy'])+1
epoch_list = list(range(1,max_epoch))
ax1.plot(epoch_list, history.history['accuracy'], label='Train Accuracy')
ax1.plot(epoch_list, history.history['val_accuracy'], label='Validation Accuracy')
ax1.set_xticks(np.arange(1, max_epoch, 5))
ax1.set_ylabel('Accuracy Value')
l1 = ax1.legend(loc="best")
Image for post
Image for post
Figure 8 : Accuracy and Loss value

As shown, the validation loss and the training loss both are in synchronization . It shows that the model is not overfitting: the validation loss is decreasing, and there is rarely any gap between training and validation loss throughout the training phase.

The accuracy of the trained model reaches 95.24 % as can be seen in the figure below:

Image for post
Image for post
Figure 10 : Accuracy value

Finally yet importantly, it is time to reconstruct the test images using the predict () function of Keras and see how well your model is able to reconstruct on the test data.

Image for post
Image for post
Image for post
Image for post
Figure 11 : Mask Images

From the above figures, you remark that the model did a great job in reconstructing the mask images that you predicted using the model.

5. Save the Model

Let’s now save the trained model. It is an important step when you are working with Deep Learning.

#save model
model_save_path = os.path.join(path_main,'model'), model_save_path)

You can anytime load the saved weights in the same model and train it from where your training stopped within just one line of code :

#load model
model = keras.models.load_model(model_save_path)

6. Model Prediction

During this step, classifier labels dataset into two categories demented and not demented for it is recommended to use an unsupervised algorithm (K-mean clustering) in order to subvise our labels.

So, first what are Clustering Algorithm?

Clustering algorithms are unsupervised algorithms but are similar to Classification algorithms but the basis is different.

K-Means clustering algorithm is an unsupervised algorithm and it is used to segment the interest area from the background. It clusters, or partitions the given data into K-clusters or parts based on the K-centroids.

The objective is to classify the data into 2 classes (Demented or not Demented) using this algorithm to allow the model to be able to classifier a new dataset and detect the anomaly.

Image for post
Image for post
Figure 12 : Demented and not Demented Images

Go Further !

This tutoriel can help you to undrestand how to read MRI nifti format images, analyse, preprocess and feed them into the model using a brain MRI dataset. It showed you one of the popular CNN algorithm : U-Net algorithm.

You can modifie the architecture and try improving the predictions both quantitatively and qualitatively, you can use data Augmentation technique to increase the size of dataset and also you can choose another classifier to divide and separate the dataset into to classes.

So, you are free to choose any architecture and any classifier.

How AI is transforming the future of healthcare

Image for post
Image for post
Figure 13 : AI in the healthcare field

AI in medicine refers to the use of artificial intelligence technology / automated processes in the diagnosis and treatment of patients who require care. Whilst diagnosis and treatment may seem like simple steps, there are many other background processes that must take place in order for a patient to be properly taken care of, for example:

· Gathering of data through patient interviews and tests

· Processing and analysing results

· Using multiple sources of data to come to an accurate diagnosis

· Determining an appropriate treatment method (often presenting options)

· Preparing and administering the chosen treatment method

· Patient monitoring, etc.

In conclusion, then, whilst it’s unlikely that machines will replace or eradicate the need for human doctors any time soon, those already in or considering a medical profession should be willing to adapt, learn and grow alongside technological advancements.

Bill Gates (1996) “We always overestimate the change that will occur in the next two years and underestimate the change that will occur in the next ten.”

I have tried to keep the article as exact and clear as possible.

Your comments, questions and feedback are welcome.

Contact me on : Github, Medium, Linkedin, Twitter, Gmail

Hope will be helpful !

See you in the next tutoriel !

Written by

Microelectronic engineer student| Coach Kids’ Robotic | Interested in Embedded system field, AI & Machine learning field

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store