Interpretable AI: Linear Regression

Shruti Misra
9 min readJun 6, 2023

--

Linear regression is one of the simplest predictive methods in AI/ML. The model predicts the output as a weighted sum of the inputs. The formula for linear regression is:

y = w₀+ wx₁ +…wxₚ + ϵ,

where y is the output, w represents the weights (w₀ is the intercept), x represents the input features and ϵ is the error between the predictions and the true outcome. To fit a linear regression model, the goal is to find the optimal set of weights that fit the given inputs features (x) and outputs (y). One of the most common methods used to estimate the weights is known as ordinary least squares (OLS). I will not go into the details of OLS, but the crux of the method is to find the set of weights that minimizes the squared difference between the actual and the estimated outputs.

Interpretation

Weights: A property that makes the OLS estimation of weights easy is linearity. Linearity is also what makes linear regression relatively easy to interpret, because each weight indicates the influence an input feature has on the output. If the weight has a positive value, then it means that increasing the corresponding feature by one unit increases the estimated output by its weight and vice versa for negative weights. The interpretability comes with some nuance depending on whether the input features are numerical or categorical. The change per unit increase/decrease by weight is specifically interpretable for numerical features. For categorical features, changing the feature from a reference category to another category changes the outcome by the corresponding weight value. Additionally, if all the input features are 0 (or at the reference category for categorical variables), then the output value is represented by the intercept (w₀).

R²: R² is an important measure indicating the goodness-of-fit for linear regression models. It tells us how much of the total variance in the output is explained by the model. R² values range between 0 and 1. An R² of 0 indicates that the model does not explain the data at all and an R² of 1 indicates that the model explains all the variance in the output. R² values are not perfect and should not be used as the sole indicator of goodness-of-fit. A bad/irrelevant model can have high R² values and vice versa. Therefore, R² should be used in conjunction with other methods such as checking the residual plots for bias and domain knowledge.

Feature Importance (t-statistic): The t-statistic is the estimated weight of an input feature scaled with its standard error (t-statistic = weight/standard error in weight estimate). The more important a feature, the more its weight and the higher its t-statistic. On the other hand, the higher the standard error (uncertainty in the weight estimate), the lower the t-statistic and lower the importance of the input feature.

Example

In this example, I used the Medical Cost Personal Dataset from Kaggle, to examine the relationship between insurance charges (y) and a number of independent variables such as age, sex, bmi, number of children, smoking and region. The dataset consisted of a total of 1338 rows, which I split into a training (66%) and a test (33%) dataset. I used dummy encoding to encode all categorical variables (sex, smoker, region). To discuss the interpretability of linear regression, I used the training dataset to fit a linear regression model and analyze its weights, R² and t-statistic.

The above table is the statsmodel summary of the linear regression model for insurance charges based on a variety of independent variables in Python. Let’s look at the interpretation for a numerical variable, namely age. The weight (coef in the table) for age is 261.57, which indicates that for an increase in age by one year, the insurance cost rises by $261.57. The same interpretation can be applied to the bmi variable, except in this case the model indicates that the insurance charges increase by $347.10 for increase in bmi for 1.

For categorical features, such as sex, the data was dummy coded. If the sex was labeled as ‘male’, then a value of 1 was assigned and 0 otherwise. In this case, the reference category is non-male individuals. The weight for sex is 121.12. This means that if the sex variable changes from the reference category, in this case goes from female to male, the insurance charges increase by $121.12. Similarly, in the case of a person being a smoker, the charges increase by $23700 for a smoker when compared to a non-smoker. For regions, the ‘northeast’ was assumed to be the reference category (I picked this as it was the first region in the list, any other region could also be the reference. In a real-world setting, for example if a government entity was analyzing the insurance costs to its population, then the reference region might be the region that the government entity is based in and compare the costs to other regions in the nation). Therefore, we can see that the insurance charges drop for all other regions when compared to the northeast.

The adjusted R² value for the model is 74.3% indicating that the model explains 74.3% of the variance in the insurance charges. We look at the adjusted-R² value because the raw R² scales with the number of features. So, the more features you add to the model, the higher R² will be, even if the features don’t add any meaningful information to the model. The adjusted R² adjusts for the number of features.

Whether the individual is a smoker has the highest t-statistic, indicating this input feature’s importance in predicting the insurance charges. Age is the next important feature in influencing how much an individual may pay for insurance in the context of this dataset.

