Stroke Prediction Machine Learning Model

Gabe Walter
7 min readJun 11, 2023

--

Introduction

Each year, almost 800,000 Americans suffer a stroke. According to the World Stroke Organization, 1 in 4 of us will have a stroke at some point in our lives. Strokes are very often the result of a combination of health factors and lifestyle choices. Knowing this, if we are able to figure out which behaviors are causing the most risk, we may be able to predict strokes before they happen.

Imports

The first thing we need to do to start any python project is import the necessary libraries. (Note some of these were imported out of need later in the project)

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import LabelEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from imblearn.over_sampling import RandomOverSampler
from imblearn.over_sampling import SMOTE
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
import xgboost as xgb

Data

As with any project, we also need our data. Let’s read our CSV in as a DataFrame and check it out.

df = pd.read_csv("stroke data.csv")
df.head()

As we can see, there are various columns in our data, including personal attributes (such as age, marriage status, and occupation), health issues (such as heart disease and hypertension), and lifestyle choices (like smoking status). We also have 3 unnamed metrics. Some people might say that these metrics would be useless, but we can actually use them later to see if there is a correlation between them and having a stroke.

Lets also check out the statistics on our data:

df.describe()

df.isnull().sum()
id                    0
gender 0
age 0
married 0
hypertension 0
heart_disease 0
occupation 0
residence 0
metric_1 0
metric_2 1462
metric_3 0
metric_4 0
metric_5 0
smoking_status 13292
stroke 0
dtype: int64

We can see from this that we have some work to do in terms of cleaning up our data before we build the model. For example, we have some incomprehensible data, like the minimum age being -10 years old. We also have some N/A data in our metric_2 and smoking_status columns. We can clean that up.

# Process data by dropping uneccessary columns, removing incorrect values(age = -10); switch columns to numerical data
df.drop("id", axis = 1, inplace = True)
df.drop(df[(df["age"] < 0)].index, inplace = True)

# fillna(0) returns higher accuracy than dropna() for smoking_status
# Some issue where fillna() does not work on metric_2, so just dropna() after fillna(0)
df["smoking_status"].fillna(0)

labeler = LabelEncoder()
c = df.select_dtypes(include = "object").columns
df[c] = df[c].apply(labeler.fit_transform)

df.dropna(inplace = True)

We can drop the id column since it doesn’t add anything to the dataset. We can also drop any rows where age is less than zero, since it should be assumed that that data is erroneous. We could drop and N/A columns in smoking_status, but after completing the model, it was found that using fillna(0) returned higher accuracy scores. Next, we want to use LabelEncoder() to transform all of our object type columns into numerical columns so our model can take them in. Now, all our data is clean with no missing values.

df.isnull().sum()
gender            0
age 0
married 0
hypertension 0
heart_disease 0
occupation 0
residence 0
metric_1 0
metric_2 0
metric_3 0
metric_4 0
metric_5 0
smoking_status 0
stroke 0
dtype: int64

Variables

We are probably not going to want to use every column in our data in our model. This is because some of our variables will be more correlated than others. We could find out which ones provide the best results manually by testing a bunch of different variables in our model, but there is an easier way. We can use df.corr() to find the correlations between every variable in our DataFrame. I also used seaborn’s heatmap to make this more visually appealing.

#Find best variables to use in model
plt.figure(figsize=(12,12))
sns.heatmap(df.corr()[["stroke"]].sort_values(by = "stroke"), annot = True, cmap="Greens")

We now have a graph of the most correlated variables. We can see that obvious ones such as age, heart disease, and hypertension are all relatively highly correlated. A surprising result here is a negative correlation between strokes and smoking, since smoking is a well known cause of strokes. This is a good example for why we need to use df.corr(). If we hadn’t we may have assumed a strong positive correlation between smokers and stroke victims, and created a model under false assumptions.

Model

Now we can start building our model. The first thing to do is to prepare our data to be fed into the model.

# Split data
best_variables = [k for k, v in df.corr()["stroke"].items() if v > 0 and k != "stroke"]
X = df[best_variables]
y = df["stroke"]
Xtrain, Xtest, ytrain, ytest = train_test_split(X, y, test_size=0.2, random_state = 42)

