Explaining Model Pipelines With InterpretML

Marius Vadeika
11 min readFeb 9, 2020

--

Preface

Statistics and machine learning

Model interpretation (or explainability) recently gained attention in the field of machine learning because ML models are very accurate, but often times it’s relevant to also know what is happening behind the scenes when model makes predictions. This topic actually involves both statistics and machine learning, so lets start with a theoretical discussion and then jump to a practical application.

Classical statistics can be categorized into descriptive and inferential. Descriptive statistics deals with exploring the observed sample data while statistical inference is concerned with making propositions about a population using data drawn from the population with some form of sampling.

Machine learning is sometimes called applied or glorified statistics and there are many more opinions. There is some overlap between statistics and machine learning, but distinctions between these fields are mostly in application. In 2001 paper Statistical Modeling: The Two Cultures Leo Breiman separated the field of statistical modeling into two communities:

  1. Data modeling. This community relies on assumptions that there is a stochastic data model and typically is focused on estimating parameters from data and using the model for information. Evaluation is based on goodness-of-fit tests and residual analysis.
  2. Algorithmic modeling. This community treats the data mechanism to be unknown. Input x variables are used to predict response y variable. Evaluation is based on predictive accuracy [1].
Two purposes in analyzing data: prediction (predict response variable y) and information (extract information how input variables x are associated with response variable y). Leo Breiman, Statistical Modeling: The Two Cultures.

Machine learning cares about model performance, but explainability falls in learning information category (first community — statistics). Then maybe there is something to be learned from statisticians? R. A. Fisher specified three main aspects to consider for a valid inference:

  1. Model specification. George E. P. Box noted that “all models are wrong, but some are useful.” This step involves setting candidate models and selecting one or more models. Picking candidate models is not straightforward, typically it’s a subjective process which comes from experience.
  2. Estimation of model parameters. Model parameters are estimated or learned from the data (internal configuration) and are required to make predictions (for ex. linear model estimates coefficients, neural network estimates weights)[2]. In practical terms, estimating model parameters in Python will usually require to call .fit() method of a model.
  3. Estimation of precision. Here it is important to set aside some data for model evaluation and select error metrics to measure model’s accuracy. ML algorithms may require hyperparameter tuning. Hyperparameters are usually specified by the user (external configuration) and they help estimate model parameters. Modern ML algorithms contain from few to hundreds of hyperparameters. A standard approach for selecting hyperparameters is to run k-fold cross validation and optimize a desired metric [3].

In statistics the goal of modeling is approximating and understanding the data-generating process. It’s clear that if different ML algorithms (SVM, k-NN, XGBoost, etc.) would be fitted on the same data set there would also be differences in how they explain the data. None of them are regarded as the “true model” which generated the data, but rather an abstraction of empirical data, which can help answer certain questions. Moreover:

“The model does not represent a belief about or a commitment to the data generation process. Its purpose is purely functional. No ML practitioner would be prepared to testify to the “validity” of a model; this has no meaning in Machine Learning, since the model is really only instrumental to its performance.”[4]

To sum up, machine learning cares about performance and many times feature relationships with the target are considered to be a blackbox. But as soon as model explainability becomes important, then more care should be taken (put your statistician’s hat). For instance, some statistical models assume feature independence, while a Data Scientist may deem this to be unimportant. If there are 10 important X predictors which have a high degree of multicollinearity (lets’ say Pearson correlation ρ>0.9), then what may happen is that they will divide the importance of the same underlying driver and it will be more complicated to discern inner workings of such model than of the one having less features.

InterpretML

InterpretML framework connects the process of model explanation whether models are explainable or blackbox.

InterpretML is a new framework backed by Microsoft with a simple idea to bring existing model interpretation frameworks in one place and package them in a practical way. It works with both types of models: glassbox (interpretable — linear models, decision trees, etc.) and blackbox (non interpretable — gradient boosting, SVM, etc.).

Another interesting thing which InterpretML brings is an implementation of a glassbox model — Explainable Boosting Machine (EBM). Authors claim that it is designed not only to be interpretable, but also to have a comparable accuracy to such popular algorithms like Random Forest, XGBoost, etc. You can learn more about its mechanics in InterpretML repository and their research paper [5, 6].

EDA

Data

For the following experiment Predict Ad Clicks data set from hackerearth was selected. To download the data visit the link (login required) and clone this repository to access the code.

Predict Ad Clicks variable descriptions.

For the sake of experiment, a smaller sample was selected — 100,000. This data set is convenient to experiment on because it has only a handful (5 categorical and 4 numerical) predictors to perform a binary classification. More details:

  • ID is completely unique and object type, it won't be used;
  • datetime contains detailed time and can be preprocessed in different ways to maximize its utility;
  • siteid, offerid, category, merchant will be used as numeric columns (even though they actually represent high cardinality categories);
  • countrycode, browserid, devid are categorical and therefore will be preprocessed and converted to numeric.
  • click is the target variable which will be classified. 1 refers to instances where it was clicked.

