🧠 Stroke Prediction using Random Forest

Sidharth Pandita
hackerdawn
Published in
5 min readMay 16, 2021
Credits: heart.org

A stroke occurs when the blood supply to part of your brain is interrupted, preventing brain tissue from getting oxygen and nutrients. Due to this, brain cells begin to die in minutes. We’ll use 11 features of a person to predict whether they will get a stroke or not.

Importing Libraries

Let’s first import the required libraries. If you don’t have a particular library installed, run the command ‘pip install <package_name>’ to install it.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import LabelEncoder
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

Exploring the Dataset

Let us load the dataset we downloaded from Kaggle. We will use pd.read_csv() function to do so.

df = pd.read_csv('./stroke/healthcare-dataset-stroke-data.csv')

We’ll print the head of the dataset to see how the data looks like.

df.head()
Head of Dataframe

Let’s get the dimensions of the dataset using shape.

df.shape
Row & Column count respectively

We will now use info() to get non-null counts and data types of different columns.

df.info()
Column info

Let’s get the stats of the numeric columns in the dataset.

df.describe().T
Stats of Numeric columns

Let’s see the null value count in different columns. As you can see in the output, bmi is the only column containing null values (201).

total = df.isnull().sum().sort_values(ascending=False)percent = (df.isnull().sum()/df.isnull().count()).sort_values(ascending=False)missing_data = pd.concat([total, percent], axis=1, keys=['Total', 'Percent'])missing_data
Null counts in descending order

Let’s visualize the null values using a heatmap.

f, ax = plt.subplots(nrows = 1, ncols = 1, figsize=(16,5))hm = sns.heatmap(df.T.isna(), cmap='binary_r')
hm.set_yticklabels(hm.get_ymajorticklabels(), fontsize = 14)
ax.set_title('Missing Values')
plt.show()
Null count Heatmap

We’ll fill the null values in the bmi column with the mean of the values in the column. Now, let us see if there are any null values still left. As shown in the output, no null values are left in any column.

df['bmi'].fillna(df['bmi'].mean(), inplace=True)df.isnull().sum()
Column-wise Null counts

We’ll plot a heatmap to see the correlation between features of the dataset. The higher the number inside the box, the higher is the correlation.

plt.figure(figsize=(7,7))
map = sns.heatmap(df.corr().abs(),annot=True,cmap="autumn_r")
Correlation Heatmap

Let’s see the counts of observations in each bin for categorical features using countplot.

fig,axes = plt.subplots(4,2,figsize = (15,15))
fig.suptitle("Count plot for categorical features")
#gender
sns.countplot(ax=axes[0,0],data=df,x='gender')
#smoking_status
sns.countplot(ax=axes[0,1],data=df,x='smoking_status')
#heart_disease
sns.countplot(ax=axes[1,0],data=df,x='heart_disease')
#ever_married
sns.countplot(ax=axes[1,1],data=df,x='ever_married')
#work_type
sns.countplot(ax=axes[2,0],data=df,x='work_type')
#Residence_type
sns.countplot(ax=axes[2,1],data=df,x='Residence_type')
#hypertension
sns.countplot(ax=axes[3,0],data=df,x='hypertension')
#stroke
sns.countplot(ax=axes[3,1],data=df,x='stroke')
plt.show()
Visualizing categorical values

We’ll visualize the distribution of ‘avg_gulcose_level’ using a histplot. We can clearly see a right skew in the distribution.

fig = plt.figure(figsize=(7,7))sns.histplot(df.avg_glucose_level,color='green',label='avg_glucose_level',kde=True,bins=60)plt.legend()
Distribution of ‘avg_gulcose_level’

We’ll visualize the distribution of ‘bmi’ using a histplot.

fig = plt.figure(figsize=(7,7))sns.histplot(df.bmi,color='orange',label='bmi',kde=True)plt.legend()
Distribution of ‘bmi’

Data Preprocessing

We will take the log of the ‘avg_glucose_level’ column to handle the skewness.

df['avg_glucose_level'] = np.log(df['avg_glucose_level'])

The column ‘id’ doesn’t provide any useful context. So, we’ll drop it.

df = df.drop('id',axis=1)

Let’s see the counts of values in the ‘gender’ column.

df['gender'].value_counts()
Value counts in ‘gender’

As the value ‘Other’ occurs only once in the ‘gender’ column, we’ll drop it.

df.drop(df[df['gender'] == 'Other'].index, inplace = True)df['gender'].unique()
Unique values in ‘gender’

The values in columns ‘gender’, ‘ever_married’, ‘work_type’, ‘Residence_type’, ‘smoking_status’ are categorical. We’ll encode them using LabelEncoder().

We can see in the output that all categorical values have been encoded with values between 0 and n_classes-1.

#Categorical columns
object_cols = ["gender","ever_married","work_type","Residence_type","smoking_status"]
label_encoder = LabelEncoder()for col in object_cols:
label_encoder.fit(df[col])
df[col] = label_encoder.transform(df[col])
df.head()
After using LabelEncoder()

We will divide the data between X and Y such that X contains all the columns except ‘stroke’ and Y only contains ‘stroke’.

#Separating input and target data
X = df.drop(['stroke'],axis=1)
Y = df['stroke']
#Printing shapes of input and target
print('X Shape', X.shape)
print('Y Shape',Y.shape)
Shapes of X and Y

We will find out the number of 0’s and 1’s in stroke ‘column’.

df['stroke'].value_counts()
Value counts in ‘stroke’

We can clearly see that there is a class imbalance problem here as there are 4860 cases of stroke but only 249 cases of no-stroke. We’ll use SMOTE to handle this class imbalance.

smote = SMOTE(sampling_strategy=0.1)X, Y = smote.fit_resample(X, Y)

We’ll now split the X, Y into train and test.

x_train,x_test,y_train,y_test = train_test_split(X,Y,test_size=0.2,random_state=0)

Creating & Fitting the Model

We will use RandomForestClassifier as our model. We’ll then fit the model using the training data, i.e x_train and y_train.

model = RandomForestClassifier(random_state=0)model.fit(x_train,y_train)
Default Parameters of the model

Let’s see our model’s score on test data. As you can see in the output, the model’s score is 0.94, which means the model correctly predicts 94% of the target values.

model.score(x_test,y_test)
Score

Predicting the Target

We will predict the test data now. The output is an array of 0’s and 1’s where 0 denotes a no-stroke and 1 denotes a stroke.

y_pred = model.predict(x_test)print(y_pred)
Prediction Array

We have completed the stroke prediction. If you liked this tutorial, do leave a clap! Hit the Follow to join the community.

--

--