How feature importance is calculated in Decision Trees? with example
Understanding the math behind
Decision Tree is amongst the most popular ML algorithms which are used as a weak learner for most of the bagging & boosting techniques, be it RandomForest or Gradient Boosting.
My debut book “LangChain in your Pocket” is out now
Note: Basics around Decision Trees is required to move ahead
A great advantage of the sklearn implementation of Decision Tree is feature_importances_ that helps us understand which features are actually helpful compared to others.
If you are a vlog person:
But why do we need feature importance in the first place?
For this, you need to visit my last post on why Interpretability & Explainability in AI is important and what they can be the consequences if ignored.
Hope you read the above post, we can now proceed to understand the maths behind feature importance calculation. First of all, assume that
We have a binary classification problem to predict whether an action is ‘Valid’ or ‘Invalid’
We have got 3 feature namely Response Size, Latency & Total impressions
We have trained a DecisionTreeclassifier on the training data
The training data has 2k samples, both classes with equal representation
So, we have a trained model already with us. Now we will jump on calculating feature_importance. But before that let’s see the structure of the decision tree we have trained
The code snippet for training & preprocessing has been skipped as this is not the goal of the post. Though, the below code snippet can help you visualize your trained model as above
import matplotlib.pyplot as plt
from sklearn import tree#dt_model is a DecisionTreeClassifier object
#df is training dataset fig = plt.figure(figsize=(15,7))_ = tree.plot_tree(dt_model,feature_names = df.columns,
filled=True)
Understanding a few parameters
Gini index is used as impurity measure.
Value in the above diagram is the total sample left from both the classes at every node i.e if value=[24,47], the current node received 24 samples from class 1 & 47 from class 2.
Calculating feature importance involves 2 steps
Calculate importance for each node
Calculate each feature’s importance using node importance splitting on that feature
So, for calculating feature importance, we need to 1st calculate every node’s importance in the Decision Tree.
How to do that?
Importance_Nodeₖ =
(%_of_sample_reaching_Nodeₖ X Impurity_Nodeₖ -
%_of_sample_reaching_left_subtree_NodeₖX Impurity_left_subtree_Nodeₖ-
%_of_sample_reaching_right_subtree_NodeₖX Impurity_right_subtree_Nodeₖ) / 100
Let’s calculate the importance of each node (going left →right, top →bottom)
- 1st Node: value=[1000,1000]
=(100 x 0.5 — 52.35 x 0.086 —47.65 x 0) / 100
=(50–4.5)/100 = 0.455
How did we get 100, 52.35 & 47.65 in the above equation?
100*2000/2000=100%
100*1047/2000(left subtree) = 52.35%
100*953/2000(left subtree) = 47.65%
Follow the same logic for rest of the nodes
- 2nd Node: value=[1000,47]
=( 52.35 x 0.086 — 48.8 x 0 — 0.035 x 0.448)/100
= 4.48/100=0.0448
- 3rd Node: value=[24,47]
= ( 0.035 x 0.448–0.024x0.041–0)
= 0.014/100 = 0.00014
: 4th Node: Value=[1,47]
=(0.024x0.041–0–0)/100
=0.0000098
The 1st step is done, we now move on to calculating feature importance for every feature present.
The formula is simple,
feature importance for feature K=
Σnode’s importance splitting on feature K / Σ all node’s importance
That means the numerator is a summation of node importances of all nodes that split on a particular feature K upon summation of all node importances
Importance Total Impressions :
As only 1st Node split on Total Impressions, in the numerator we will consider only node importance of 1st Node
=0.455/ (0.00098+0.00014+0.0448+0.455)
=0.455/0.5=0.91
Importance of Total Response Size :
Considering 2nd & 3rd Node in the numerator
(0.048+0.00014)/ (0.00098+0.00014+0.0448+0.455)
=0.062/0.5=0.124
Importance Latency:
0.00098/ (0.00098+0.00014+0.0448+0.455)
=0.00098/0.5=0.002
Hence, we can see that Total impressions are the most critical feature followed by Total Response Size.
Does our answer match the one given by python?
As we can see, the value looks lumpsum the same in the bar plot.