CART

Deniz Gunay
20 min readSep 10, 2023

--

CART (Classification and Regression Trees) is a popular machine learning algorithm used for both classification and regression tasks. It is a type of decision tree algorithm that recursively splits the data into subsets based on the values of input features, ultimately creating a tree-like structure that can be used for prediction.

CART — (Classification Tree)

Advantages of CART:

  1. Interpretability: Decision trees, including CART, are highly interpretable models. You can easily understand and visualize the decision-making process of the model, making it a valuable tool for explaining predictions to stakeholders.
  2. No Assumptions about Data: CART does not make any assumptions about the distribution of data or the relationships between features. It can handle both categorical and numerical features without requiring feature scaling.
  3. Handles Non-linearity: CART can capture complex non-linear relationships in the data, which makes it suitable for a wide range of problems.
  4. Feature Importance: CART can provide information about feature importance, helping you identify which features are most relevant for making predictions.
  5. Versatility: It can be used for both classification and regression tasks, making it a versatile algorithm.

Disadvantages of CART:

  1. Overfitting: CART decision trees are prone to overfitting, especially when they are allowed to grow too deep. Overfitting occurs when the tree becomes too complex and fits the training data noise, resulting in poor generalization to new, unseen data.
  2. Instability: Small changes in the training data can lead to significantly different tree structures. This can result in instability, making the model sensitive to variations in the data.
  3. Bias Towards Dominant Classes: In classification tasks with imbalanced classes, CART may produce biased trees that favor the majority class.
  4. Greedy Algorithm: CART uses a greedy approach, which means it makes locally optimal decisions at each split. This may not always lead to globally optimal trees.

Some Hyperparameters for CART:

  1. max_depth: This parameter controls the maximum depth or depth limit of the decision tree. It can help prevent overfitting by limiting the tree's complexity.
  2. min_samples_split: It specifies the minimum number of samples required to split an internal node. Increasing this value can prevent the tree from splitting too early, reducing the likelihood of overfitting.
  3. min_samples_leaf: This parameter sets the minimum number of samples required to be in a leaf node. It can be used to control the tree's size and prevent small leaf nodes that capture noise.
  4. max_features: It determines the maximum number of features considered for splitting at each node. It can help reduce the model's complexity and improve generalization.
  5. criterion: This parameter defines the impurity measure used for splitting nodes. For classification, common options are "gini" for Gini impurity and "entropy" for information gain. For regression, "mse" (mean squared error) is typically used.
  6. min_impurity_decrease: It specifies a threshold for splitting nodes based on impurity decrease. A split will only be made if it results in impurity decrease greater than this threshold.
  7. class_weight (for classification): It allows you to assign weights to classes, addressing imbalanced class issues.

Hyperparameter tuning is crucial for optimizing the performance of CART models, and it often involves techniques like cross-validation to find the best combination of hyperparameters for your specific problem. max_depth and min_samples_split are often a good starting point for hyperparameter tuning since adjusting these two hyperparameters allows you to find the right balance between model complexity and generalization for your specific problem.

Now let’s examine these two parameters in a little more detail,

max_depth Hyperparameter:

  • max_depth controls the maximum depth or height of the decision tree.
  • A smaller max_depth value creates a shallow tree with fewer splits.
  • A larger max_depth value allows the tree to be deeper and more complex.

Impact of max_depth:

a) Small max_depth (e.g., 1 or 2):
​ ​ ​​​ ​ ​​- Creates a simple decision boundary.
​ ​ ​​​ ​ ​​- May underfit the data as it may not capture complex patterns.
​ ​ ​​​ ​ ​​- Less prone to overfitting.

b) Medium max_depth (e.g., 3 or 4):
​ ​ ​​​ ​ ​​- Balances complexity and ability to capture patterns.
​ ​ ​​​ ​ ​​- Often a good starting point for tuning.

c) Large max_depth (e.g., 10 or more):
​​​ ​ ​​​ ​ ​​- Creates a complex decision boundary.
​ ​ ​​​ ​ ​​- May overfit the data by capturing noise.
​ ​ ​​​ ​ ​​- More prone to overfitting.

min_samples_split Hyperparameter:

  • min_samples_split controls the minimum number of samples required to split an internal node.
  • A smaller min_samples_split value allows nodes to split with fewer samples.
  • A larger min_samples_split value requires more samples for a node to split.

Impact of min_samples_split:

a) Small min_samples_split (e.g., 2):
​ ​ ​​​ ​ ​​- Allows nodes to split even with very few samples.
​ ​ ​​​ ​ ​​- Creates a more complex tree with many splits.
​ ​ ​​​ ​ ​​- More prone to overfitting.

b) Medium min_samples_split (e.g., 10 or 20):
​ ​ ​​​ ​ ​​- Balances complexity and generalization.
​ ​ ​​​ ​ ​​- Often a good choice for preventing overfitting.

