A machine learning survival kit for doctors
Or: how to estimate the age of your brain with MRI data and Artificial Intelligence
A 30 to 60 minute crash course to understand artificial intelligence in practice… and its pitfalls.
Artificial Intelligence (AI) is on everyone’s lips today, and healthcare is one of the industries that raises the highest hopes regarding its potential benefits. Will AI eventually replace medical staff completely? Or will it allow practitioners to focus on more interesting, value-added tasks? No one knows what AI’s exact role and place in the care pathway will be.
Physicians and researchers will not become programmers or data scientists overnight, nor will they be replaced by them. But, they will need an understanding of what AI actually is, how it works, what it is able to achieve today, what is still out of reach, and importantly, who are data scientists and how they work; this is the purpose of this article. You will understand the fundamental concepts of machine learning, such as cross-validation and overfitting, and the most common difficulties and pitfalls in practice.
We demonstrate the process of creating an algorithm able to estimate the physiological age of the brain of a subject based on magnetic resonance imaging (MRI) data. You will learn how to create state-of-the-art algorithms, from entry-level linear regression to advanced deep neural networks that automatically rediscover known features of brain aging, such as cortical atrophy and leukoariosis. Do not worry if you are not a radiologist, this article should be accessible to everyone. From a scientific point of view, identifying the physiological age of the brain could have an impact on the understanding of neurodegenerative diseases. We show at the end of this article that the brain age — as estimated by our algorithm — of MRIs of patients diagnosed with Alzheimer’s disease is around 6 years older than what they “should” be.
We also publish a Colab notebook accessible with this link, where you can execute code to reproduce some of the experiments and get a better feeling of what data science is in practice.
We have compiled a lot of knowledge about data science and machine learning in this article, along with an in depth case study, so find a comfy chair, take your time, and welcome to this dive into the world of AI.
At Owkin we are always looking for new exciting projects, at the frontier of medicine and machine learning. Do not hesitate to reach us at email@example.com.
In a recent post, Michael Jordan, a leading figure in machine learning (no, not the basketball player) tries to demystify the actual hype around AI:
Most of what is being called “AI” today, particularly in the public sphere, is what has been called “Machine Learning” for the past several decades. Machine Learning is an algorithmic field that blends ideas from statistics, computer science and many other disciplines to design algorithms that process data, make predictions and help make decisions.
In fact, brain age estimation can be treated as a supervised machine learning problem, a problem in which data scientists excel. The goal of any supervised machine learning problem is to create an algorithm able to output a value (Y) — in our case the age — given some inputs (X) — in our case the MRI. The key aspect of machine learning algorithms is that they are “trained” with real-world data, and not designed according expert-defined rules. If you are able to cast a medical question into an X and Y, it makes you half a data scientist.
When confronted with such problems, data scientists will always take a similar approach, whatever the X and Y:
- Get the data and clean it
- Analyze the data and extract features relevant to the problem
- Design a validation strategy
- Train an algorithm on the data, analyze the errors and interpret the results
- Iterate until the algorithm is best performing
We will go through these steps one by one, describing data-science best practices, highlighting the most useful tools in machine learning, and showing the most common mistakes you can find (even in published literature!)
Hardware, Software, Knowledge
Before talking about Machine Learning, we should talk about machines. While you can start with any laptop for projects involving only a few megabytes of Excel spreadsheets, our project would be very painful to handle without an adequate hardware setup, due to the size of the dataset and its complexity (here 3D images):
As such, a good machine for this project would have:
- A SSD disk to store and quickly load the data: at the end of the project we had ~500GB of data
- A lot of RAM (e.g. 128GB) to perform in-memory operations on large chunks of the data
- Many processor cores (e.g. 32 cores) to benefit from parallelization, and
- At least one high-performance graphics card (GPU). GPUs have proved to be crucial for deep learning, a fascinating subfield of machine learning that highly contributed to the current buzz around AI
On the software side, we used the Python programming language. There is an historical competition in data science between Python and R, but if R is widely used by statisticians, Python now has the favors of the machine learning community.
Python is an open source language, that can be extended through packages, which are easily installable, and contain a lot of useful functions. Prominent examples of such packages in Machine Learning are Scikit Learn, XGboost and Tensorflow, which are used by every data scientist and supported by major stakeholders in both academy and industry. Packages are constantly created, improved and updated with the help of a very active community through collaboration tools, such as GitHub. For instance, the Scikit Learn repository, shows 1000+ issues currently reported, 4000 closed issues, as well as roughly 6000 code modification since its inception!
To get a first glimpse of how Python works, we invite you to execute the following lines of code in the Colab notebook
Finally, to be perfectly ready, you need … skills in Machine Learning. Data science requires to master both theoretical aspects (applied mathematics, statistics, algorithms) and practical aspects (programming, database management). Data scientists can thus come from very different backgrounds: they generally studied in a mathematical or computer science related field, but it frequently happens that researchers, consultants or other analytical corporate roles, wanting to develop machine learning tools in their research or work, pivoted into the field completely and became data scientists.
As a research community, machine learning practitioners are in line with the open science mindset: they boycott classic peer reviewed journals and prefer open conferences such as NIPS, ICML or ICLR, they publicly publish their work on arxiv, comment it on reddit, and often share their code on GitHub.
For this project, we used two publicly available and anonymized datasets of healthy subjects. The first, let’s call it Dataset A, was collected in three different London hospitals and contains data from nearly 600 subjects. The second, Dataset B, contains data from more than 1,200 subjects from 25 hospitals across the US, China and Germany. A single MRI may consist of multiple images representing different physical properties (T1, T2, FLAIR, DTI, etc.), called sequences . In our experiments, we only use the most common of these: the T1 anatomical sequence.
This part was easy for us, as the brain image datasets we use were already collected and curated, as well as usable, from a legal and regulatory perspective. Also medical imaging benefits from having a standard format (DICOM), which is not the case yet for EHR, genomics data or digital pathology. Medical dataset compilation is, as we all know, the most difficult task for physicians and researchers. Having an open, precise discussion with all involved parties on what data is needed, how to access it, and how to ensure patient privacy will improve the chances of compiling a usable dataset.
Any data science project begins with data cleaning. It is not the most exciting task, and may be very time consuming, especially for retrospective studies where data acquisition protocols and quality controls were not set up to answer a specific question. To clean the data, one first needs to explore it.
Dataset A consists of a spreadsheet with demographic information (including patient age) and a large zipped folder containing all the MRIs. The images are stored in a format called NIfTI, a popular alternative to DICOM. Each subject is supposed to have an ID, that can be found both in the spreadsheet and within the MRI filename. However, looking at the data more closely, many issues become apparent.
- Some subjects have a NIfTI file but do not appear in the spreadsheet.
- Conversely, some subjects appear in the spreadsheet but have no associated NIfTI file.
- Some IDs are duplicated within the spreadsheet: two subjects have the same ID, but different age, sex, height, etc. It is therefore impossible to know which one to associate with the NIfTI file.
- The name of the hospital of origin is not in the spreadsheet but hidden in the name of the NIfTI file.
- Finally, and more difficult to spot, some NIfTI files are problematic: images may recorded at very low resolution, are partially or totally cropped, etc.
It may seem like a mess, but all in all, this dataset is actually not particularly messy at all. Such errors and inconsistencies are very common for medical datasets. These errors are often caused by the transition of data through complex, and not necessarily compatible, data systems, manual data modification, multi-centric studies, and so on. We encountered similar issues with Dataset B, as well…welcome to the world of dataset cleaning! At the end of the day, we could obtain 563 “clean” subjects from Dataset A (down from 600, 94%) and 1034 subjects for Dataset B (down from 1200, 86%), which is still a sizable dataset. We created a unique spreadsheet with 1597 lines and 5 columns: an ID, the age and the gender of the subject, the MRI’s hospital of origin, and the path to the associated NIfTI file. You can explore the spreadsheet in our Colab notebook.
We can observe the following demographics in the subject cohort: 55% are women, the youngest is 18 years old (we removed children from the datasets), the oldest one is 87 years old, with quartiles at 22, 27, and 48 years of age.
So, you think we are ready to speak about machine learning? Well…not yet! When opening the NiFTI files containing the MRIs, we observe a very high heterogeneity among the images: resolutions, voxel values (a voxel is a 3D pixel), field of view, orientation, etc. Therefore, some manner of normalization is needed so that the images are actually comparable one to another. We would face the same problem with any type of data, not just MRIs!
Fortunately for us, the neuroscience community has developed a set of software tools to normalize brain MRIs. We decided to use ANTs, as well as some home-made Python scripts, to perform the following normalization steps on each MRI.
- Resample. We set the resolution to 1 mm³ per voxel.
- N4 bias field correction. This operation allows to remove non-biological signal artifacts on the image that are due to the magnetic field of the MRI. You can see an example here.
- Co-registration to a unique brain template. This operation transforms all the images to have the same orientation, field of view, center of mass, etc. As a brain template, we used the standard MNI152, which is the “average” image of 152 brains oriented for optimal readability.
- Skull-stripping. We remove the skull bone from the image and keep only the brain. Why? Simply because we want to estimate the age of the brain, not the age of a whole head!
- Voxel Intensity Normalization. In an MRI, the intensity value of a voxel is arbitrary (from 0 to 100, or from 0 to 100000…); only contrast matters. We used a popular technique called white stripe normalization. The idea is to detect the white matter intensity value and set this value to 1.
All these operations are computationally expensive, requiring up to 5 minutes per subject…around 5 days in total to preprocess the entire dataset! Parallelizing these operations allowed us to “only” wait for a night. We now have clean and comparable images with a clean spreadsheet of MRI metadata. Clear? Yes, finally! But before we design an algorithm able to estimate the age of the brain, we must first discuss…
What do doctors know about brain aging ?
In fact, not that much! It is impossible for physicians to determine the precise age of the subject from the brain image alone. However, radiologists do know how to find anatomical features associated with normal brain aging on an MRI.
Three main features are associated with the aging process, and all are visible on a T1 MRI sequence.
- Atrophy, a decrease of the thickness of grey matter.
- Leukoaraiosis, which appears as white matter hypointensities in the T1 sequence.
- Ventricle dilation, as consequence of atrophy, there is a buildup of cerebrospinal fluid in the brain ventricles.
To simplify these explanations, we decided to restrict ourselves to the T1 anatomical sequence. However, in clinical practice, radiologists assess brain aging on more than just one MRI sequence. For instance, leukoaraiosis is assessed on a T2 sequence, such as FLAIR, where it is more contrasted and appears as white matter hyperintensities. There are other sequences that can help detect structural features associated with aging, e.g. T2* for studying cerebral microbleeds (small hemorrhages in the brain). In the future, features from functional MRI (fMRI) scans, such as connectivity and network integrity, may also prove to be relevant for brain aging studies.
Several works have already investigated the brain age prediction task from anatomical data, as well as its link with brain disorders or genetics. James Cole, a research fellow at King’s College London, has written an excellent series of papers on the topic, of which Cole et al., 2017 is the most similar to our work. A much larger study (Kaufmann et al., 2018) on around 37,000 patients is currently under review. Both works are nicely covered in this article from Quanta magazine (August, 2018).
We are ready to start with algorithms! From here things will be getting a bit more difficult, so make sure you have enough time to go through to the end :).
Defining the Problem: Predicting Y from X
A good practice in machine learning is to start with a simple baseline algorithm to get a grasp of the complexity of your problem. So, we first decided not to use the whole 3D MRIs, but a simpler reflection of their content, in the form of their histogram of voxel intensities.
For a computer, an MRI is a 3D grid of values — the voxel grayscale intensity values — where low and high values are often represented as black and white, respectively. The histogram of an image is the histogram of these values. Another way to put it: it is the count of voxels that have a grayscale value within a given range. A typical histogram is shown below. On the x-axis you can see the grayscale values ranging from 0 to 1, and on the y-axis the total count of voxels for each value. Values have been grouped in small intervals, visually represented by the columns in the figure, called bins. We set the number of bins to 200, meaning the interval [0, 1] is split into 200 equally sized intervals. See our Colab notebook to create your own histogram.
The idea is that homogeneous tissues have similar grayscale values. So to determine the quantity and proportion of grey matter and white matter, respectively, you can count the number of voxels that have similar values. Indeed, you can see on the figure that we identify 2 peaks, from left to right: grey matter (low intensity values) and white matter (high intensity values). We know that brain aging is correlated with grey matter atrophy… and grey matter quantity is related with the dimension of the first peak. Let’s explore this further!
Going back to our X and Y notations, we replaced our X with vectors (just a sequence of numbers) of length 200. In mathematical notation, for a single MRI, X= [X₁,X₂, …, X₂₀₀] where Xᵢ is the number of voxels in the i-th bin. We call this concise description of the MRI a feature vector. The main concept of machine learning algorithms is that there are not a sequence of human hand-crafted rules that indicate how to go from X to Y, but they can “learn” themselves these rules using data, i.e. a lot of examples of X and Y. This is one of the main difference between Deep Blue, the IBM chess-playing computer, and the recent AlphaGo, the first computer program able to defeat a professional Go player.
To be more precise, “training” an algorithm means searching for a function F, the algorithm, so that F(X), the prediction of the algorithm, is a close as possible to Y, the true value, for all the pairs (X, Y) of the dataset. In practice, we search F within a large family of functions (linear functions, decision trees, neural networks, etc.) and chose the one that has the minimal average error between the prediction F(X) and the true value of Y. In short, “learning” is nothing but an optimization problem: minimizing an error. In our case, we have 1597 pairs of the form (X = histogram, Y = age), and we try to minimize the absolute error |F(X) — Y|. For example, if the prediction F(X) is 23 years of age, and the true value of Y is 21 years of age, the absolute error is 2 years. If you understand this paragraph, you understand supervised machine learning.
Cross-validation: splitting training and test sets
We now arrive at one of the most crucial steps of the machine learning workflow : how can we evaluate the efficiency of an algorithm ? The fundamental machine-learning technique addressing this question is called cross-validation. We randomly split our dataset in two parts: a training set and a test set (also called validation set). The training set is used to train the algorithm, and the test set is only used to compute its performance. The idea is to evaluate how well the algorithm generalizes on new data it never saw during training, making the test set a proxy to evaluate the performance on new real world data. Each time we report an absolute error in this article, it is the error on the test set, and never on the training set.
A difficult question we have to ask ourselves is how to determine the relative proportions of the training set and the test set. 50/50? 75/25? 90/10? The larger the training set, the better the algorithm because it is trained on more data. However, a larger test set implies more reliable performance, as the data becomes more representative of the real world.
We used a procedure called k-fold cross-validation: we randomly split the dataset in 5 chunks, or folds, of equal size, and repeat the training 5 times, each time taking a different fold as the test set and the four others as train set (equivalent to a 80/20 split). This forces us to train and test the algorithm 5 times, which can take hours, but it makes the results much more reliable. However, since these fold evaluations are independent, evaluating the cross-validation performance parallelizes easily (if you have the computational resources available).
In medical studies cross-validation is often not enough. Commonly, results reported in papers are over-estimated, as one may iterate many times over the same dataset to obtain the best cross-validated performance. Therefore, it is best practice to validate an algorithm on an external independent dataset, provided by another institution or hospital. Lack of transferability from an hospital to another, or from a population to another, may be a major flaw of machine learning algorithms, but so far there is no equivalent of clinical trials for AI algorithms imposed by the FDA or equivalents.
Finding the right model
There exist a large variety models, i.e. families of functions F, and choosing the right one for your problem of interest requires some experience. Again, a good practice is to begin with the simplest effective method, and for this, linear models are always a good baseline.
In a linear model, the prediction F(X) is a weighted sum of the values of X: F(X) = (W₁ * X₁) + (W₂ * X₂) + … + (W₂₀₀ * X₂₀₀) + β, where Wᵢ is a “weight” associated to the value of feature Xᵢ, and β is some constant additive term often called the bias. For instance, if X₁₂₀ corresponds to the bin of the grey matter peak, one possible (very) simple linear model that correlates the predicted age with the amount of grey matter in the MRI might be F(X) = 100+(-10 * X₁₂₀) , where β = 100, W₁₂₀ = -10, and all other Wᵢ = 0. This function seems to make sense: the more grey matter (large X₁₂₀), the lower the age.
In the case of linear regression, simple algebra performed on the training data will help us find the best possible weights Wᵢ. We used the Scikit Learn Python package to train our linear models over our 5 cross-validation subset folds. Try it yourself in our Colab notebook!
Training a linear model on the MRI intensity histogram feature vectors gives us mean absolute errors (in years) of 8.49, 9.53, 9.29, 8.89, and 9.22 on each of the 5 folds, so an average error of 9.08 years. It is not great…but not that bad either! With this extremely simple algorithm, we can predict from brain scans whether a patient is younger or older than 50 years of age with an accuracy of 84%.
The hypothesis of a linear relation between an MRI histogram and the age is, of course, simplistic. An algorithm taking into account a non-linear relationship may be able to provide more accurate predictions. Gradient tree boosting is one of the most popular and efficient non-linear choices for F(X). Gradient boosted trees are a sequence of decision trees iteratively built to minimize the error. You can find a more complete introduction here.
While deep neural networks, and their ability to tackle complex tasks, are nowadays oft-publicized in popular science articles and well-known even outside the machine learning community, gradient tree boosting remains less well-known outside of the data science community. Gradient boosting trees are often a key part of the winning solutions in international data science competitions organized on platforms such as Kaggle or Dream, and often are extremely hard to beat.
Using the CatBoost Python library, we get a much better result: our mean absolute error can be reduced to just 5.71 years, which is already closer to state-of-the art performance (4.16 years as reported in Cole et al. 2017). We may be tempted to stop here, publish our algorithm and validate it in clinics. But wait! We are making a large mistake which is, unfortunately, quite common… even in peer-reviewed litterature.
Avoiding a common pitfall
As mentioned earlier, the range of intensities of the voxels in MRIs has no biological meaning and varies greatly from one MRI scanner to another. In the cross-validation procedure, we randomly split subjects between the training and the test set. So, for each hospital, there will be on average 80% of images from this hospital in the training set and 20% in the test set. But what would be the consequence of focusing on randomizing not the subjects but the hospitals, and consequently the MRI scanner? In this hospital split setting, the test set would contain not just new patients, but data from scanners never seen during training.
Once we split by hospital, the mean absolute error of our linear regression and gradient tree boosting models increase by around 5 and 6 years, up to 14.22 and 11.52 years error. Even a naive algorithm, F(X)=27, that predicts all MRIs to be from subjects of 27 years of age (the median age of the dataset) would get an average error of around 14 years. Our trained algorithms do not seem to perform much better than reporting the median of the dataset in the case of the hospital split. Somehow, in the random split setting, our algorithms must have been able to use the fact that a subject came from a given hospital to accurately predict its age.
The following figures show that a more careful data analysis would have prevented us from making such a mistake. You can see on this first figure that the distribution of the ages per hospital: most hospitals have a bias in their recruitment. Some have only young subjects in their datasets, some others only old ones.
In the next figure, we show the averaged the histograms of the subjects in each hospital, with each curve representing a different hospital.
And here we get our answer: while the white matter peaks are quite aligned across MRIs sourced from different hospitals, the grey matter peaks are spread wide from hospital to hospital due to the use of different scanners. Because of this feature, it is was quite easy for the algorithms to 1) detect the source hospital using the histogram and 2) use this information to constrain the age prediction to the range of ages recorded in the dataset provided by this hospital.
To remove this effect, we decided to go back on the last step of our preprocessing pipeline: the intensity normalization procedure. We moved from the white stripe normalization (method v1), which only fixes the white matter peak, to a new home-made method (method v2) that additionally fixes the grey matter peak. So now, the white matter and grey matter signals are centered on fixed values, and indeed the per-hospital average histograms now look much better, as shown in the next figure.
As expected, when we re-run the linear and non-linear models, we get much better results when using cross-validation over hospital splits. We still observe that the non-linear gradient boosting models are more powerful than the baseline linear regression.
Note that looking at a random split is not necessarily irrelevant. After all, you may be happy to have an algorithm which is able to take into account the specificities of the scanner the image comes from, as long as it does not integrate selection bias from the hospital.
- Spend time to analyze your data.
- Begin with simple approaches as baselines.
- Non-linear models can be a powerful tool if used properly.
- Be very careful with cross-validation in multi-centric studies, when you have several samples per patient, or when you have a small sample size. Best practice is to have an external and independent validation set. Splitting data is not always as easy as it sounds…
Going further with tissue segmentation
By reducing the entire MRI to a histogram, we lost the opportunity to use any spatial information about the structure of the brain. As a next step towards a more performant and interpretable algorithm, we used another software package, FSL FAST, to segment each MRI into grey matter, white matter, and cerebrospinal fluid (CSF). The segmentation is based on the voxel values and yields convincing results that you can explore in our Colab notebook. Once again, we changed our feature vector X. It is now no longer an encoding of the proportions of voxel intensities (200 features/sequence), but rather a 3D grid of 4 possible values: 0 for the voxels in the background, 1 for CSF, 2 for grey matter, and 3 for white matter.
Based on these segmentations, we can compute a precise grey matter volume per subject, and compare it to the subject’s age. As expected, the figure below shows a negative correlation between the two values (Pearson correlation r=-0.75), i.e. the less grey matter, the older the subject.
We can go beyond this simple scatter plot, and compare the local correlation (a 1 cm³ ball centered around each voxel) of grey matter volume with age. Again, we see on the figure below a negative correlation in the cortex, but this analysis also yields quite surprising results. There exist regions where grey matter seems to be positively correlated with age, meaning that an older subject would have more grey matter in certain zones…seemingly impossible. Actually, a radiologist would recognize these regions: they show “periventricular repartition of leukoaraiosis,” which is a lesion of this specific region of the brain. The FSL FAST software we used was not able to differentiate it from grey matter, because they are both hypointense in the MRI T1 sequence. While an error, this analysis actually shows us a zone that we could use for better predicting age. And if these lesions had not been known before, this “error” could have led to their discovery!
Generalization and overfitting
This new X allows us to introduce two of the most important notions in machine learning: generalization and over-/under-fitting. Generalization is the ability of an algorithm to maintain its predictive performance on unseen data. The cross-validation process we described earlier is one way to evaluate how an algorithm generalizes from the training data to the test data.
Fit denotes how well the model describes the data. If the performance of the algorithm is low on the training set, and low on the test set, the model under-fits: the algorithm is too simple to describe the data and its predictions are far from reality. In the other extreme, if the performance is high on the training set, but low on the test set, the model overfits: it fails to generalize on new data. Usually, this implies that the model is too complex for the available amount of training data. You can find a very famous plot that explains overfitting here, but we can illustrate it with our own dataset.
Indeed, our new X is very prone to overfitting for a simple reason: an MRI commonly contains around 8 millions voxels each… and we have only 1597 subjects! It is very likely that a complex algorithm will find spurious correlations between these 8 millions values and the 1597 ages of the subjects. In simpler words, it is almost certain that the algorithm will find at least one pattern with no biological meaning that will help it to determine age within the training dataset… but will not generalize on the test set.
There are different strategies to fight overfitting: add more data (increase the sample size), reducing the complexity of the algorithm, adding constraints and regularization to penalize model complexity during training, removing useless features, stopping the training early if overfitting is observed on a held-out set of training data, etc. In the case of brain scans, we have a simple lever to reduce the complexity of the model: we can downsample the resolution of the MRI to obtain a smaller image with less voxels. By lowering the resolution of the image, and thus the size of the feature vector X, we reduce the probability of finding spurious correlations.
Below, we show the mean absolute error of a linear regression model trained on MRIs downsampled with different downscaling factors. A downscaling value of 10 means we divided the resolution by a factor 10 and thus the number of voxels in X by a factor 10³. When the downscaling factor is too low, on left, there are too many voxels: the algorithms overfits. When the downscaling factor is too high, on the right, the MRI is too blurry: the algorithm does not get enough information from the image anymore and underfits.
By reducing the number of voxels in X, we observe much better results than with histograms: an MAE of 4.65 years with the random-split cross-validation and an MAE of 6.11 years for the hospital-split cross-validation. The smaller drop in performance between the random and the hospital split (1.46 years compared to 2.77 years for histograms) can be explained easily: we completely removed variability of the grayscale values as now each voxel can only be associated to 4 values: background, white matter, grey matter or CSF.
You can observe that gradient tree boosting did not perform much better than linear regression on this new X. In fact, decision trees struggle to deal with such high dimensional data, and have been eclipsed in the field of computer vision by deep learning, which we will explore in the next chapter.
How much data do I need?
This question is extremely common, and unfortunately the best answer in practice is generally: it depends. Data scientists commonly work with thousands to millions of data points, and they often forget that in clinical studies each data point is a real person, and that recruiting even one hundred patients can be extremely difficult.
A frequent answer to this question is “Well I don’t know, but the more data you have, the better the performance will be.” Keeping the properties of the data fixed, more data is always better than less data. How much more performance each new piece of data brings is a harder question to answer. We illustrate this effect below by training multiple linear regressions with an increasing number of subjects in the training set.
This is the last step of our journey. So far, we have used data-manipulation techniques to extract features we thought could be relevant to age prediction: voxel intensity histograms and tissue segmentation. Deep learning takes another approach, using a family of functions called neural networks, which work directly on the raw data. They are able to identify the most relevant features for a given task without human input.
In 2012, neural networks equipped with a special mathematical operation called convolution, set a milestone in the ImageNet large-scale image classification challenge, surpassing previous attempts by a wide margin. The associated paper (Krizhevsky et al., 2012) has already been cited more than 27,000 times and has radically impacted all the areas of computer vision. Since then, deep learning and neural networks have been game changers in many other computer science fields such as speech recognition, natural language understanding, and reinforcement learning.
Convolutional Neural Networks (CNN) are built from sequences of elements called layers, which consist of convolutions and non-linearities. Their architecture can be complex, but as before, CNNs are just mathematical functions which are optimized to minimize an error. In non-mathematical terms, a CNN’s organization shows some similarities with the structure of the visual cortex of mammals. In the superficial layers (i.e. close to the input), basic visual information is processed and extracted (such as shapes, edges, etc.), while deeper layers (i.e. close to the output) reflect higher level information (such as tissue type, brain areas, gender, etc.).
We don’t aim to provide you much more understanding of neural networks in this article. However, you can play around with a simple live neural network here, or with a simple CNN for handwritten digit classification here. If you want to go further, you can read Michael Nielsen’s online book, or follow the famous online course on machine learning by Andrew Ng from Stanford University which is rated 4.9/5 by … 80,000 students!
Designing the architecture of a CNN (the number of layers, the structure of convolutions, the kind of non-linearities used, etc.) and training the model is much more difficult and time consuming than for linear models. In fact, there are so many possibilities and options to explore that researchers are now working on automatic machine learning, casting the problem of finding the best architecture… as a learning problem itself!
In this work, we simplified the problem by reducing each MRI from around 200 images in the axial dimension to only 10 images. These images correspond to a 1 cm (10 * 1mm) axial zone at the level of ventricles, where both atrophy, ventricle dilation and leukoaraiosis can potentially be detected. To accelerate the training, we used a high performance graphics card (GTX-1080 Ti) and the Tensorflow Python library. Training a single CNN, even with only 10 images per subject, takes anywhere from one to three hours, and tuning the parameters can take days. In comparison, training linear models on the voxel intensity histograms required only a few seconds.
After spending a few hours searching for a good baseline architecture, we finally designed a simple CNN with 10 convolutional layers and around 5M parameters, obtaining a mean absolute error of 4.57 years on the random split, and 6.94 years on the hospital split. We then improved our models using a common trick called data augmentation. Data augmentation is a simple approach to simulate more data than you have by slightly deforming your dataset, adding small distortions through rotation, zoom, or blur on every image. To make our network more robust to scanner variability, we also slightly varied voxel intensities across parts of the MRIs. Indeed, with this data augmentation method we achieved a mean absolute error of 4.27 and 6.14 years on the random and hospital splits, respectively. You can load the trained model and use it in our Colab notebook.
The current state-of-the-art CNNs have been trained on datasets consisting of millions and rich sets of meta-data used as potential prediction targets. While these pre-trained CNNs have never seen an MRI during their training, they have learned to recognize very complex patterns within images. This allows them to distinguish image characteristics easily, even for extremely complex problems such as predicting dog breed or vehicle make. Transfer learning is a widely-used technique that consists in fine-tuning such pre-trained CNNs on completely new tasks. We used one of these CNNs, called ResNet50, and starting from its pre-trained state, we fine-tuned it on our brain scan dataset to predict ages instead of dogs, cats, cars, etc. Below we present the complete set of results for these different flavors of CNN architecture and training.
Focusing only on a 1cm axial zone, CNNs trained on raw data are able to match the performance of a linear model trained on the whole tissue segmented MRI. There are many ways we could improve the performances using neural networks: we could use more slices, design a 3D convolutional neural network, or use domain adaptation techniques to limit the multi-centric effect. We have no doubts deep learning methods would outperform more significantly linear models with more data, but even using transfer learning, it remains true that such models often require large amounts of data, which is difficult in healthcare.
As a final trick, we averaged the prediction of two of our best algorithms: the CNN with data augmentation, and the linear model on the segmented MRI. Comparable to collaboration between experts, ensembling methods bring an additional boost in performance, highlighting that the two models are different but complementary.
In the table below, we present more metrics on our final model. Because our cohort is very young (median age is 27 years old), our model is more reliable for young subjects than for older ones.
Looking into the black box
While performance metrics may be objectively convincing, they are often not enough to generate trust. Algorithms such as CNNs, with their millions of parameters, are difficult to understand, and often considered black boxes. Data goes in and predictions come out. As the data scientist cannot explicitly explain the correlations the CNN found during training to make its predictions, it is hard to explain their success… or why they sometimes fail. Worse, it has been shown that CNNs sometimes make absurdly wrong predictions when adding some imperceptible noise (Nguyen et al. 2014).
Interpretability of CNNs is an active area of research. As a first glimpse into the black box we have created, we used an occlusion technique to find the regions of the MRI that are the most important for the CNN model to make accurate age predictions. The idea is to occlude a small area (4 cm² in our case) of the test set images and observe the corresponding drop in mean absolute error. If the performance drops a lot, it means that the occluded area was important for the algorithm. On the figure below, the pink regions are associated with the highest error drops.
The occlusion map of the youngest subjects on the left revealed that regions close the ventricles were important in the prediction, and indeed we know that ventricles are the thinnest in younger people, and that they dilate with aging.
The occlusion map of the oldest subjects revealed the importance of the insula in this slice, on both sides. This is consistent with results of Good et al., 2001, where brain aging was described as a heterogeneous process, with some regions known to show “accelerated” aging, such as… the insula. Moreover, deep gray matter (thalami and the internal capsule) was associated with higher prediction errors when occluded, consistent with the findings of Fama & Sullivan, 2015.
Several more sophisticated techniques to interpret deep learning algorithms are currently studied, and nicely summarized in this post.
Using our model to predict Alzheimer’s disease
So far, we have used age prediction from brain scans as a pretext to introduce to you the fundamentals of machine learning. But estimating the physiological age of the brain could be useful to develop a better understanding of neurodegenerative diseases such as Alzheimer’s disease. As a final experiment, we applied our models on 489 subjects from the ADNI database. These subjects are split into two categories: normal control (269 subjects) and Alzheimer’s disease (220 subjects).
We downloaded the data, cleaned it, passed it through our preprocessing pipeline and applied the trained linear regression and CNN models. Patients in this database were much older (averaging 75 years of age) than in the dataset we used for training (averaging 35 years of age), and as one would expect, our models consistently underestimate the age of healthy subjects. This failure highlights an important limitation of machine learning models: if they are not trained on a representative sample of the population, they may perform very badly on unseen subjects. This lack of transferability may be one of the most serious flaws of machine learning in healthcare.
However, very interestingly, when plotting the distributions of the differences between the reported subject age and the brain age predicted with the final algorithm (linear regression + CNN), we found an average difference of 6 years between patients with Alzheimer’s disease and healthy subjects. While the model never saw even 1 MRI from a subject with Alzheimer’s disease during training, the output of the model can distinguish subjects with Alzheimer’s disease with an ROC-AUC score of 76%. This tends to confirm that the brains of subjects with Alzheimer’s disease somehow contain features correlated to an accelerated brain aging (Gaser et al., 2013), corroborating the hypothesis that brain age could be a new biomarker for neurodegenerative diseases (Rafel et al., 2017, Koutsouleris et al., 2014, Cole et al., 2015).
The journey of prediction with machine learning can go from deploying simple linear models in seconds to building complex CNNs that train over days. In both cases, interpretability is crucial, particularly in medicine where understanding the data features most correlated with a predictive model’s performance can actually lead to the discovery of new biological mechanisms or biomarkers.
However, researchers must be careful when building and analyzing these models. Even when cross-validation is done properly, stellar performance on a limited number of use cases does not guarantee good generalization capabilities. This is especially true if the algorithm was not trained on a representative sample of the overall population, and biases are frequent in practice. While difficult from a regulatory and organizational perspective, investing time collecting and cleaning datasets is key for the success of machine learning projects.
We hope that this article has given you a taste of what AI looks like in practice, and that it encouraged you to become a “machine teacher”! At Owkin, we always welcome new exciting projects, so if you have an X, a Y and some motivation, you know what to do!
Several members of the Owkin team contributed to this work, including Simon Jégou*, Paul Herent*, Olivier Dehaene and Thomas Clozel. We also thank Dr. Roger Stupp, Dr. Julien Savatovsky, and Olivier Elemento, PhD, for their active support, Sylvain Toldo and Valentin Amé for their work on the figures, as well as Sebastian Schwarz and Eric Tramel for their edits on the manuscript
- Simon Jégou, data scientist: firstname.lastname@example.org
- Paul Herent, radiologist: email@example.com
- Partnership team :