Building interpretable models with Generalized additive models in Python

Ian SHEN
Just another data scientist
6 min readDec 21, 2018

Aim

Building machine/deep learning models that produce high accuracy is getting easier, but when it comes to interpretability, most of them are still far from good. In many cases, you might need to put more emphasis on understanding the models than accuracy. Your cutting-edge deep learning models become inapplicable.

As a powerful yet simple technique, generalized additive model (GAM) is underrepresented. Few data scientists know it or apply it in their daily work, especially in Python. In this article, you’ll see how to build generalized additive models in Python, and how to use its partial dependence functions to inspect the contribution of each feature. We used two public data sets to build two GAM models: one for classification and another for regression. The pyGAM package is used to train GAM and all our executable codes are available on Colab.

This post briefly explains the theory behind GAMs. For more in-depth knowledge, I highly recommend you to read an excellent article by Kim Larsen. Moreover, there is a practical blog by Pablo Oberhauser on getting started GAMs. The main purposes of this article are as follows:

  • GAM 101
  • How to build GAMs with pyGAM
  • Demonstrate the interpretability of GAMs in both regression and classification

GAM 101

To understand how to build GAM, we need to know something about the structure of GAM and some important concepts in it. The structure of GAM can be written as:

  • g(E(Y)) is the link function that links the expected value to the predictor variables x1,x2,…,xm. It tells how the expected value of the response relates to the predictor variables. GAM support multiple link functions.
  • f1(x1) + f2(x2) + … +fm(xm) is the functional form with an additive structure which consists of a number of terms f1(x1) , f2(x2) , … ,fm(xm). The terms denote smooth, non-parametric functions.
  • Distribution refers to the distribution of the response variable Y. It could be any distribution from the exponential family, such as Gaussian, binomial Poisson and etc.

GAM allows us to easily examine the partial relationships between the response variable and the predictors. First of all, its addictive nature ensures that the marginal impact of a single variable does not depend on the others in the model. Moreover, its ability to control the smoothness of the predators can help us obtain a more clear relationship. Partial dependence plots are used to demonstrate the partial relationships.

How to build GAMs with pyGAM

pyGAM is a package for building GAM in Python. To my best knowledge, it might be the only Python package available for GAM. pyGAM is on pypi, and can be installed using pip:

pip install pygam

To train a GAM with pyGAM, we need to specify the link function, the functional form and the distribution as follows:

from pygam import GAM, s, fgam = GAM(s(0, n_splines=5) + s(1) + f(2) + s(3), distribution=’gamma’, link=’log’)

pyGAM also has the built-in common models with which GAMs can be easily created. The common models are LinearGAM, LogisticGAM, PoissonGAM, GammaGAM, InvGuss. The model training is simplified as:

from pygam import PoissonGAMgam = PoissonGAM(s(0, n_splines=5) + s(1) + f(2) + s(3))

Automatic model tuning with `gridsearch()`

Find the best model requires the tuning of several key parameters including n_splines, lam, and constraints. Among them, lam is of great importance to the performance of GAMs. It controls the strength of the regularization penalty on each term. pyGAM built a grid search function that build a grid to search over multiple lam values so that the model with the lowest generalized cross-validation (GCV) score.

Partial dependence plots

pyGAM supports partial dependence plot with matplotlib. The partial dependence for each term in a GAM can be visualized with a 95% confidence interval for the estimation function.

More information about pyGAM is here.

Build Interpretable GAMs

Regression

This dataset is about red variants of the Portuguese ‘Vinho Verde’ wine, available from the UCI machine learning repository. Input features are 11 physicochemical variables describe the red wine variants from various aspects. The target feature is the quality score, ranging from 0 to 10, that indicates how good the red wine is.

An overview of the data set

Prepare the data

redwine_X = redwine.drop(['quality'], axis=1).values
redwine_y = redwine['quality']

Build the model via gridsearch

lams = np.random.rand(100, 11)
lams = lams * 11 - 3
lams = np.exp(lams)
print(lams.shape)
gam = LinearGAM(n_splines=10).gridsearch(redwine_X, redwine_y, lam=lams)

Partial dependency plots

titles = redwine.columns[0:11]
plt.figure()
fig, axs = plt.subplots(1,11,figsize=(40, 8))
for i, ax in enumerate(axs):
XX = gam.generate_X_grid(term=i)
ax.plot(XX[:, i], gam.partial_dependence(term=i, X=XX))
ax.plot(XX[:, i], gam.partial_dependence(term=i, X=XX, width=.95)[1], c='r', ls='--')
if i == 0:
ax.set_ylim(-30,30)
ax.set_title(titles[i])
Partial dependence plots showing factors that affect red wine quality (see Colab)

So far, we have build a linear GAM that could predict the red wine quality score based on the physicochemical variables. More importantly, how each of these physicochemical variables affects the quality score is revealed in the above partial dependence plots. As shown above, volatile acidity, chlorides, total sulfur dioxide, density, and pH have a negative correlation with the quality score, meaning the higher value is, the lower the quality score is. On the other hand, the quality score increases as the values of residual sugar and free sulfur dioxide getting larger. We also notice that fixed acidity have litter influence on the quality score. The impacts of citric acid, sulphates, alcohol are more complex. For instance, the optimal alcohol level is around 13. Value higher or lower than that brings down the quality score.

Classification

The data set contains 30 features that describe characteristics of the cell nuclei presented in the breast mass images. They are computed from 10 descriptors of a cell nucleus, including radius, texture, perimeter and etc.. Besides, each record is labeled as malignant (M) or bengn (B).

An overview of the data set

Prepare the data

# drop the id column
tumors = tumors.drop(['id'],axis=1)
# encode the diagnosis column
tumors.loc[tumors['diagnosis']=='M','diagnosis'] =1
tumors.loc[tumors['diagnosis']=='B','diagnosis'] =0
tumors_X = tumors.iloc[:,:11].drop(['diagnosis'], axis=1).values
tumors_y = tumors['diagnosis']

Build the model with LogisticGAM

log_gam = LogisticGAM(n_splines=10).gridsearch(tumors_X, tumors_y)

Check the accuracy of the trained model

log_gam.accuracy(tumors_X, tumors_y)0.9578207381370826

Partial dependency plots

titles = tumors.columns[1:11]
plt.figure()
fig, axs = plt.subplots(1,10,figsize=(40, 8))
for i, ax in enumerate(axs):
XX = log_gam.generate_X_grid(term=i)
ax.plot(XX[:, i], log_gam.partial_dependence(term=i, X=XX))
ax.plot(XX[:, i], log_gam.partial_dependence(term=i, X=XX, width=.95)[1], c='r', ls='--')
if i == 0:
ax.set_ylim(-30,30)
ax.set_title(titles[i])
partial dependency plots for the breast mass images classification (See Colab)

The partial dependency plots uncover the interpretability of the GAM model. Variables that have positive correlation with the response variable include: mean radius, mean texture, mean area, mean smoothness, mean concave points, and mean symmetry. The higher value is, the more likely it is to be malignant. The higher mean perimeter means less likely it is to be malignant.

Conclusion

As a data scientist, you should add GAM to your arsenal. Its advantage in interpretability could very useful in many scenarios. Hopefully, this article helps you know the technique and try it in your work.

--

--