c) Large min_samples_split (e.g., 100 or more):
​​​ ​ ​​​ ​ ​​- Requires a significant number of samples for nodes to split.
​ ​ ​​​ ​ ​​- Creates a simpler tree with fewer splits.
​ ​ ​​​ ​ ​​- Helps prevent overfitting.

What is the loss function in CART?

For classification problems, ‘gini’ or ‘entropy’ is generally used, and for regression, ‘mse’ is generally used. It is determined by the criterion parameter. The default criterion value for classification is ‘gini’, the default value for regression is ‘squared_error’.
Gini and entropy are measures of impurity. So, the more diverse the results of the samples in the nodes, the higher the Gini and Entropy. For example, let’s say we have two separate nodes and each node has a sample within. Let’s assume the sample in the first node are (1 0 1 0 1) while the sample in the second node are (1 1 1 1 1). In this case, since the second node has less impurity, in other words, less heterogeneity, the gini and entropy values ​​will be calculated lower compared to the first node.

Lost function (Gini)
Loss function (Entropy)
Loss function (MSE)

Coding

Let’s do some coding by using diabetes dataset!

NOTE: Since we will model a classification decision tree here, we will import DecisionTreeClassifier from SciKit Learn. However, for regression problems we must import DecisionTreeRegressor

################################################
# Decision Tree Classification: CART
################################################

import warnings
import joblib
import pydotplus
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.tree import DecisionTreeClassifier, export_graphviz, export_text
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.model_selection import train_test_split, GridSearchCV, cross_validate, validation_curve
from sklearn.preprocessing import MinMaxScaler, RobustScaler
from sklearn.impute import KNNImputer
from skompiler import skompile
import graphviz

pd.set_option('display.max_columns', None)
pd.set_option('display.width', 500)

warnings.simplefilter(action='ignore', category=Warning)

df = pd.read_csv("datasets/diabetes.csv")




####################################
# FUNCTIONS
####################################

def outlier_thresholds(dataframe ,col ,q1=.05 , q3=.95, decimal=3):
quartile1=dataframe[col].quantile(q1)
quartile3=dataframe[col].quantile(q3)
iqr=quartile3-quartile1
low_limit= round(quartile1 - (iqr*1.5) , decimal)
up_limit= round(quartile3 + (iqr*1.5), decimal)
return low_limit , up_limit



def replace_with_thresholds(dataframe, col_name, q1=.05, q3=.95, lower_limit = None, upper_limit = None):
low_limit, up_limit = outlier_thresholds(dataframe, col_name, q1, q3)
if lower_limit != None:
dataframe.loc[(dataframe[col_name] < lower_limit), col_name] = lower_limit
else:
dataframe.loc[(dataframe[col_name] < low_limit), col_name] = low_limit

if upper_limit != None:
dataframe.loc[(dataframe[col_name] > upper_limit), col_name] = upper_limit
else:
dataframe.loc[(dataframe[col_name] > up_limit), col_name] = up_limit




def plot_importance(model, features, num:int = 0, save=False):
if num <= 0:
num = X.shape[1]
feature_imp = pd.DataFrame({'Value': model.feature_importances_, 'Feature': features.columns})
plt.figure(figsize=(10, 10))
sns.set(font_scale=1)
sns.barplot(x="Value", y="Feature", data=feature_imp.sort_values(by="Value",
ascending=False)[0:num])
plt.title('Features')
plt.tight_layout()
plt.show()
if save:
plt.savefig('importances.png')




def val_curve_params(model, X, y, param_name, param_range, scoring="accuracy", cv=10):
train_score, test_score = validation_curve(
model, X=X, y=y, param_name=param_name, param_range=param_range, scoring=scoring, cv=cv)

mean_train_score = np.mean(train_score, axis=1)
mean_test_score = np.mean(test_score, axis=1)

plt.plot(param_range, mean_train_score,
label="Training Score", color='b')

plt.plot(param_range, mean_test_score,
label="Validation Score", color='g')

plt.title(f"Validation Curve for {type(model).__name__}")
plt.xlabel(f"Number of {param_name}")
plt.ylabel(f"{scoring}")
plt.tight_layout()
plt.legend(loc='best')
plt.show(block=True)




def tree_graph(model, col_names, file_name):
tree_str = export_graphviz(model, feature_names=col_names, filled=True, out_file=None)
graph = pydotplus.graph_from_dot_data(tree_str)
graph.write_png(file_name)








###########################################
# DATA PREPROCESSING
###########################################

#Replacing outliers with these values.
cols = [col for col in df.columns if col != "Outcome"]
for col in cols:
replace_with_thresholds(df, col)



#Columns that cannot contain zero
problematic_cols = [col for col in df.columns if col not in ["Pregnancies",'DiabetesPedigreeFunction','Outcome']]


#Now replace these zeros with NaN
for col in problematic_cols:
df[col]=df[col].replace(0,np.nan)


