FastAPI: Employee Attrition Prediction

Deploy boosting model using FastAPI and predict employee attrition

Hitesh Sharma
9 min readJul 6, 2023

What is an API?

An application programming interface (API), is a computing
interface which defines interactions between multiple software
intermediaries. It defines the kind of calls and requests that can be made, how to make them, the data formats that should be used, the conventions
to follow. The API user does not get access to the whole dataset
or source code and yet they can get all the information
that they need.

The users through their devices interact with an API which acts as an intermediary between them and the servers
API as an intermediary

Here, the users through their devices interact with an API which acts as an intermediary between the users and the web Servers/databases. Interacting directly with the complex servers can be difficult for most users. APIs help to by-pass all these difficulties.

Let us try to have an intuition of the API with the help of an example.

Taking an analogy of the waiter to that of an API. Just as how a waiter works as an intermediary between the user and the kitchen, a similar work is done by an API. The person ordering the food does not have to worry about the details in the kitchen. The waiter will act as an intermediary between the customer and the kitchen. The customer can enjoy the food without going into the details of how the food is being prepared in the kitchen.
Waiter acts as an intermediary between the customer and the kitchen

Try to understand API as an intermediary with the help of an analogy. A waiter here acts as an intermediary between the customer and the kitchen. The customer does not have to worry about how the food will be prepared in the kitchen or any other detail. The customer interacts only with the waiter where he can get all the necessary information required for him to choose the right food. Similarly an API acts as intermediary between the end users and servers/databases.

Why should one choose FastAPI?

FastAPI logo
https://fastapi.tiangolo.com/

FastAPI is a modern, fast (high-performance), web framework for building APIs with Python based on standard Python type hints and Pydantic. It can run with Gunicorn and ASGI servers for production such as Uvicorn and Hypercorn. One of the fastest Python frameworks available, can serve upto
9000 requests per second. Designed to be easy to use and learn.

The key features are:

  • Fast: Very high performance, on par with NodeJS and Go.
  • Fast to code: Increase the speed to develop features by about 200% to 300%.
  • Fewer bugs: Reduce about 40% of human (developer) induced errors.
  • Intuitive: Great editor support. Completion everywhere. Less time debugging.
  • Easy: Designed to be easy to use and learn. Less time reading docs.
  • Short: Minimize code duplication. Multiple features from each parameter declaration. Fewer bugs.
  • Robust: Get production-ready code. With automatic interactive documentation.

Read more about FastAPI on its official website here

Employee Attrition Prediction Implementation

Problem Statement: To predict employee attrition using CatBoost and XgBoost

Learning Procedure:

  • Explore the employee attrition dataset.
  • Apply CatBoost and XgBoost on the dataset.
  • Tune the model hyperparameters to improve accuracy.
  • Evaluate the model using suitable metrics.

What is Employee Attrition?

Employee attrition is the gradual reduction in employee numbers. Employee attrition happens when the size of your workforce diminishes over time. This means that employees are leaving faster than they are hired. Employee attrition happens when employees retire, resign, or simply aren’t replaced. Although employee attrition can be company-wide, it may also be confined to specific parts of a business.

Employee attrition can happen for several reasons. These include unhappiness about employee benefits or the pay structure, a lack of employee development opportunities, and even poor conditions in the workplace.

To know more about the factors that lead to employee attrition, refer here.

Gradient Boosted Decision Trees:

  • Gradient boosted decision trees (GBDTs) are one of the most important machine learning models.
  • GBDTs originate from AdaBoost, an algorithm that ensembles weak learners and uses the majority vote, weighted by their individual accuracy, to solve binary classification problems. The weak learners in this case are decision trees with a single split, called decision stumps.
  • Some of the widely used gradient boosted decision trees are XgBoost, CatBoost and LightGBM.

Dataset:

The dataset used for this mini-project is HR Employee Attrition dataset. It is a fictional dataset created by IBM data scientists. There are 35 features and 1470 records.

There are numerical features such as:

  • Age
  • DistanceFromHome
  • EmployeeNumber
  • PerformanceRating

