Machine Learning with a Heart: Predicting Heart Disease

Machine learning (ML) is causing quite the buzz in the healthcare industry as a whole . Payers to healthcare companies around the world are taking advantage of ML today. In this post, I will demonstrate a use case and show how we can harness the power of ML and apply it real world problems. We’ll walk through a very simple baseline model for predicting heart disease from patient data, how to load the data, and make some predictions.

In previous discussion I shared my notes on Deep Learning Book Part I: Applied Math and Machine Learning Basics. This is a good segue to apply some of those concepts discussed in the book. Let’s take a look at Warm Up: Machine Learning with a Heart competition hosted by DrivenData.


The figure above is a great visual of a machine learning project A to Z. The first step before we began coding is to understand the problem we are trying to solve and get the available data. In this project, we will work with publicly available data from DrivenData.

About Heart Disease

Preventing heart disease is important. Good data-driven systems for predicting heart disease can improve the entire research and prevention process, making sure that more people can live healthy lives.

In the United States, the Centers for Disease Control and Prevention is a good resource for information about heart disease. According to their website:

  • About 610,000 people die of heart disease in the United States every year–that’s 1 in every 4 deaths.
  • Heart disease is the leading cause of death for both men and women. More than half of the deaths due to heart disease in 2009 were in men.
  • Coronary heart disease (CHD) is the most common type of heart disease, killing over 370,000 people annually.
  • Every year about 735,000 Americans have a heart attack. Of these, 525,000 are a first heart attack and 210,000 happen in people who have already had a heart attack.
  • Heart disease is the leading cause of death for people of most ethnicities in the United States, including African Americans, Hispanics, and whites. For American Indians or Alaska Natives and Asians or Pacific Islanders, heart disease is second only to cancer.

For more information, you can look at the website of the Centers for Disease Control and Prevention: preventing heart disease

Problem description

The goal is to predict the binary class heart_disease_present, which represents whether or not a patient has heart disease:

  • 0 represents no heart disease present
  • 1 represents heart disease present


There are 14 columns in the dataset, where the patient_id column is a unique and random identifier. The remaining 13 features are described in the section below.

  • slope_of_peak_exercise_st_segment (type: int): the slope of the peak exercise ST segment, an electrocardiography read out indicating quality of blood flow to the heart
  • thal (type: categorical): results of thallium stress test measuring blood flow to the heart, with possible values normal, fixed_defect, reversible_defect
  • resting_blood_pressure (type: int): resting blood pressure
  • chest_pain_type (type: int): chest pain type (4 values)
  • num_major_vessels (type: int): number of major vessels (0-3) colored by flourosopy
  • fasting_blood_sugar_gt_120_mg_per_dl (type: binary): fasting blood sugar > 120 mg/dl
  • resting_ekg_results (type: int): resting electrocardiographic results (values 0,1,2)
  • serum_cholesterol_mg_per_dl (type: int): serum cholestoral in mg/dl
  • oldpeak_eq_st_depression (type: float): oldpeak = ST depression induced by exercise relative to rest, a measure of abnormality in electrocardiograms
  • sex (type: binary): 0: female, 1: male
  • age (type: int): age in years
  • max_heart_rate_achieved (type: int): maximum heart rate achieved (beats per minute)
  • exercise_induced_angina (type: binary): exercise-induced chest pain (0: False, 1: True)

Step 2: Clean and Prepare

Datasets in a perfect world is a perfectly curated group of observations with no missing values or anomalies. However, this is not true. Real world data comes in all shapes and sizes. It can be messy, which means it needs to be clean and wrangles. Data cleaning is a necessary part in data science problems. Machine learning models learn from data. It is crucial, however, that the data you feed them is specifically pre processed and refined for the problem you want to solve. This includes data cleaning, preprocessing, feature engineering, and so on.

Visual Exploratory Data Analysis (EDA)

It’s time to visualize our data with a little help from the seaborn package.

From the frequency plot of heart disease below, we see that the two classes (‘Heart Disease’ and ‘No Heart Disease’) are approximately balanced, with 45% of observations having heart disease and the remaining population not having heart disease.

The data is relatively well-balanced, so we won’t take any steps here to equalize the classes.

The pairplot above allows us to see the distribution and relationship of numerical variables. The diagonal shows kernel density plots showing the rough distributions of the two populations. The scatter plots show relationship between plots. We can make a couple observations from the pairplot:

  • Resting blood pressure tends to increase with age regardless of heart disease.
  • We can see that max heart rates are significantly lower for people without heart disease.

Step 3 and 4: Train and Test Model

Model selection is the process of combining data and prior information to select among a group of statistical models. When it comes to classic first pass models, few can contend with logistic regression. This linear model is fast to train, easy to understand, and typically does pretty well “out of the box”.

The Scikit Learn logistic regression model works well combined with Pipeline and GridSearchCVpreprocessing tools. This will help to streamline the process of model training and hyperparameter optimization.

Sklearn’s pipeline functionality makes it easier to repeat commonly occuring steps in your modeling process.

Pipeline and GridSearchCV preprocessing

GridSearchCV allows you to construct a grid of all the combinations of parameters, tries each combination, and then reports back the best combination/model.

The figure above is a confusion matrix for parameter C and penalty.

Looking at the C (Inverse Regulation Strength) plot and print out above, we can see that the highest accuracy achieved is when the parameter C is set to 0.1 and the penalty parameter is set to ‘l2’ (Ridge Regression).

From the confusion matrices above, we see that the most ‘accurate’ model tends to make the worse kind of mistake. Like above, correct predictions appear on the main diagonal, whereas all off-diagonal values correspond to incorrect classifications.

Accuracy is measured on how well all cases were classified. The two types of misclassifications:

  • labeling a healthy person as unhealthy,
  • labeling an unhealthy person as healthy

are not equally bad. It’s definitely worse to mislabel a patient as healthy if they actually have heart disease, as their heart disease would go untreated, and they may continue with an unhealthy diet or pursue dangerous activity levels.

In terms of classification metrics, that kind of prediction mistake is a False Negative. The alternative mistake (a False Positive) is to label someone as unhealthy when they are actually healthy, which would lead to someone unnecessarily changing their diet and lifestyle, which would be unpleasant, but not potentially lethal, as in the other case.

Step 5: Improve

One way to improve our model is to reduce the number of features in your data matrix by picking those with the highest predictive value is advisable.

Let’s build a model that takes the cost of mistakes into account.

The number of significant features is less than the total number of features, so the unimportant features are eliminated.

The l1 logistic regression (Lasso Regression) was smart enough to assign 0 importance to features below the significance threshold.

The confusion matrices directly below as well as the confusion matrices a bit further up were made by fitting classifiers with 30 different slices of the data and generating a confusion matrix for each data slice.

The confusion matrices below used a classifier with parameters {C=1, penalty =L1}, only used the 13 features that the recursive feature elimination found to be significant, and used parameters that were selected by the recall scoring measure rather than simple accuracy.

Over the 30 runs and over a sample of 45 people, this model labeled (on average) 3.27 unhealthy people as being healthy, which is slightly better than the above model, which labeled 3.53 unhealthy people as healthy, but this model labeled 2.63 healthy people as unhealthy while the above model labeled 1.80 healthy people as unhealthy.

Considering the differing costs of mistakes, this is an improvement.


In summary, we demonstrated a use case and showed how we can harness the power of ML and apply it real world problems. The Warm Up: Machine Learning with a Heart is a good dataset to practice applying ML algorithms.