# Filling NaN values by using KNN Imputer
scaler=MinMaxScaler()
df=pd.DataFrame(scaler.fit_transform(df), columns=df.columns)
imputer=KNNImputer(n_neighbors=5)
df=pd.DataFrame(imputer.fit_transform(df), columns=df.columns)
df=pd.DataFrame(scaler.inverse_transform(df), columns=df.columns)







##################################################
# FEATURE ENGINEERING
##################################################



df.loc[(df["Age"] <= 18 ), "NEW_AGE"] = "young"
df.loc[(df["Age"] > 18 ) & (df["Age"] <= 24), "NEW_AGE"] = "adult"
df.loc[(df["Age"] > 24 ) & (df["Age"] <= 59), "NEW_AGE"] = "mid_adult"
df.loc[(df["Age"] > 59), "NEW_AGE"] = "senior"



df.loc[(df["BMI"] < 18.5) , "BMI_CAT"] ="underweight"
df.loc[(df["BMI"] >= 18.5) & (df["BMI"] < 24.9) , "BMI_CAT"] ="normal"
df.loc[(df["BMI"] >= 24.9) & (df["BMI"] < 29.9) , "BMI_CAT"]="overweight"
df.loc[(df["BMI"] >= 29.9) , "BMI_CAT"] ="obese"



df.loc[(df["Insulin"] < 15) , "INSULIN_CAT"] ="low"
df.loc[(df["Insulin"] >= 15) & (df["Insulin"] < 166) , "INSULIN_CAT"] ="normal"
df.loc[(df["Insulin"] >= 166) , "INSULIN_CAT"] ="high"


# One Hot Encoding
ohe_cols = [col for col in df.columns if 10 >= df[col].nunique() > 2]
df= pd.get_dummies(df,columns= ohe_cols, drop_first=True)


X = df.drop(["Outcome"], axis=1)
y = df["Outcome"]

We have some additional candidate features. We will select some of them by using feature_selecter() function,

def feature_selecter(input_x, y, candidate_features_dict:dict, candidate_features_id:list, best_features:list, best_accuracy=0, verbose=True):
if not candidate_features_id:
return best_accuracy, best_features
best_x = input_x
best_feature= -1
if best_accuracy == 0:
cart_model = DecisionTreeClassifier(random_state=17).fit(input_x, y)

cv_results = cross_validate(cart_model,
input_x, y,
cv=5,
scoring="accuracy")

best_accuracy = cv_results["test_score"].mean()

if verbose:
print(f"best accuracy(old) = {best_accuracy}")
#print(candidate_features_id)

for feature in candidate_features_id:
X = input_x.copy(deep=True)

# define your candidate feature here!
if feature == 0:
X[candidate_features_dict[feature]] = X["Insulin"]*X["Glucose"]

elif feature == 1:
X[candidate_features_dict[feature]] = X["Glucose"]/(X["Insulin"]+0.0001)

elif feature == 2:
X[candidate_features_dict[feature]] = X["Age"]*X["Pregnancies"]

elif feature == 3:
X[candidate_features_dict[feature]] = X["Age"]/(X["Pregnancies"]+0.0001)

elif feature == 4:
X[candidate_features_dict[feature]] = X["Age"]*X["Pregnancies"]*X["Glucose"]

elif feature == 5:
X[candidate_features_dict[feature]] = X["Glucose"]/(X["Age"]+0.0001)

elif feature == 6:
X[candidate_features_dict[feature]] = X["Insulin"]/(X["Age"]+0.0001)

elif feature == 7:
X[candidate_features_dict[feature]] = X["BMI"]*X["Pregnancies"]

elif feature == 8:
X[candidate_features_dict[feature]] = X["BMI"]*X["Age"]

elif feature == 9:
X[candidate_features_dict[feature]] = X["BMI"]*(X["Age"])*X["Pregnancies"]

elif feature == 10:
X[candidate_features_dict[feature]] = X["BMI"]*(X["Glucose"])

elif feature == 11:
X[candidate_features_dict[feature]] = X["DiabetesPedigreeFunction"]*(X["Insulin"])

elif feature == 12:
X[candidate_features_dict[feature]] = X["SkinThickness"]*(X["Insulin"])

elif feature == 13:
X[candidate_features_dict[feature]] = X["Pregnancies"]/(X["Age"]+0.0001)

elif feature == 14:
X[candidate_features_dict[feature]] = X["Glucose"]+X["Insulin"]+X["SkinThickness"]

elif feature == 15:
X[candidate_features_dict[feature]] = X["BloodPressure"]/(X["Glucose"]+0.0001)



cart_model = DecisionTreeClassifier(random_state=17).fit(X, y)

cv_results = cross_validate(cart_model,
X, y,
cv=5,
scoring="accuracy")

accuracy = cv_results["test_score"].mean()
if accuracy > best_accuracy:
best_accuracy = accuracy
best_feature = feature
best_x = X

if best_feature == -1:
return best_accuracy, best_features

best_features.append(best_feature)
candidate_features_id.remove(best_feature)

