🧠Stroke Prediction using Random Forest
A stroke occurs when the blood supply to part of your brain is interrupted, preventing brain tissue from getting oxygen and nutrients. Due to this, brain cells begin to die in minutes. We’ll use 11 features of a person to predict whether they will get a stroke or not.
Importing Libraries
Let’s first import the required libraries. If you don’t have a particular library installed, run the command ‘pip install <package_name>’ to install it.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import LabelEncoder
from imblearn.over_sampling import SMOTE
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
Exploring the Dataset
Let us load the dataset we downloaded from Kaggle. We will use pd.read_csv() function to do so.
df = pd.read_csv('./stroke/healthcare-dataset-stroke-data.csv')
We’ll print the head of the dataset to see how the data looks like.
df.head()
Let’s get the dimensions of the dataset using shape.
df.shape
We will now use info() to get non-null counts and data types of different columns.
df.info()
Let’s get the stats of the numeric columns in the dataset.
df.describe().T
Let’s see the null value count in different columns. As you can see in the output, bmi is the only column containing null values (201).
total = df.isnull().sum().sort_values(ascending=False)percent = (df.isnull().sum()/df.isnull().count()).sort_values(ascending=False)missing_data = pd.concat([total, percent], axis=1, keys=['Total', 'Percent'])missing_data
Let’s visualize the null values using a heatmap.
f, ax = plt.subplots(nrows = 1, ncols = 1, figsize=(16,5))hm = sns.heatmap(df.T.isna(), cmap='binary_r')
hm.set_yticklabels(hm.get_ymajorticklabels(), fontsize = 14)
ax.set_title('Missing Values')plt.show()
We’ll fill the null values in the bmi column with the mean of the values in the column. Now, let us see if there are any null values still left. As shown in the output, no null values are left in any column.
df['bmi'].fillna(df['bmi'].mean(), inplace=True)df.isnull().sum()
We’ll plot a heatmap to see the correlation between features of the dataset. The higher the number inside the box, the higher is the correlation.
plt.figure(figsize=(7,7))
map = sns.heatmap(df.corr().abs(),annot=True,cmap="autumn_r")
Let’s see the counts of observations in each bin for categorical features using countplot.
fig,axes = plt.subplots(4,2,figsize = (15,15))
fig.suptitle("Count plot for categorical features")#gender
sns.countplot(ax=axes[0,0],data=df,x='gender')
#smoking_status
sns.countplot(ax=axes[0,1],data=df,x='smoking_status')
#heart_disease
sns.countplot(ax=axes[1,0],data=df,x='heart_disease')
#ever_married
sns.countplot(ax=axes[1,1],data=df,x='ever_married')
#work_type
sns.countplot(ax=axes[2,0],data=df,x='work_type')
#Residence_type
sns.countplot(ax=axes[2,1],data=df,x='Residence_type')
#hypertension
sns.countplot(ax=axes[3,0],data=df,x='hypertension')
#stroke
sns.countplot(ax=axes[3,1],data=df,x='stroke')plt.show()
We’ll visualize the distribution of ‘avg_gulcose_level’ using a histplot. We can clearly see a right skew in the distribution.
fig = plt.figure(figsize=(7,7))sns.histplot(df.avg_glucose_level,color='green',label='avg_glucose_level',kde=True,bins=60)plt.legend()
We’ll visualize the distribution of ‘bmi’ using a histplot.
fig = plt.figure(figsize=(7,7))sns.histplot(df.bmi,color='orange',label='bmi',kde=True)plt.legend()
Data Preprocessing
We will take the log of the ‘avg_glucose_level’ column to handle the skewness.
df['avg_glucose_level'] = np.log(df['avg_glucose_level'])
The column ‘id’ doesn’t provide any useful context. So, we’ll drop it.
df = df.drop('id',axis=1)
Let’s see the counts of values in the ‘gender’ column.
df['gender'].value_counts()
As the value ‘Other’ occurs only once in the ‘gender’ column, we’ll drop it.
df.drop(df[df['gender'] == 'Other'].index, inplace = True)df['gender'].unique()
The values in columns ‘gender’, ‘ever_married’, ‘work_type’, ‘Residence_type’, ‘smoking_status’ are categorical. We’ll encode them using LabelEncoder().
We can see in the output that all categorical values have been encoded with values between 0 and n_classes-1.
#Categorical columns
object_cols = ["gender","ever_married","work_type","Residence_type","smoking_status"]label_encoder = LabelEncoder()for col in object_cols:
label_encoder.fit(df[col])
df[col] = label_encoder.transform(df[col])df.head()
We will divide the data between X and Y such that X contains all the columns except ‘stroke’ and Y only contains ‘stroke’.
#Separating input and target data
X = df.drop(['stroke'],axis=1)
Y = df['stroke']
#Printing shapes of input and target
print('X Shape', X.shape)
print('Y Shape',Y.shape)
We will find out the number of 0’s and 1’s in stroke ‘column’.
df['stroke'].value_counts()
We can clearly see that there is a class imbalance problem here as there are 4860 cases of stroke but only 249 cases of no-stroke. We’ll use SMOTE to handle this class imbalance.
smote = SMOTE(sampling_strategy=0.1)X, Y = smote.fit_resample(X, Y)
We’ll now split the X, Y into train and test.
x_train,x_test,y_train,y_test = train_test_split(X,Y,test_size=0.2,random_state=0)
Creating & Fitting the Model
We will use RandomForestClassifier as our model. We’ll then fit the model using the training data, i.e x_train and y_train.
model = RandomForestClassifier(random_state=0)model.fit(x_train,y_train)
Let’s see our model’s score on test data. As you can see in the output, the model’s score is 0.94, which means the model correctly predicts 94% of the target values.
model.score(x_test,y_test)
Predicting the Target
We will predict the test data now. The output is an array of 0’s and 1’s where 0 denotes a no-stroke and 1 denotes a stroke.
y_pred = model.predict(x_test)print(y_pred)
We have completed the stroke prediction. If you liked this tutorial, do leave a clap! Hit the Follow to join the community.