The above graph displays the weights (coefficient) values along with the standard error around each of the weights. From the graph, it is clear that smoking has a strong positive effect on the insurance costs. This is also indicated by the t-static value which serves as an indicator of feature importance (see graph below). The above plot is not the best way to visualize the results because they are measured on different scales. One way to remedy this would be to scale the features themselves (zero mean and stdev = 1) before fitting the model.

Christopher Molnar’s book has another interesting way of visualizing the interpretation of linear regression and that is through effect plots. To generate effect plots, the weights of the model are multiplied by the corresponding feature values to understand how much the combination of the weights and features contribute to the prediction. Given below is the effect plot of the insurance dataset.

In the graph, the boxplots show the effect that a feature has on the prediction. For example, age, bmi and smoking have a broad range of contribution to the prediction. The effect of age on insurance costs increases as the age increases. This trend is clearly visible when age is plotted against insurance costs (given below).

Now, let’s pick in individual (local) prediction and try to use what we know about the model to reason about the insurance costs. Consider the following three individuals in the data:

Individual 1: 
age 19.0
bmi 27.9
children 0.0
sex 0.0
smoke 1.0
northwest 0.0
southeast 0.0
southwest 1.0
insurance charge: 16884.924

Individual 2:
age 19.00
bmi 32.11
children 0.00
sex 0.00
smoke 0.00
northwest 1.00
southeast 0.00
southwest 0.00
insurance charge: 2130.6759

Individual 3:
age 60.000
bmi 36.005
children 0.000
sex 0.000
smoke 0.000
northwest 0.000
southeast 0.000
southwest 0.000
insurance charges: 13228.84695

Let’s take a look at their effect plots.

The first two individuals are 19 year old males, without any children. Individual 1 is a smoker, whereas Individual 2 is obese (based on healthy ranges of bmi). The insurance cost between the two is very different and it is likely that being a smoker (non-smoker) is the main reason in the large different in insurance costs for both males. The third individual is a 60 year old male who is also obese and has no children. The cost difference between this individual and Individual 1 (similar profile except for age) is very likely due to age (the second most influential variable according to the t-statistic). What is interesting here to note is the cost difference between Individual 2 and 3. The cost to Individual 2 is still higher than Individual 3 despite Individual 3’s higher age. This reinforces that smoking is a more influential variable than age when it comes to insurance costs incurred.

Predictions and Beyond

Now that we have insight into how the linear regression model works, let’s use the model to predict insurance costs using the test data. To do that, I used sklearn’s LinearRegression() function. The RMSE for the fitted model was 5923.7, which means that the model is off by a weighted average of $5923.7. This is not good, as the average value of the insurance charges in the training data is $13379.69.

The question then becomes, is linear regression a good fit to analyze the given dataset. We can start by checking the assumptions that underlie linear regression.

A correlation matrix of the features shows that the the correlation between different features is pretty low, indicating that multicollinearity may not be a problem. Next, we look at the assumptions of normality, that is, check whether the residuals are normally distributed. To do this, we look at the QQ plots.

The QQ plot shows that the residuals deviate from the line quite a bit, indicating the residuals may violate the assumption of normality. Additionally, we test for homoscedasticity to ensure that the variance of the residuals is constant. To do that, we just plot the residuals in a scatterplot. Ideally, the residuals should be uniformly spread.

From the above scatterplot, we can see that the residuals are in fact not uniformly spread and may have some clustering and may thus violate the assumption of homoscedasticity.

Just by testing out some assumptions, it is evident that the data provided may not fulfill all the assumptions of a linear regression model and that is OK! Real data is rarely ever linear and there are ways to transform data in a way to use linear regression in a more informed manner (such as generalized linear models and generalized additive model).

Conclusion

Linear regression is a fundamental method for predictive analysis, on which various other methods are built. In this post, I went over how to interpret linear regression models to understand why the predict what they predict. Linear regression models are pretty easy to understand and have a large body of expertise around them, making them easy to learn, teach and implement. Mathematically, they are easy enough to estimate, with a closed form solution and a guarantee to find the optimal weights. However, they are also very basic and break if the data has non-linearity or interactions between features. In their native form, they often don’t have a good predictive performance between real data is usually non-linear and consists of interactions between features and complexity that linear models are unable to handle. This is where the tradeoff between precision and interpretability comes into play. While linear regression is an easy to understand predictive model, the interpretability may come at the cost of performance. Therefore, it is upto the researcher/data scientist or whoever is working with these models to decide what is tolerable in terms of interpretability and performance.

Code

The notebook for this and the dataset can be found on my Github repository.

--

--