Emulating R regression plots in Python
(Originally posted on my personal blog)
Recently, as a part of my Summer of Data Science 2017 challenge, I took up the task of reading Introduction to Statistical Learning cover-to-cover, including all labs and exercises, and converting the R labs and exercises into Python. While I’m still at early chapters, I’ve learned a lot already. Some commands can be straightforward replicated in Python, some are surprisingly hard to find equivalents without using custom functions etc. (not that it’s needed to have “exact” equivalents, as python also has powerful features unique to it, but I’m thinking this as more of a learning opportunity).
One of the simplest R commands that doesn’t have a direct equivalent in Python is
plot() for linear regression models (wraps
plot.lm() when fed linear models). While python has a vast array of plotting libraries, the more hands-on approach of it necessitates some intervention to replicate R’s
plot(), which creates a group of diagnostic plots (residual, qq, scale-location, leverage) to assess model performance when applied to a fitted linear regression model.
Let’s see the example in R with the Auto dataset:
Let’s start with the necessary imports and setup commands:
Loading the data, and getting rid of
The fitted linear regression model, using statsmodels R style formula API:
Calculations required for some of the plots:
And now, the actual plots:
1. Residual plot
First plot that’s generated by
plot() in R is the residual plot, which draws a scatterplot of fitted values against residuals, with a “locally weighted scatterplot smoothing (lowess)” regression line showing any apparent trend.
This one can be easily plotted using seaborn
residplot with fitted values as
x parameter, and the dependent variable as
lowess=True makes sure the lowess regression line is drawn. Additional parameters are passed to underlying matplotlib scatter and line functions using
line_kws, also titles and labels are set using matplotlib methods. The
; in the end gets rid of the output text
<matplotlib.text.Text at 0x000000000> at the top of the plot 1. Top 3 absolute residuals are also annotated:
2. QQ plot
This one shows how well the distribution of residuals fit the normal distribution. This plots the standardized (z-score) residuals against the theoretical normal quantiles. Anything quite off the diagonal lines may be a concern for further investigation.
For this, I’m using
ProbPlot and its
qqplot method from statsmodels graphics API. statsmodels actually has a qqplot method that we can use directly, but it’s not very customizable, hence this two-step approach. Annotations were a bit tricky, as theoretical quantiles from
ProbPlot are already sorted:
3. Scale-Location Plot
This is another residual plot, showing their spread, which you can use to assess heteroscedasticity.
It’s essentially a scatter plot of absolute square-rooted normalized residuals and fitted values, with a lowess regression line. Scatterplot is a standard matplotlib function, lowess line comes from seaborn
regplot. Top 3 absolute square-rooted normalized residuals are also annotated:
4. Leverage plot
This plot shows if any outliers have influence over the regression fit. Anything outside the group and outside “Cook’s Distance” lines, may have an influential effect on model fit.
statsmodels has a built-in leverage plot for linear regression, but again, it’s not very customizable. Digging around the source of the
statsmodels.graphics package, it’s pretty straightforward to implement it from scratch and customize with standard matplotlib functions. There are three parts to this plot: First is the scatterplot of leverage values (got from statsmodels fitted model using
get_influence().hat_matrix_diag) vs. standardized residuals. Second one is the lowess regression line for that. And the third and the most tricky part is the Cook’s distance lines, which I currently couldn’t figure out how to draw in Python. But statsmodels has Cook’s distance already calculated, so we can use that to annotate top 3 influencers on the plot:
Update: I think I figured out how to draw Cook’s distance (D) contours for D = 0.5 and D = 1. The trick was rearranging the formula to plot the lines at 0.5 and 1.
I learned a lot, and continuing to learn during this code porting. It’ll hopefully be helpful for others struggling with similar issues.