Interpretable deep learning models for tabular data — Neural GAMs

An intro to Generalized Additive Models (GAMs), and its recent deep-learning versions.

Chun-Hao Chang
6 min readMay 4, 2022

When approaching a new tabular dataset, I always use Explainable Boosting Machine — a GAM model — to quickly get an accurate, interpretable model that helps me quickly understand the data. A GAM can help answer if there’s anything wrong with (1) the preprocessing, (2) the missingness imputations, (3) or spurious correlations. It often has similar accuracy to Xgboost but it’s way more informative.

But what if you need some neural net machinery such as finetuning or gradient, or you have a large dataset? Often EBM is too slow or crashes when datasets have more than 10k features. Thus I will introduce the recent two papers about deep-learning GAMs:

Before we start, let’s talk about what’s a GAM.

What are Generalized Additive Models (GAMs)?

Here I show the functional forms of several different model classes:

  • The linear model is just a linear combination of weights and features.
  • Generalized Linear Model (GLM) puts an additional g called the linked function on the target y to generalize from Linear Regression to Binary Classification or Posson Regression.
  • Generalized Additive Model (GAM), instead of doing a linear combination of weights and features, we put a non-linear function f(xd) for each feature d. Note that there is still no pairwise feature interaction allowed. Each feature has to operate on its own to generate a score, and sum together.
  • GA2M is like a GAM but it allows up to pairwise feature interactions, and there are no 3rd or higher-order interactions allowed.
  • Finally, a full-complexity model such as a decision tree or deep neural network can use any feature interaction in the data.

Why GAM is interpretable?

GAM is interpretable because each individual feature term can be plotted as a graph. For example, in the feature term f(xj), we can plot the x-axis as the feature value xj, and the y-axis shows the functional value f(xj). Below is an example from Caruana et al, 2016 that they find an interesting pattern from a pneumonia dataset.

A GAM graph for a feature age on a pneumonia dataset where the target is the mortality risk of the patient. There is a strong increase of risk around 65 which is likely due to the lifestyle changes before and after retirement. And the jump in 80 and the drop in 100 are likely that clinicians treat differently below and above the thresholds.

Above we show a GAM graph of the factor f(age), and we can see the risk score change around age 65, 80, and 100 which we can guess there is a retirement effect and treatment effects.

Another famous example is that having asthma lowers the risk of dying in this dataset. It’s likely that the clinicians give a higher quality of care to patients with asthma and prevent bad outcomes which make those patients appear low risk. This is obviously undesirable in test time and a GAM can help uncover and fix it.

A GAM graph on feature asthma that shows having asthma lowers the risk of mortality. This is likely because clinicians give a higher quality of care to patients with asthma and prevent bad outcomes which make those patients appear low risk.

Neural Additive Models: Interpretable Machine Learning with Neural Nets (NeurIPS 2021)

This paper proposes the NAM architecture, which trains a multi-layer perceptron (MLP) for each feature and sums the outputs across features in the end:

One of the key problems they find is that the non-linear activation function Relu used in MLP is too smooth as seen in the GAM graph below. So they propose a new activation function ExU, and we can see that compared to Relu (left), the ExU (right) becomes jumpier.

But they realize sometimes ExU can be too overfitting, so they instead train multiple neural nets with different random seeds and take the average. This helps “smooth” out the final curve and achieves better accuracy, as shown in the following.

By averaging the predictions of multiple NAMs, the prediction becomes much smoother.

Remarks

  • NAM is computationally expensive. It builds an MLP per feature and also requires training of 10s-100s models to take the average. These might prevent it from working with more than 100 features.
  • It’s not able to model pairwise interaction effects like GA2M, which could be important in some datasets when pairwise interactions are important.
  • Finally, its curve after average could be too smooth. It’s argued in this paper that capturing quick changes is important in GAM to capture important real-world anomalies like missingness imputation or treatment effects.

NODE-GAM: Neural Generalized Additive Model for Interpretable Deep Learning (ICLR 2022)

NodeGAM is another neural-based GAM that improves from NAM in 2 ways: (1) it uses differentiable trees which allow quick, non-linear changes, (2) it uses attention to do feature selection that can scale to a large number of features, and (3) model pairwise feature interactions called NodeGA2M.

The key differences between a NodeGAM and a multi-layer perceptron are:

  • It uses differentiable oblivious decision trees (ODT) instead of neurons to learn since trees help learn a more jumpy shape in the GAM graphs.
  • In the input layer, instead of summing all the input features and going through a Relu function, NodeGAM uses differentiable attention with temperature annealing to take only 1 feature. This makes sure there is no feature interaction in the model.
  • For the connections between layers, NodeGAM uses differentiable gates that only connect trees that belong to the same (set of) features. This also prevents feature interactions.
  • Finally, it uses the DenseNet-like skip connections that take all the previous layers’ outputs as inputs. Also, in the output layer, it takes all the intermediate layers embedding as the inputs to the final linear layer as outputs. This helps the gradient to flow through the model since the tree response function is similar to the sigmoid that has a gradient vanishing problem.

The following is a graphical illustration of NodeGAM.

The NodeGAM architecture. It uses differentiable trees instead of neurons, and each tree is only allowed to take in one feature. Therefore, there is no pairwise interaction allowed in the model which remains as a GAM.

NodeGA2M uses a similar idea to learn pairwise interactions but removes any 3rd or higher-order interactions. This is achieved by limiting each tree to take at most 2 features. And the connections are only allowed among trees with the same sets of features.

The NodeGA2M architecture. Note there are no 3rd or higher feature interactions allowed in the model. It only has at most 2 colors.

Finally, we also incorporate an attention mechanism between the layers called GAMAtt that further improves the accuracy.

NodeGAM Results

NodeGAM outperforms the Explainable Boosting Machine (EBM), a tree-based GAM, in 6 large datasets (millions of samples and thousands of features) below when considering pairwise interaction effects (GA2M). NodeGAM can improve up to 7% compared to EBM.

We show 6 large datasets: 3 classifications and 3 regressions. The error rate is shown in the top 3 datasets, and the Mean Squared Error (MSE) is shown for the bottom 3 datasets. The lower number is better. Rel Imp is the relative improvement of NodeGAM over EBM. The dash indicates an out-of-memory error during the model training. In GA2M, the NodeGAM improves more than EBM up to 7%.

Explainability of NodeGA2M — a case study on the Bikeshare dataset

Here is an example of the explainability of NodeGA2M. The Bikeshare dataset is to predict the Bikeshare rental counts in Washington D.C. from 2012 to 2013 by weather, time and seasons etc. The strongest pairwise interaction term learned by NodeGA2M is shown below. We find there are more people renting bikes (blue) from 9–10 AM and 4–5 PM from Monday to Thursday. Interestingly, on Friday the spikes happen around 10 AM and 3–4 PM, which shows people get off work earlier on Friday. Finally, on weekends more people rent bikes around noon. We note that these insights are hard to get from Xgboost, Shapley value, or attention-based methods.

NodeGAM code is available here (written in PyTorch): https://github.com/zzzace2000/nodegam

Takeaway

GAM is very interpretable and accurate in most tabular datasets. In this blog post, I introduced what’s a GAM and its application, and two recent endeavors to adopt a neural net into a GAM.

Some of the GAMs future research include:

  • How do we adapt to the time-series settings to capture the trend?
  • How to handle correlated features? If that happens, the GAM graphs become very similar and some grouping among features might be needed.
  • Can we apply the GAM to other settings such as Reinforcement Learning or the distribution shift detection?

That’s it for today!

--

--