There are several categorical features such as:

  • JobRole
  • EducationField
  • Department
  • BusinessTravel

Dependent or target feature is ‘attrition’ which has values as Yes/No.

!wget -qq https://cdn.iisc.talentsprint.com/CDS/Datasets/wa_fn_usec_hr_employee_attrition_tsv.csv

Install CatBoost:

!pip -qq install catboost

Import Required Packages:

import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix, f1_score, ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from lightgbm import LGBMClassifier
from xgboost import XGBClassifier
from catboost import CatBoostClassifier, metrics
import warnings
warnings.filterwarnings("ignore")
plt.style.use('fivethirtyeight')
pd.set_option('display.max_columns', 100)
%matplotlib inline

from hyperopt import hp, tpe, Trials, STATUS_OK
from hyperopt import fmin
from fastapi import FastAPI
import pickle
from pydantic import BaseModel

Load the Dataset:

ibm_df = pd.read_csv('/content/wa_fn_usec_hr_employee_attrition_tsv.csv')
ibm_df.head()

Data Exploration:

  • Check for missing values.
  • Do we have a target label imbalance?

Description:

description = pd.DataFrame(index=['observations(rows)', 'percent missing', 'dtype', 'range'])
numerical = []
categorical = []
for col in ibm_df.columns:
obs = ibm_df[col].size
p_nan = round(ibm_df[col].isna().sum()/obs, 2)
num_nan = f'{p_nan}% ({ibm_df[col].isna().sum()}/{obs})'
dtype = 'categorical' if ibm_df[col].dtype == object else 'numerical'
numerical.append(col) if dtype == 'numerical' else categorical.append(col)
rng = f'{len(ibm_df[col].unique())} labels' if dtype == 'categorical' else f'{ibm_df[col].min()}-{ibm_df[col].max()}'
description[col] = [obs, num_nan, dtype, rng]

#numerical.remove('employeecount')
numerical.remove('standardhours')
pd.set_option('display.max_columns', 100)
display(description)

Check for Outliers:

ibm_df.boxplot(rot = 90)
plt.show()

Handing outliers:

outlier_colms = ['monthlyincome', 'numcompaniesworked', 'stockoptionlevel',  'performancerating', 'totalworkingyears',
'trainingtimeslastyear', 'yearsatcompany', 'yearsincurrentrole', 'yearssincelastpromotion', 'yearswithcurrmanager']
ibm_df1 = ibm_df.copy()

def handle_outliers(df, colm):
'''Change the values of outlier to upper and lower whisker values '''
q1 = df.describe()[colm].loc["25%"]
q3 = df.describe()[colm].loc["75%"]
iqr = q3 - q1
lower_bound = q1 - (1.5 * iqr)
upper_bound = q3 + (1.5 * iqr)
for i in range(len(df)):
if df.loc[i,colm] > upper_bound:
df.loc[i,colm]= upper_bound
if df.loc[i,colm] < lower_bound:
df.loc[i,colm]= lower_bound
return df

for colm in outlier_colms:
ibm_df1 = handle_outliers(ibm_df1, colm)

Recheck for outliers:

ibm_df1.boxplot(rot = 90)
plt.show()

Target label imbalance:

attrition_values = ibm_df1['attrition'].value_counts()
attrition_values

Count of unique values in Attrition column

plt.bar(attrition_values.index, attrition_values.values)
plt.title('IBM Attrition Label Imbalance')
plt.xlabel('Whether IBM Employee Left')
plt.ylabel('Count of Employees')
plt.show()

Plot pairplot:

features = ['monthlyincome', 'attrition', 'yearsatcompany', 'yearswithcurrmanager', 'joblevel', 'totalworkingyears']
pairplot = sns.pairplot(ibm_df1[features], diag_kind='kde', hue='attrition')
plt.show()

From the results it can see that the data has an imbalance in target labels. It has about a 6:1 No attrition label compared to Yes. The effect of the imbalance really shows up in the pairplots where the yes markers in the scatter plots are all but drowned out, though this would be less of a problem if the classes were more distinct. To test the model a smart thing to do would be to look at the confusion matrix and see how well the model performed on the minority class, yes.

