Neural Network Feature Importance with fastai

Miguel Pinto
3 min readAug 6, 2018


Structured data problems are very common in machine learning. A good example is to predict sales based on a wide range of predictors like store, item, day of week, distance to the nearest competitor or even weather forecasts for the day. Variables like store, item and day of week are treated as categorical variables whereas variables like temperature or distance to the nearest competitor are treated as continuous variables.

Fastai library is based on PyTorch and has all the tools to make the process of fitting a model of this kind (and many others) very easy to follow. You can see this notebook for a great example and also the Practical Deep Learning for Coders course for the detailed explanation.

Assuming you are already familiar with fastai library this is how I implemented a function to compute feature importance:

The basic idea is to randomly shuffle a column at the time and compute the loss with the shuffled column. Then the feature importance is given by the difference between the loss obtained when the respective column is shuffled and the original loss.

In the above code the loss0 is the original loss. Then we loop for each categorical variables (line 13) and for each continuous variable (line 20). Inside the main loops we have another loop that iterates over each mini-batch, computes the loss for each mini-batch and stores it in a list. Then the feature importance for that variable (line 20 and 26) is the drop in performance.

The rationale is that when the column is shuffled the relationship between that feature and the output of the model is broken and therefore if the loss increases a lot (i.e. the model predictions are less accurate) then that variable must be an important predictor. Conversely, if the loss remains almost equal, that predictor is not doing anything and can probably be removed without decrease in accuracy or could even result on a better accuracy after retraining the model.

In practice, assuming you already have m, md, cat_vars and cont_vars and the code above, you just need to run:

fi = nn_feat_importance(m, md, cat_vars, cont_vars)
plot(‘cols’, ‘imp’, ‘barh’, figsize=(12,7), legend=False)

The result should be something like the following image:

Figure 1. Feature importance example

The image above is what I’ve obtained for the Store Item Demand Forecasting Challenge playground competition on Kaggle. Here the goal is to predict sales for several stores and items for the next 3 months after the given period. Figure 1 shows that the item and the store are the most important features, followed by the day of the week, month, and so on. Variables like day and is_month_start/end are probably not important for this particular problem.

Note however that we are shuffling one variable at the time when computing the feature importance. You cannot remove all the variables with low feature importance and expect the same performance of the model. It can happen but it’s not the rule. Multicollinearity is an aspect to consider here.



Miguel Pinto

PhD student (Remote sensing, Meteorology), ML/DL enthusiast, fastai student, competition master at Kaggle, pianist/composer

Recommended from Medium


See more recommendations