Data sample is randomly split into training (70%) and holdout (30%) sets. The random split maintains an approximately the same proportion of target variable in both sets (96% — 0 and 4% — 1).

Numerical features

In this step lets just briefly check feature variation against target variable. Popular visualization libraries like matplotlib and seaborn could be used but that would be too mainstream for a blog post. Lets try something different. For example, Altair visualization library has somewhat intimidating syntax, but produces good-looking plots.

For numerical features (again, they are actually just high cardinality categorical features) density plots should do the trick. With visualization libraries it’s sometimes simpler to write a loop than to make a facet plot:

Density plots with Altair library.

Altair density plots produce quite choppy distributions (other libraries smooth them out more by default). Distributions are more or less aligned across numerical features, but there is a notable bump in category between 35,000 and 50,000 values:

Numerical features distributions colored by target variable (whether clicked).

Categorical features

Categorical features with low cardinality can be conveyed with a heatmap. Since the target click is imbalanced (there are significantly less clicks than no clicks) the key here is to display relative frequency by group (0 and 1 click categories sum to 1 respectively if there are no missing values):

Heatmaps displaying categorical features against relative frequency by clicks.

In this case, categorical features seem to be really helpful in separating clicks vs no clicks. For instance, all these factors appear to contribute to more clicks:

  • c and d countrycode;
  • Google Chrome and InternetExplorer browserid;
  • and Desktop devid.

Note that this data set is most likely not fully cleaned because same browsers are named differently as separate categories (IE, Internet Explorer, etc.), but here this will not be accounted for.

Model pipeline

Typically model pipeline is more complex than just calling .fit() and .predict() methods. Therefore, lets construct a simple, but yet realistic pipeline. It’s important to look ahead while doing that. If any methods in model interpretation framework do not support missing values or categorical features, then this should be addressed before the model is built. For instance, InterpretML’s ExplainableBoostingClassifier will raise the following error if any of the features contain missing values:

ValueError: Missing values are currently not supported.

Pipeline will consist of the following steps:

  1. Transform datetime feature to numeric (day, hour, etc.) so that a model can make the full use of the date variable. For that, a custom DateTransformer() will be used (inspired by this blog post and StackOverflow answer).
  2. Categorical features are required to be encoded in a numerical representation and there are various ways to do that. Instead of creating many dummy variables, lets see if boosting algorithms are able to extract useful patterns themselves. Here OrdinalEncoder() was selected from category_encoders library because scikit-learn version doesn’t support missing values or out of sample encoding.
  3. In the data set there are 3 features which contain missing values and often times an ML algorithm will not support them. One option is to impute missing values. Since in this data set there is a mixture of categorical and numerical features, a single solution to impute with mean, median or mode may have unwanted side effects. For this experiment, SimpleImputer(fill_value=0) should do the trick by replacing missing values with 0 (again, let the algorithm do all the thinking).
  4. After doing data transformations, the final step in the pipeline is the predictive algorithm (also called an estimator). For comparison reasons two pipelines will be built, one will contain ExplainableBoostingClassifier and the other LGBMClassifier.
Pipeline for ExplainableBoostingClassifier.

Model validation

After building two pipelines it’s important to check whether the models are useful at prediction. The goal is to build explainable models which are not fitted to noise (overfit data) and are able to generalize out of sample. Models which approximate data well, should also provide relevant insight into underlying structure of data.

Without getting into accuracy metric details, one handy method can be used to validate model — scikit-learn’s classification_report() . In this case, it was used for both models on training (70k) and holdout (30k) sets and compressed into a single data frame:

Accuracy metrics of two models — LightGBM (lgb) and Explainable Boosting Machine (ebm).

Training set metrics are displayed here just to see how well did the model learn training data. LightGBM can easily overfit smaller data sets and here it shows higher metric values on the training set than Explainable Boosting Machine.

On the other hand, both models show very similar accuracy on the test set. This is an imbalanced classification task and the majority class (0) are nearly perfectly predicted while minority class (1) is predicted reasonably well.

Model interpretation

In this section InterpretML and SHAP Python libraries will be tested on previously created pipelines. The unfortunate part is that model interpretation frameworks refuse to be fully compatible with pipelines and require workarounds.

Lets start with InterpretML, it has a useful ClassHistogram() , which enables doing some EDA on data. There is a caveat. In order to use this, the data set can’t contain missing values, which means it has to undergo pipeline steps. Therefore, lets create a training set which is preprocessed and then visualize it.

from interpret import show
from interpret.data import ClassHistogram
X_t_prep = pd.DataFrame(data=pipeline_ebm[0:3].transform(X_t), columns=feature_names)
hist = ClassHistogram().explain_data(X_t_prep, y_t, name = 'Train Data')
show(hist)