Explore Correlation:

plt.figure(figsize = (10, 8))
sns.heatmap(ibm_df1.corr())
plt.title('Correlation Among IBM Employee Attrition Numerical Features')
plt.show()

Preparing the data for CatBoost:

split_data = ibm_df1.copy()

Creating a copy of the data.

split_dummy = pd.get_dummies(split_data[categorical], drop_first=True)
split_dummy.head()

Handling categorical features.

split_data = pd.concat([split_data, split_dummy], axis=1)
split_data.drop(columns = categorical, inplace=True)
split_data.head(5)

Concat the dummy variables to actual dataframe and remove initial categorical columns.

split_data.rename(columns={'attrition_Yes': 'attrition'}, inplace=True)
split_data.head()

Rename target column

X = split_data.drop('attrition', axis=1)

Features

y = split_data['attrition']

Target label

X.shape, y.shape
X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.2)

Apply CatBoost

Here is the official documentation of CatBoost.

# Create CatBoost model
cboost = CatBoostClassifier(learning_rate = 1,
depth = 1,
scale_pos_weight = 6,
l2_leaf_reg = 8,
border_count = 65
)
# Model training
cboost.fit(X_train, y_train,
cat_features=None)

Model performance

# Model performance on all sets
predictions = cboost.predict(X_test)
test_preds = cboost.predict_proba(X_test)[:,1]
train_preds = cboost.predict_proba(X_train)[:,1]

train_auc = roc_auc_score(y_train, train_preds)
test_auc = roc_auc_score(y_test, test_preds)
accuracy = accuracy_score(y_test, predictions)

F1 Score

F1 = f1_score(y_test, predictions)
print(F1)

0.5245901639344263

Confusion Matrix

cm = confusion_matrix(y_test, predictions, labels=cboost.classes_)
disp = ConfusionMatrixDisplay(confusion_matrix=cm,
display_labels=cboost.classes_)
disp.plot()
plt.show()
cboost_results = {'accuracy': accuracy,
'model': 'CatBoost',
'f1_score': F1,
'training auc score': train_auc,
'test auc score': test_auc}
cboost_results

Fast API Implementation

Installing uvicorn and setting up ngrok

!pip install fastapi nest-asyncio pyngrok uvicorn

!mkdir -p /drive/ngrok-ssh
%cd /drive/ngrok-ssh
!wget https://bin.equinox.io/c/4VmDzA7iaHb/ngrok-stable-linux-amd64.zip -O ngrok-stable-linux-amd64.zip
!unzip -u ngrok-stable-linux-amd64.zip
!cp /drive/ngrok-ssh/ngrok /ngrok
!chmod +x /ngrok

Enter token for authentication for ngrok

!/ngrok authtoken # YOUR TOKEN

Create pickle file for Catboost model

Pkl_Filename1 = "model_catboost.pkl"  

with open(Pkl_Filename1, 'wb') as file:
pickle.dump(cboost, file)

BaseModel class in Pydantic

from pydantic import BaseModel