if verbose:
print(f"best accuracy(new) = {best_accuracy}")
print(f"added feature = {best_feature}", end = '\n\n')
#print(best_features)

return feature_selecter(best_x, y, candidate_features_dict, candidate_features_id, best_features, best_accuracy, verbose)

Then, run feature_selecter() function,

candidate_features = {0:"new_glucoseXinsulin",
1:"new_glucose/insulin",
2:"new_ageXpreg",
3:"new_age/preg",
4:"new_ageXpregXglucose",
5:"new_glucose/age",
6:"new_insulin/age",
7:"new_bmiXpreg",
8:"new_bmiXage",
9:"new_bmiXageXpreg",
10:"new_bmiXglucose",
11:"new_degreeXinsulin",
12:"new_skinXinsulin",
13:"new_preg/age",
14:"new_glucose+insulin+skin",
15:"new_blood/glucose"}

accuracy, new_features = feature_selecter(X,y,candidate_features, list(candidate_features.keys()), best_features=[])
'''
best accuracy(old) = 0.7291401409048468
best accuracy(new) = 0.7539003480179951
added feature = 15

best accuracy(old) = 0.7539003480179951
'''


#See which feature is selected.
for feature in new_features:
print(candidate_features[feature])
#new_blood/glucose



#Add this newly created feature.
X["new_blood/glucose"] = X["BloodPressure"]/(X["Glucose"]+0.0001)






##################################################
# MODEL BUILDING AND EVALUATION
##################################################


#Build the CART model.
cart_model = DecisionTreeClassifier(random_state=17).fit(X, y)


#Evaluate with 5-Fold CV
cv_results = cross_validate(cart_model,
X, y,
cv=5,
scoring=["accuracy", "precision", "recall", "f1", "roc_auc"])

print(f"Accuracy : {cv_results['test_accuracy'].mean()}") # 0.7539
print(f"Precision : {cv_results['test_precision'].mean()}") # 0.6467
print(f"Recall : {cv_results['test_recall'].mean()}") # 0.6605
print(f"F1 Score : {cv_results['test_f1'].mean()}") # 0.6515
print(f"ROC AUC : {cv_results['test_roc_auc'].mean()}") # 0.7322







##################################################
# HYPERPARAMETER OPTIMIZATION
##################################################

#We will use GridSearchCV to find the optimal hyperparameters.
#First, see the current parameters.
cart_model.get_params()
'''
{'ccp_alpha': 0.0,
'class_weight': None,
'criterion': 'gini',
'max_depth': None,
'max_features': None,
'max_leaf_nodes': None,
'min_impurity_decrease': 0.0,
'min_samples_leaf': 1,
'min_samples_split': 2,
'min_weight_fraction_leaf': 0.0,
'random_state': 17,
'splitter': 'best'}
'''




#Let's use GridSearchCV for the max_depth and min_samples_split parameters.
#Initially, max_depth = None and min_samples_split = 2. We will try max_depth
#from 1 to 10, min_samples_split from 2 to 19 and find the best parameter
#values ​​that give the highest accuracy.
cart_params = {'max_depth': range(1, 11),
"min_samples_split": range(2, 20)}

cart_best_grid = GridSearchCV(cart_model,
cart_params,
cv=5,
n_jobs=-1,
verbose=True, scoring="accuracy").fit(X, y)
'''
Fitting 5 folds for each of 180 candidates, totalling 900 fits
'''

print(cart_best_grid.best_params_)
#{'max_depth': 7, 'min_samples_split': 4}




#So it seems when max_depth = 7 and min_samples_split = 4, we obtain the
#highest accuracy. Let's the accuracy after this hyperparameter tuning
print(cart_best_grid.best_score_) # accuracy : 0.7629







################################################
# FINAL MODEL
################################################

#Build the final model
cart_final = DecisionTreeClassifier(**cart_best_grid.best_params_, random_state=17).fit(X, y)

#See our newly updated parameters, max_depth = 7, min_samples_split = 4
cart_final.get_params()
'''
{'ccp_alpha': 0.0,
'class_weight': None,
'criterion': 'gini',
'max_depth': 7,
'max_features': None,
'max_leaf_nodes': None,
'min_impurity_decrease': 0.0,
'min_samples_leaf': 1,
'min_samples_split': 4,
'min_weight_fraction_leaf': 0.0,
'random_state': 17,
'splitter': 'best'}
'''



cv_results = cross_validate(cart_final,
X, y,
cv=5,
scoring=["accuracy", "precision", "recall", "f1", "roc_auc"])

print(f"Accuracy : {cv_results['test_accuracy'].mean()}") # 0.7629
print(f"Precision : {cv_results['test_precision'].mean()}") # 0.6674
print(f"Recall : {cv_results['test_recall'].mean()}") # 0.6681
print(f"F1 Score : {cv_results['test_f1'].mean()}") # 0.6621
print(f"ROC AUC : {cv_results['test_roc_auc'].mean()}") # 0.7667