After a bunch of experimentation, I found that the best accuracy scores were produced when only the positively correlated variables were used, so we select those as best_variables. Then we split our data into train and test datasets.

Next, we have to balance our data. This is because strokes are actually very rare in this dataset, since our data is a survey of a bunch of regular people of all ages. Without this step, our model will be very inaccurate because it will not have enough stroke victims to learn to accurately predict them. There are many ways to do this, but we are going to use a method called oversampling. Specifically, we are going to use a technique called SMOTE, or Synthetic Minority Oversampling Technique. I originally planned to use RandomOverSampler, but SMOTE returned higher accuracy scores. SMOTE will essentially create a bunch of synthetic stroke victims with similar metrics to other stroke victims and insert them into our dataset at random points. Here is an example of oversampling if you are unfamiliar.

Image from https://www.mastersindatascience.org/learning/statistics-data-science/undersampling/
# Balance Data!! Smote achieved slightly higher scores than RandomOverSampler, so went with Smote

# oversample = RandomOverSampler(sampling_strategy="minority")
# Xtrain, ytrain = oversample.fit_resample(Xtrain, ytrain)
# Xtest, ytest = oversample.fit_resample(Xtest, ytest)

smote=SMOTE()
Xtrain, ytrain=smote.fit_resample(Xtrain, ytrain)
Xtest, ytest=smote.fit_resample(Xtest, ytest)

Now, we can create our model. To start, lets use a Logistic Regression model, since we are trying to predict a binary outcome. We will also use StandardScaler, since the range of our data varies a lot.

#Create and train model
pipe = Pipeline([("std", StandardScaler()), ("lr", LogisticRegression())])
pipe.fit(Xtrain, ytrain)

# Make predictions
ypred = pipe.predict(Xtest)

# Score model
print("Accuracy:", pipe.score(Xtest, ytest))
print("\nClassification report:\n", classification_report(ytest, ypred))
ConfusionMatrixDisplay.from_estimator(pipe, Xtest, ytest)

From this, we can see our model did pretty well. We had an 82.8% accuracy score. We can also consult our confusion matrix. We had 6,419 true negatives, and 7,280 true positives, compared to 2,845 false positives or false negatives. But, we can probably increase our accuracy by using different models. Let’s try to increase our scores using a different model.

I thought about using the RandomForestClassifier, but it returned similar scores to the LogisticRegression model. However, using the popular XGBoost library, we can greatly improve the accuracy. I did some playing around with the hyperparameters, and was able to get a much higher score.

# Second Model, RandomForestClassifier returned low scores so using XGBoost

# Create model
model = xgb.XGBClassifier(n_estimators=1000, max_depth=5, learning_rate=0.1, objective="binary:logistic")

# Fit model
model.fit(Xtrain, ytrain)

# Make predictions
ypred = model.predict(Xtest)

# Score model
print("Accuracy:", model.score(Xtest, ytest))
print("\nClassification report:\n", classification_report(ytest, ypred))
ConfusionMatrixDisplay.from_estimator(model, Xtest, ytest)

Now, we have a roughly 95.5% accurate model to predict if someone will have a stroke or not. One thing to note is that this model is outstandingly accurate when it predicts a stroke. We can see this through the 7,606 true positives compared to only 84 false positives. If we were, say, and insurance company, information like this could be extremely valuable.

Conclusion

Overall, I was able to achieve very high accuracy scores with my models. I had a few concluding thoughts that I wrote in my Jupyter Notebook, so I will leave them here. If I had absolutely wanted to maximize the models, I could have used further hyperparameter tuning. Specifically, utilizing RandomSearchCV or GridSearchCV would likely have boosted results. Also, I could’ve experimented with different ways of cleaning the data. For example, metric_2 had missing data in places. I was considering replacing N/A values with the mean values for the column to see the effect. All in all however, I would say this was a successful project.

--

--

Gabe Walter
0 Followers

Data Science student interested in Machine Learning, AI, and Finance