class Attrition(BaseModel):
age: int
dailyrate: int
distancefromhome: int
education: int
environmentsatisfaction: int
hourlyrate: int
jobinvolvement: int
joblevel: int
jobsatisfaction: int
monthlyincome: int
monthlyrate: int
numcompaniesworked: int
percentsalaryhike: int
performancerating: int
relationshipsatisfaction: int
standardhours: int
stockoptionlevel: int
totalworkingyears: int
trainingtimeslastyear: int
worklifebalance: int
yearsatcompany: int
yearsincurrentrole: int
yearssincelastpromotion: int
yearswithcurrmanager: int
businesstravel_Travel_Frequently: int
businesstravel_Travel_Rarely: int
department_ResearchDevelopment: int
department_Sales: int
educationfield_LifeSciences: int
educationfield_Marketing: int
educationfield_Medical: int
educationfield_Other: int
educationfield_TechnicalDegree: int
gender_Male: int
jobrole_HumanResources: int
jobrole_LaboratoryTechnician: int
jobrole_Manager: int
jobrole_ManufacturingDirector: int
jobrole_ResearchDirector: int
jobrole_ResearchScientist: int
jobrole_SalesExecutive: int
jobrole_SalesRepresentative: int
maritalstatus_Married: int
maritalstatus_Single: int
overtime_Yes: int
class Config:
schema_extra = {
"example": {
"age": 43,
"dailyrate": 823,
"distancefromhome": 6,
"education": 3,
"environmentsatisfaction": 1,
"hourlyrate": 81,
"jobinvolvement": 2,
"joblevel": 5,
"jobsatisfaction": 3,
"monthlyincome": 16581,
"monthlyrate": 2571,
"numcompaniesworked": 7,
"percentsalaryhike": 13,
"performancerating": 3,
"relationshipsatisfaction": 4,
"standardhours": 80,
"stockoptionlevel": 0,
"totalworkingyears": 21,
"trainingtimeslastyear": 2,
"worklifebalance": 3,
"yearsatcompany": 16,
"yearsincurrentrole": 12,
"yearssincelastpromotion": 6,
"yearswithcurrmanager": 14,
"businesstravel_Travel_Frequently": 0,
"businesstravel_Travel_Rarely": 1,
"department_ResearchDevelopment": 0,
"department_Sales": 1,
"educationfield_LifeSciences": 0,
"educationfield_Marketing": 1,
"educationfield_Medical": 0,
"educationfield_Other": 1,
"educationfield_TechnicalDegree": 0,
"gender_Male": 1,
"jobrole_HumanResources": 0,
"jobrole_LaboratoryTechnician": 1,
"jobrole_Manager": 0,
"jobrole_ManufacturingDirector": 1,
"jobrole_ResearchDirector": 0,
"jobrole_ResearchScientist": 1,
"jobrole_SalesExecutive": 0,
"jobrole_SalesRepresentative": 1,
"maritalstatus_Married": 0,
"maritalstatus_Single": 1,
"overtime_Yes": 0,
}
}

Setting up FastAPI

from fastapi import FastAPI

from fastapi.middleware.cors import CORSMiddleware

import pickle

app1 = FastAPI()
# app1.add_middleware(
# CORSMiddleware,
# allow_origins=['*'],
# allow_credentials=True,
# allow_methods=['*'],
# allow_headers=['*'],
# )

@app1.on_event("startup")
async def load_model():
global model
model1 = pickle.load(open("model_catboost.pkl", "rb"))
model2 = pickle.load(open("model_xgboost.pkl", "rb"))

@app1.get('/')
async def index():
return {'message': 'This is the homepage of the API '}