As a result, the final performance metrics are as follows,
​ ​ ​​​ ​​ ​​​ ​ ​​ ​ ​​​ ​ ​​ ​ ​​​ ​ ​ ​ ​​​ ​​ ​​​ ​ ​​ ​ ​​​ ​ ​​ ​ ​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​​ ​​​ ​ ​​Accuracy : ​​​ ​​​ ​ 0.7629​​​ ​​​
​ ​ ​​​ ​​ ​​​ ​ ​​ ​ ​​​ ​ ​​ ​ ​​​ ​ ​ ​ ​​​ ​​ ​​​ ​ ​​ ​ ​​​ ​ ​​ ​ ​​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​​ ​​Precision : ​​​ ​​​ ​ 0.6674 ​​​ ​​​ ​​
​ ​ ​​​ ​​ ​​​ ​ ​​ ​ ​​​ ​ ​​ ​ ​​​ ​ ​​ ​ ​​​ ​​ ​​​ ​ ​​ ​ ​​​ ​ ​​ ​ ​​​ ​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​Recall​​​ ​​​ ​​​ ​​ ​​​ ​​​ ​​​​​ : ​​​ ​​​ ​ ​​​​​​0.6681​​​ ​​​ ​​​ ​​
​ ​ ​​​ ​​ ​​​ ​ ​​ ​ ​​​ ​ ​​ ​ ​​​ ​ ​ ​ ​​​ ​​ ​​​ ​ ​​ ​ ​​​ ​ ​​ ​ ​​​ ​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​F1 Score ​​​ ​​​​​​​​:​​ ​​​ ​​​ ​ 0.6621 ​​​ ​​​
​ ​ ​​​ ​​ ​​​ ​ ​​ ​ ​​​ ​ ​​ ​ ​​​ ​ ​ ​ ​​​ ​​ ​​​ ​ ​​ ​ ​​​ ​ ​​ ​ ​​​ ​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ ​​ROC AUC​​​ ​​​​​​​​:​​ ​​​ ​​​ ​ 0.7667

################################################
# FEATURE IMPORTANCE
################################################

#See the most important features.
plot_importance(cart_final, X, save=False)
#IMAGE IS BELOW (importance.png)
importance.png
################################################
# ANALYZING MODEL COMPLEXITY WITH LEARNING CURVES
################################################

val_curve_params(cart_final, X, y, "max_depth", range(1, 11), scoring="accuracy")
#IMAGE IS BELOW (validation_curve.png)
validation_curve.png

With GridSearchCV, we found that the most optimal max_depth value is 7. By looking at the validation curve above, you can wonder why the optimal max_depth value is not 8. This is because the above graph depends only on the max_depth parameter, but we used GridSearchCV with both max_depth and min_samples_split parameters. So, at its core, the GridSearchCV algorithm obtained a different table than the table above, and the best hyperparameter values ​​for the GridSearchCV’s table are max_depth = 7 and min_samples_split = 4

Now let’s continue.

################################################
# VISUALIZING THE DECISION TREE
################################################


#We can visualize the decision tree we obtained using the CART algorithm.
#The tree_graph function will create an image named "cart_final.png" in
#the directory.
tree_graph(model=cart_final, col_names=X.columns, file_name="cart_final.png")
# IMAGE IS SAVED AS "cart_final.png" (look at the directory)
cart_final.png
################################################
# EXTRACTING DECISION RULES
################################################

#We can also express the decision tree in the "cart_final.png" picture above
#with text or code.