This creates a dashboard which displays Plotly interactive histograms colored by click counts. The supplied preprocessed data set contains only numeric features, for examplebrowserid now is in a numeric representation from 1 to 12. Two of browserid categories show a noticeably high percentage of clicks (but with this encoding it’s unclear which ones):

ClassHistogram() output.

Next thing to check is global explanations with both InterpretML and SHAP. Since pipeline is not supported directly, the estimator has to be extracted in each case:

# InterpretML
ebm_global = pipeline_ebm['model'].explain_global()
show(ebm_global)
# SHAP
explainer = shap.TreeExplainer(pipeline_lgb['model'])
shap_values = explainer.shap_values(X_t_prep)
shap.summary_plot(shap_values, X_t_prep, plot_type="bar", plot_size=(10,5))

Feature importance summary shows that two categorical features countrycode and browserid are very important. Their ability to separate clicks was already seen in EDA section. There is definitely a slight disagreement between both estimators in feature importance. What they both agree is that month is the worst performing feature without any importance. There is a good reason for that — it has no variance (all clicks happened in same month).

Feature importance: ExplainableBoostingClassifier explained with InterpretML vs LGBMClassifier explained with SHAP.

Now lets look at single feature influence on the target variable. Typically clear relationships can be seen for the most important features. To view this, InterpretML requires to call the global dashboard again, while SHAP has dependence_plot method:

#InterpretML
show(ebm_global)
#SHAP
shap.dependence_plot(ind="countrycode", shap_values=shap_values[0], features=X_t_prep, interaction_index=None)

In this particular case, InterpretML has a more appealing visualization, but they both tell the same story (just in opposite directions). SHAP dependence plot tells that 1 and 5 country codes strongly influence the click prediction downwards to 0. InterpretML also shows that the same country codes have the strongest influence.

Single feature explanations: ExplainableBoostingClassifier with InterpretML vs LGBMClassifier with SHAP

And finally, for local explanation comparison, one random observation was selected.

ind = [69950]
# InterpretML
ebm_local = pipeline_ebm['model'].explain_local(X_t_prep.iloc[ind], y_t.iloc[ind], name='Local')
show(ebm_local)
# SHAP
shap.initjs()
shap.force_plot(explainer.expected_value[0], shap_values[0][ind,:], X_t_prep.iloc[ind])

EBM predicted 0.35 and LightGBM 0.88, while true value was 1 — clicked. Both SHAP and InterpertML plots display that countrycode was the main driver in their respective explanations. day = 17 has some effect in the decision, but this data sample covers only one month and in a more realistic application constructing a variable something likedayofweek should be more useful.

Local explanations: ExplainableBoostingClassifier with InterpretML vs LGBMClassifier with SHAP

The downside of SHAP’s so called “force plot” is that feature names which had the smallest impact are not visible.

End notes

In this blog post it was briefly discussed what it takes to build a pipeline which is explainable. To summarize:

  • In the beginning there is data. Any preprocessing steps will affect model pipeline, predictions and explanations. When building an explainable model it’s up to the analyst to consider things like incosistent categorical feature labels, determining meaning and usefulness of a feature, missing values treatment, etc. Because this will be seen in the end result.
  • An ordinal encoder doesn’t require creating many columns in the model pipeline but it will make the interpretability afterwards more difficult. There would be more clarity with dummy variable encoding.
  • A feature with no variance monthreally didn’t make any decisions in LightGBM or Explaining Boosting Machine (EBM) models. In this particular case this feature could be dropped but it would probably become an interesting feature in the real world once there are more months of data.
  • This library introduces EBM algorithm which has a comparable accuracy to LightGBM and is designed to be interpretable with InterpretML framework. EBM has a slower training time but it makes predictions fast which can be a crucial factor for a model in production.
  • InterpretML library also provides tools like ClassHistogram, which can be useful to conduct EDA, but you may need to sample data. Since it's interactive and runs under Plotly, it may quickly become something that takes long to load as data grows.

Currently InterpretML is in Alpha release and already shows some compelling features. Some functionality is missing, but this shouldn’t discourage to start testing it. In the roadmap there is planned support for missing values, improvements for categorical encoding, R language interface and more.

Thanks for reading, do you have examples/suggestions for improvements (code can be accessed here)? Share them in the comments section, thanks! Also, I am happy to connect on LinkedIn!

References

  1. Statistical Modeling: The Two Cultures
  2. What is the Difference Between a Parameter and a Hyperparameter?
  3. Model Selection and Multimodel Inference: A Practical Information Theoretic Approach, Second Edition (Ch. 1)
  4. Machine Learning vs. Statistics
  5. InterpretML repository
  6. InterpretML: A Unified Framework for Machine Learning Interpretability

--

--

Marius Vadeika

Senior Data Scientist at Beyond Analysis. If you like my content consider following me on Medium and connecting on LinkedIn: linkedin.com/in/marius-vadeika/