@app1.post('/predict')
async def get_employee_attrition(data: Attrition):
received = data.dict()
age = received['age']
dailyrate = received['dailyrate']
distancefromhome = received['distancefromhome']
education = received['education']
environmentsatisfaction = received['environmentsatisfaction']
hourlyrate = received['hourlyrate']
jobinvolvement = received['jobinvolvement']
joblevel = received['joblevel']
jobsatisfaction = received['jobsatisfaction']
monthlyincome = received['monthlyincome']
monthlyrate = received['monthlyrate']
numcompaniesworked = received['numcompaniesworked']
percentsalaryhike = received['percentsalaryhike']
performancerating = received['performancerating']
relationshipsatisfaction = received['relationshipsatisfaction']
standardhours = received['standardhours']
stockoptionlevel = received['stockoptionlevel']
totalworkingyears = received['totalworkingyears']
trainingtimeslastyear = received['trainingtimeslastyear']
worklifebalance = received['worklifebalance']
yearsatcompany = received['yearsatcompany']
yearsincurrentrole = received['yearsincurrentrole']
yearssincelastpromotion = received['yearssincelastpromotion']
yearswithcurrmanager = received['yearswithcurrmanager']
businesstravel_Travel_Frequently = received['businesstravel_Travel_Frequently']
businesstravel_Travel_Rarely = received['businesstravel_Travel_Rarely']
department_ResearchDevelopment = received['department_ResearchDevelopment']
department_Sales = received['department_Sales']
educationfield_LifeSciences = received['educationfield_LifeSciences']
educationfield_Marketing = received['educationfield_Marketing']
educationfield_Medical = received['educationfield_Medical']
educationfield_Other = received['educationfield_Other']
educationfield_TechnicalDegree = received['educationfield_TechnicalDegree']
gender_Male = received['gender_Male']
jobrole_HumanResources = received['jobrole_HumanResources']
jobrole_LaboratoryTechnician = received['jobrole_LaboratoryTechnician']
jobrole_Manager = received['jobrole_Manager']
jobrole_ManufacturingDirector = received['jobrole_ManufacturingDirector']
jobrole_ResearchDirector = received['jobrole_ResearchDirector']
jobrole_ResearchScientist = received['jobrole_ResearchScientist']
jobrole_SalesExecutive = received['jobrole_SalesExecutive']
jobrole_SalesRepresentative = received['jobrole_SalesRepresentative']
maritalstatus_Married = received['maritalstatus_Married']
maritalstatus_Single = received['maritalstatus_Single']
overtime_Yes = received['overtime_Yes']
pred_catboost = model1.predict([[age, dailyrate, distancefromhome,
education, environmentsatisfaction, hourlyrate, jobinvolvement,
joblevel, jobsatisfaction, monthlyincome, monthlyrate, numcompaniesworked,
percentsalaryhike, performancerating, relationshipsatisfaction,
standardhours, stockoptionlevel, totalworkingyears, trainingtimeslastyear,
worklifebalance, yearsatcompany, yearsincurrentrole, yearssincelastpromotion,
yearswithcurrmanager, businesstravel_Travel_Frequently,
businesstravel_Travel_Rarely, department_ResearchDevelopment,
department_Sales, educationfield_LifeSciences, educationfield_Marketing,
educationfield_Medical, educationfield_Other, educationfield_TechnicalDegree,
gender_Male, jobrole_HumanResources, jobrole_LaboratoryTechnician,
jobrole_Manager, jobrole_ManufacturingDirector, jobrole_ResearchDirector,
jobrole_ResearchScientist, jobrole_SalesExecutive,
jobrole_SalesRepresentative, maritalstatus_Married, maritalstatus_Single,
overtime_Yes]]).tolist()[0]

pred_xgboost = model2.predict([[age, dailyrate, distancefromhome,
education, environmentsatisfaction, hourlyrate, jobinvolvement,
joblevel, jobsatisfaction, monthlyincome, monthlyrate, numcompaniesworked,
percentsalaryhike, performancerating, relationshipsatisfaction,
standardhours, stockoptionlevel, totalworkingyears, trainingtimeslastyear,
worklifebalance, yearsatcompany, yearsincurrentrole, yearssincelastpromotion,
yearswithcurrmanager, businesstravel_Travel_Frequently,
businesstravel_Travel_Rarely, department_ResearchDevelopment,
department_Sales, educationfield_LifeSciences, educationfield_Marketing,
educationfield_Medical, educationfield_Other, educationfield_TechnicalDegree,
gender_Male, jobrole_HumanResources, jobrole_LaboratoryTechnician,
jobrole_Manager, jobrole_ManufacturingDirector, jobrole_ResearchDirector,
jobrole_ResearchScientist, jobrole_SalesExecutive,
jobrole_SalesRepresentative, maritalstatus_Married, maritalstatus_Single,
overtime_Yes]]).tolist()[0]
return {'prediction_catboost': pred_catboost,
'prediction_xgboost' : pred_xgboost}

Run FastAPI from Colab

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware

import nest_asyncio
from pyngrok import ngrok
import uvicorn

ngrok_tunnel = ngrok.connect(8000)
print('Public URL:', ngrok_tunnel.public_url)
nest_asyncio.apply()
uvicorn.run(app1, port=8000)

Click on the public url created by ngrok and interact with the app.

--

--