#Text
tree_rules = export_text(cart_final, feature_names=list(X.columns))
print(tree_rules)
'''
|--- Insulin <= 127.80
| |--- Glucose <= 123.50
| | |--- BMI <= 50.90
| | | |--- Insulin <= 110.10
| | | | |--- DiabetesPedigreeFunction <= 0.75
| | | | | |--- new_blood/glucose <= 0.64
| | | | | | |--- new_blood/glucose <= 0.64
| | | | | | | |--- class: 0.0
| | | | | | |--- new_blood/glucose > 0.64
| | | | | | | |--- class: 1.0
| | | | | |--- new_blood/glucose > 0.64
| | | | | | |--- Insulin <= 36.50
| | | | | | | |--- class: 0.0
| | | | | | |--- Insulin > 36.50
| | | | | | | |--- class: 0.0
| | | | |--- DiabetesPedigreeFunction > 0.75
| | | | | |--- BMI <= 32.70
| | | | | | |--- class: 0.0
| | | | | |--- BMI > 32.70
| | | | | | |--- DiabetesPedigreeFunction <= 0.88
| | | | | | | |--- class: 1.0
| | | | | | |--- DiabetesPedigreeFunction > 0.88
| | | | | | | |--- class: 0.0
| | | |--- Insulin > 110.10
| | | | |--- Insulin <= 110.30
| | | | | |--- class: 1.0
| | | | |--- Insulin > 110.30
| | | | | |--- BMI <= 45.40
| | | | | | |--- BMI <= 33.85
| | | | | | | |--- class: 0.0
| | | | | | |--- BMI > 33.85
| | | | | | | |--- class: 0.0
| | | | | |--- BMI > 45.40
| | | | | | |--- class: 1.0
| | |--- BMI > 50.90
| | | |--- class: 1.0
| |--- Glucose > 123.50
| | |--- BMI <= 26.30
| | | |--- class: 0.0
| | |--- BMI > 26.30
| | | |--- DiabetesPedigreeFunction <= 0.71
| | | | |--- new_blood/glucose <= 0.62
| | | | | |--- BMI <= 27.65
| | | | | | |--- class: 1.0
| | | | | |--- BMI > 27.65
| | | | | | |--- BMI <= 36.20
| | | | | | | |--- class: 0.0
| | | | | | |--- BMI > 36.20
| | | | | | | |--- class: 1.0
| | | | |--- new_blood/glucose > 0.62
| | | | | |--- class: 0.0
| | | |--- DiabetesPedigreeFunction > 0.71
| | | | |--- class: 1.0
|--- Insulin > 127.80
| |--- Glucose <= 154.50
| | |--- NEW_AGE_mid_adult <= 0.50
| | | |--- SkinThickness <= 25.60
| | | | |--- Age <= 69.50
| | | | | |--- class: 0.0
| | | | |--- Age > 69.50
| | | | | |--- class: 0.0
| | | |--- SkinThickness > 25.60
| | | | |--- Insulin <= 154.50
| | | | | |--- BMI <= 42.05
| | | | | | |--- Pregnancies <= 7.00
| | | | | | | |--- class: 0.0
| | | | | | |--- Pregnancies > 7.00
| | | | | | | |--- class: 1.0
| | | | | |--- BMI > 42.05
| | | | | | |--- class: 1.0
| | | | |--- Insulin > 154.50
| | | | | |--- Insulin <= 281.80
| | | | | | |--- Glucose <= 127.50
| | | | | | | |--- class: 0.0
| | | | | | |--- Glucose > 127.50
| | | | | | | |--- class: 1.0
| | | | | |--- Insulin > 281.80
| | | | | | |--- class: 0.0
| | |--- NEW_AGE_mid_adult > 0.50
| | | |--- BMI <= 26.25
| | | | |--- class: 0.0
| | | |--- BMI > 26.25
| | | | |--- Age <= 42.50
| | | | | |--- Insulin <= 133.80
| | | | | | |--- Pregnancies <= 4.00
| | | | | | | |--- class: 1.0
| | | | | | |--- Pregnancies > 4.00
| | | | | | | |--- class: 1.0
| | | | | |--- Insulin > 133.80
| | | | | | |--- Insulin <= 146.30
| | | | | | | |--- class: 0.0
| | | | | | |--- Insulin > 146.30
| | | | | | | |--- class: 1.0
| | | | |--- Age > 42.50
| | | | | |--- Insulin <= 143.80
| | | | | | |--- BloodPressure <= 74.00
| | | | | | | |--- class: 1.0
| | | | | | |--- BloodPressure > 74.00
| | | | | | | |--- class: 0.0
| | | | | |--- Insulin > 143.80
| | | | | | |--- Pregnancies <= 0.50
| | | | | | | |--- class: 0.0
| | | | | | |--- Pregnancies > 0.50
| | | | | | | |--- class: 1.0
| |--- Glucose > 154.50
| | |--- BMI_CAT_overweight <= 0.50
| | | |--- Insulin <= 544.00
| | | | |--- SkinThickness <= 13.50
| | | | | |--- class: 0.0
| | | | |--- SkinThickness > 13.50
| | | | | |--- BMI <= 22.60
| | | | | | |--- class: 0.0
| | | | | |--- BMI > 22.60
| | | | | | |--- BMI <= 46.10
| | | | | | | |--- class: 1.0
| | | | | | |--- BMI > 46.10
| | | | | | | |--- class: 1.0
| | | |--- Insulin > 544.00
| | | | |--- BloodPressure <= 69.00
| | | | | |--- class: 1.0
| | | | |--- BloodPressure > 69.00
| | | | | |--- class: 0.0
| | |--- BMI_CAT_overweight > 0.50
| | | |--- NEW_AGE_mid_adult <= 0.50
| | | | |--- class: 0.0
| | | |--- NEW_AGE_mid_adult > 0.50
| | | | |--- SkinThickness <= 28.90
| | | | | |--- class: 1.0
| | | | |--- SkinThickness > 28.90
| | | | | |--- class: 0.0
'''





