Detailed Explanation of Random Forests Features importance Bias

Mohammed Saad
7 min readDec 7, 2021

--

Many Data Science practitioners use Random Forest for their experiments, predictions, and decision making since it is a powerful model which can be quite accurate, easy to understand, and can be used with most of the data types out there, without the need to do feature scaling or handling imbalanced data.

In addition to Random Forest prediction ability, it can return metrics that can help in interpreting the model, knowing which features are most predictive, those metrics are called “Feature Importance”.

Random Forest’s Feature Importance is widely used among Data Scientists, and it’s highly probable that you used it before, but didn’t you notice before that ranking features by their importance didn’t match your domain knowledge in some cases?

Feature Importance is biased

Scikit-learn published an article explaining this issue, I am not going to redo that experiment but rather I will try to explain the results to you and help you to confidently understand how feature importance is calculated.

In scikit-learn’s article, it shows that by adding a random feature to the existing list of features, Random Forest ranked that feature as the most important feature.

Shocking, right?

scikit-learn experiment that shows a random feature being ranked as #1 feature

Well, it turns out that Random Forest Feature importance is biased towards high cardinal features, cardinal features are features with a high number of categories ( number of unique values in that feature).

Now it’s time to revisit all the business decisions you took before based on those metrics, but before doing that let’s understand why Random Forest’s Feature importance is biased.

How does Random Forest Work?

Random Forest consists of multiple Decision Trees that work together to predict the output, so in order to understand Random Forest, we need to understand how does Decision Tree work.

If I am going to explain a Decision Tree, I would explain it in three steps.

1- Pick a Feature

2- Pick a Splitting point

3- Iterate

Let’s pick a simple example and try to apply those steps, imagine that we have 1’s and 0’s and each 1 or 0 can take the color blue or red, and we want to identify 1’s and 0’s by their colors, take a look at the following image to understand.

at first, we started with some 1’s and 0’s with colors blue and red, we identified that we have two features [color, number ]

At the first level we picked the color feature, we determined the splitting point ( red or blue ), and for the second level (iteration), we picked the number feature and determined the splitting point ( > 0 or ≤ 0 ), and by following that procedure we were able to identify each number by its color.

Here we were able to pick the best feature and the best split by ourselves because it is just a simple example, but for a real-world example, we need to have an automated role that we can rely on to pick that for us.

Gini Impurity

Let's imagine that you have a bucket full of bananas and apples, we say that your bucket is impure because it has more than one class of items, it’s not purely full of bananas or purely full of apples.

Gini impurity is a way to determine how purely your bucket is.

So in our bucket example, we have two classes ( bananas, apples ), let's say that there are 4 bananas |S1| and 3 apples |S2|, so we have 7 items in total and that is |S|.

We can calculate Gini impurity of the bucket by applying the above equation as follows.

G(S) = 3/7 (1- 3/7) + 4/7 (1- 4/7) = 0.48

So, now we know that the Gini impurity of the bucket is 0.48, but is that high or low?

Let’s take a look at the Gini Impurity Plot.

It states that the highest Gini Impurity you can get is 0.5 and that happens when the probability of each class in your bucket is the same, so if you have 3 bananas and 3 apples, you get the highest Gini impurity, but if you have more bananas even by 1, it means that there is a dominating class and your bucket is becoming pure even by a small degree, so the more items you have in a dominating class the less impure your bucket is.

The impurity becomes 0 when there is one class of items in your bucket, ex: only bananas.

How does Decision Tree use Gini Impurity?

The decision tree uses Gini Impurity to determine which split will make the bucket purer, it uses the following equation to determine the impurity of the split.

the Gini impurity of a split

Since in every level of the decision tree we split the tree into a right subtree and left subtree, we combine the Gini impurity of those splits and then we add them together by weighting every split by the number of items it has.

For every feature, it tries every split and calculates the Gini impurity, compares it to the current impurity, and picks the best one that reduces the impurity.

And we say that this split on that feature decreased the impurity by x.

MDI

Random forest uses MDI to calculate Feature importance, MDI stands for Mean Decrease in Impurity, it calculates for each feature the mean decrease in impurity it introduced across all the decision trees while constructing them.

That seems logical, but why this way of calculating the Feature importance is biased towards high cardinal features?

Let's look at the following example

In this example, we have two features [x1, x2 ] and one output [y], where feature x1 has 10 unique values, feature x2 has 2 unique values, and the output can be either 1 or 0.

If we calculate the impurity of our current bucket without doing any split we will get 0.5, wonder why? We have two classes 1’s and 0’s and there are 5 items per class, so we get max Gini Impurity.

If we pick feature x2 and we try to determine a splitting point, we can notice that we have only one splitting point that is between 4 and 5.

and this splitting point splits our bucket into two purely buckets each one containing only 1’s or 0’s.

So the Gini Impurity of this split will be 0.

Now if we pick feature x1, there are many splitting points but the best splitting point will be between 5 and 6, giving us an impurity of 0.

So now the only difference between those features is that one has more splitting points than the other, but they both decrease the impurity by the same amount.

Did you notice it ?, if not don’t worry, just follow me.

What if we swapped two ones with two zeros, so we get the following

If we calculate the Gini impurity of the new data, we will notice it won’t change, we still have two classes and each one of them has the same number of items, so the Gini impurity will be 0.5.

Now let's pick feature x2, and try to find the best splitting point, oops we have no choice, we only have one splitting point between 4 and 5.

Lets calculate the Gini impurity of that split

Left SubTree :

P(class = 1 ) = 2/5 = 0.4 , P(class = 0 ) = 3/5 = 0.6

G(left subtree) = 0.4*(0.6) + 0.6*(0.4) = 0.48

Right SubTree:

P(class = 1 ) = 3/5 = 0.6 , P(class = 0 ) = 2/5 = 0.4

G(left subtree) = 0.6*(0.4) + 0.4*(0.6) = 0.48

G(split) = 0.5 * 0.48 + 0.5 * 0.48 = 0.48

So that split decreased the impurity by 0.02.

Now let's take a look at the feature x1, it has many splitting points, and I am gonna pick this splitting point, I argue that it is the best one.

Let's calculate the Gini impurity for that split.

Left SubTree :

The Gini Impurity is just 0 , its purely consist of only 0’s, but to prove it lets follow the calculations.

P(class = 1 ) = 0/5 = 0 , P(class = 0 ) = 3/3 = 1

G(left subtree) = 0*(1) + 1*(0) = 0

Right SubTree:

P(class = 1 ) = 5/7 = 0.71 , P(class = 0 ) = 2/7 = 0.29

G(left subtree) = 0.71*(0.29) + 0.29*(0.71) = 0.41

G(split) = 3/10 * 0 + 7/10 * 0.41 = 0.29

So feature x1 was able to decrease the impurity from 0.5 to 0.29.

Can you see know why feature importance using MDI is biased? It’s because the more splitting points you have in your feature the more probably you will be able to get a better splitting point, even if that feature is totally random.

Better options? Permutation Importance and drop by column are more reliable than MDI feature importance, and can be used with any machine learning model not only Random forests.

--

--

Mohammed Saad

Machine Learning Engineer | Data Science | Data Engineering | Distributed Systems