#Python
print(skompile(cart_final.predict).to('python/code'))
'''
(((((((0 if x[15] <= 0.6409004628658295 else 1) if x[15] <=
0.6422942876815796 else 0 if x[4] <= 36.5 else 0) if x[6] <=
0.7525000274181366 else 0 if x[5] <= 32.70000076293945 else 1 if x[6] <=
0.8784999847412109 else 0) if x[4] <= 110.0999984741211 else 1 if x[4] <=
110.29999923706055 else (0 if x[5] <= 33.849998474121094 else 0) if x[5
] <= 45.39999961853027 else 1) if x[5] <= 50.89999961853027 else 1) if
x[1] <= 123.5 else 0 if x[5] <= 26.300000190734863 else ((1 if x[5] <=
27.65000057220459 else 0 if x[5] <= 36.20000076293945 else 1) if x[15] <=
0.6183468401432037 else 0) if x[6] <= 0.7114999890327454 else 1) if x[4
] <= 127.79999923706055 else (((0 if x[7] <= 69.5 else 0) if x[3] <=
25.59999942779541 else ((0 if x[0] <= 7.0 else 1) if x[5] <=
42.04999923706055 else 1) if x[4] <= 154.5 else (0 if x[1] <= 127.5 else
1) if x[4] <= 281.8000030517578 else 0) if x[8] <= 0.5 else 0 if x[5] <=
26.25 else ((1 if x[0] <= 4.0 else 1) if x[4] <= 133.8000030517578 else
0 if x[4] <= 146.29999542236328 else 1) if x[7] <= 42.5 else (1 if x[2] <=
74.0 else 0) if x[4] <= 143.8000030517578 else 0 if x[0] <= 0.5 else 1) if
x[1] <= 154.5 else ((0 if x[3] <= 13.5 else 0 if x[5] <=
22.59999942779541 else 1 if x[5] <= 46.10000038146973 else 1) if x[4] <=
544.0 else 1 if x[2] <= 69.0 else 0) if x[11] <= 0.5 else 0 if x[8] <=
0.5 else 1 if x[3] <= 28.899999618530273 else 0)
'''




#SQL
print(skompile(cart_final.predict).to('sqlalchemy/sqlite'))
'''
SELECT CASE WHEN (x5 <= 127.79999923706055) THEN CASE WHEN (x2 <= 123.5) THEN CASE WHEN (x6 <= 50.89999961853027) THEN CASE WHEN (x5 <= 110.0999984741211) THEN CASE WHEN (x7 <= 0.7525000274181366) THEN CASE WHEN (x16 <= 0.6422942876815796) THEN CASE WHEN (x16 <= 0.6409004628658295) THEN 0 ELSE 1 END ELSE 0 END ELSE CASE WHEN (x6 <= 32.70000076293945) THEN 0 ELSE CASE WHEN (x7 <= 0.8784999847412109) THEN 1 ELSE 0 END END END ELSE CASE WHEN (x5 <= 110.29999923706055) THEN 1 ELSE CASE WHEN (x6 <= 45.39999961853027) THEN 0 ELSE 1 END END END ELSE 1 END ELSE CASE WHEN (x6 <= 26.300000190734863) THEN 0 ELSE CASE WHEN (x7 <= 0.7114999890327454) THEN CASE WHEN (x16 <= 0.6183468401432037) THEN CASE WHEN (x6 <= 27.65000057220459) THEN 1 ELSE CASE WHEN (x6 <= 36.20000076293945) THEN 0 ELSE 1 END END ELSE 0 END ELSE 1 END END END ELSE CASE WHEN (x2 <= 154.5) THEN CASE WHEN (x9 <= 0.5) THEN CASE WHEN (x4 <= 25.59999942779541) THEN 0 ELSE CASE WHEN (x5 <= 154.5) THEN CASE WHEN (x6 <= 42.04999923706055) THEN CASE WHEN (x1 <= 7.0) THEN 0 ELSE 1 END ELSE 1 END ELSE CASE WHEN (x5 <= 281.8000030517578) THEN CASE WHEN (x2 <= 127.5) THEN 0 ELSE 1 END ELSE 0 END END END ELSE CASE WHEN (x6 <= 26.25) THEN 0 ELSE CASE WHEN (x8 <= 42.5) THEN CASE WHEN (x5 <= 133.8000030517578) THEN 1 ELSE CASE WHEN (x5 <= 146.29999542236328) THEN 0 ELSE 1 END END ELSE CASE WHEN (x5 <= 143.8000030517578) THEN CASE WHEN (x3 <= 74.0) THEN 1 ELSE 0 END ELSE CASE WHEN (x1 <= 0.5) THEN 0 ELSE 1 END END END END END ELSE CASE WHEN (x12 <= 0.5) THEN CASE WHEN (x5 <= 544.0) THEN CASE WHEN (x4 <= 13.5) THEN 0 ELSE CASE WHEN (x6 <= 22.59999942779541) THEN 0 ELSE 1 END END ELSE CASE WHEN (x3 <= 69.0) THEN 1 ELSE 0 END END ELSE CASE WHEN (x9 <= 0.5) THEN 0 ELSE CASE WHEN (x4 <= 28.899999618530273) THEN 1 ELSE 0 END END END END END AS y
FROM data
'''



#MS Excel
print(skompile(cart_final.predict).to('excel'))
#A1=IF((x5<=127.79999923 ...922 chars skipped... 9618530273),1,0)))))








################################################
# PREDICTION USING PYTHON CODES
################################################

#We can directly predict an X value we enter by the tree rules we obtain.
#The X value here can be a list, it does not need to be a pandas dataframe.
def predict_with_rules(x):
return (((((((0 if x[15] <= 0.6409004628658295 else 1) if x[15] <=
0.6422942876815796 else 0 if x[4] <= 36.5 else 0) if x[6] <=
0.7525000274181366 else 0 if x[5] <= 32.70000076293945 else 1 if x[6] <=
0.8784999847412109 else 0) if x[4] <= 110.0999984741211 else 1 if x[4] <=
110.29999923706055 else (0 if x[5] <= 33.849998474121094 else 0) if x[5
] <= 45.39999961853027 else 1) if x[5] <= 50.89999961853027 else 1) if
x[1] <= 123.5 else 0 if x[5] <= 26.300000190734863 else ((1 if x[5] <=
27.65000057220459 else 0 if x[5] <= 36.20000076293945 else 1) if x[15] <=
0.6183468401432037 else 0) if x[6] <= 0.7114999890327454 else 1) if x[4
] <= 127.79999923706055 else (((0 if x[7] <= 69.5 else 0) if x[3] <=
25.59999942779541 else ((0 if x[0] <= 7.0 else 1) if x[5] <=
42.04999923706055 else 1) if x[4] <= 154.5 else (0 if x[1] <= 127.5 else
1) if x[4] <= 281.8000030517578 else 0) if x[8] <= 0.5 else 0 if x[5] <=
26.25 else ((1 if x[0] <= 4.0 else 1) if x[4] <= 133.8000030517578 else
0 if x[4] <= 146.29999542236328 else 1) if x[7] <= 42.5 else (1 if x[2] <=
74.0 else 0) if x[4] <= 143.8000030517578 else 0 if x[0] <= 0.5 else 1) if
x[1] <= 154.5 else ((0 if x[3] <= 13.5 else 0 if x[5] <=
22.59999942779541 else 1 if x[5] <= 46.10000038146973 else 1) if x[4] <=
544.0 else 1 if x[2] <= 69.0 else 0) if x[11] <= 0.5 else 0 if x[8] <=
0.5 else 1 if x[3] <= 28.899999618530273 else 0)




#However, we added some new columns during feature engineering. For example
#NEW_AGE, new_blood/glucose. While we originally had 8 independent columns,
#now we have 16 independent columns. Therefore, an X value that we enter as
#input must have these new features. That's why we write a function called
#converter() and this function converts the original 8 feature sample
#into 16 features.
def converter(inp):
if inp[7] <= 18:
inp += [1,1]
elif (inp[7] > 18) & (inp[7] <= 24):
inp += [0,0]
elif (inp[7] > 24) & (inp[7] <= 59):
inp += [1,0]
elif inp[7] >= 60:
inp += [0,1]

if inp[5] < 18.5:
inp += [0,0,1]
elif (inp[5] >= 18.5) & (inp[5] < 24.9):
inp += [0,0,0]
elif (inp[5] >= 24.9) & (inp[5] < 29.9):
inp += [0,1,0]
elif inp[5] >= 29.9:
inp += [1,0,0]

if inp[4] < 15:
inp += [1,0]
elif (inp[4] >= 15) & (inp[4] < 166):
inp += [0,1]
elif inp[4] >= 166:
inp += [0,0]

inp.append(inp[2]/(inp[1]+0.0001))
return inp





#Let's make an example.
x = [12, 13, 20, 23, 4, 55, 12, 7]
x = converter(x)
print(x)
# [12, 13, 20, 23, 4, 55, 12, 7, 1, 1, 1, 0, 0, 1, 0, 1.5384497042330443]
print(predict_with_rules(x)) # 1





#Another example.
x = [6, 148, 70, 35, 0, 30, 0.62, 50]
x = converter(x)
print(x)
# [6, 148, 70, 35, 0, 30, 0.62, 50, 1, 0, 1, 0, 0, 1, 0, 0.4729726533968558]
print(predict_with_rules(x)) # 0









################################################
# SAVING AND LOADING MODEL
################################################

#Let's say you created a good model and you want to share that model with
#other people. In this case, you can save the final model as a pkl file and
#share with other people.
#Save model as pkl file
joblib.dump(cart_final, "cart_final.pkl")



#You can also load a model with a pkl extension in your directory.
#Load model from pkl file
cart_model_from_disc = joblib.load("cart_final.pkl")



#And you can use the model..
x = [12, 13, 20, 23, 4, 55, 12, 7]
x = converter(x)
#since you are using a sklearn model, you also need to convert it pandas df.
x = pd.DataFrame(x).T
print(cart_model_from_disc.predict(x)) # [1.]

--

--