Introduction to Linear Regression — With Implementation in Python from Scratch

Asra Khalid
The Startup
Published in
8 min readJun 8, 2020

Linear Regression is one of the most basic Machine Learning algorithms that every data scientist should know. If you want to get a short introduction to machine learning and its types, you can read my previously written article.

What is regression?

Regression is used to build a model to predict the dependent attributes from a bunch of attribute variables. In regression, we only use continuous variables.

What is linear regression?

Linear Regression is a supervised machine learning algorithm used on continuous data. In linear regression, we establish a linear relationship between the dependent and independent variables based on the best-fitted line. Technically, in regression analysis, the independent variable is usually called the predictor variable, and the dependent variable is called the predicting variable.

Linear Regression

There are mainly two types of linear regression: Simple Linear Regression and Multilevel Linear Regression or Multiple Linear Regression.

In simple linear regression, we have one dependent and one independent variable. Whereas, in multiple linear regression, there is one dependent variable and multiple independent variables to find a best-fitted line.

What type of regression are there?

There are many different types of regression. The specific family of regressions we’ll be learning are called “generalized linear models”. The important thing for you to know is that with this family of models, you need to pick a specific type of regression you’re interested in. The type of regression will depend on what type of data you’re trying to predict.

  • Linear: When you’re predicting a continuous value. (What temperature will it be today?)
  • Logistic: When you’re predicting which category your observation is in. (Will a customer leave or not)
  • Poisson: When you’re predicting a count value. (How many ice creams will a store sale this year?)

Linear Regression Formula:

The formula used for finding linear regression is:

Y = bo + b1x + e

This formula is also known as cost function, where,

  • Y is the dependent variable that we want to predict also known as the output variable.
  • bo is the intercept of the line touching the y-axis.
  • b1 is the slope of the line and x is the independent variable, use for predicting y also known as an input variable.
  • e is the error in the prediction.
  • bo and b1 are called model coefficients.

To create your model, you must “learn” the values of these coefficients. Once we’ve learned these coefficients, we can use the model to predict.

Python Implementation:

We know the basic idea behind linear regression. Now we’re going to practice picking the right model for our data set and plotting it. I will demonstrate it using python on an E-commerce customer data set.

I will be using jupyter notebook. You can use any environment which you find most suitable. If you’re on Windows 7 and facing some issues during the download please refer to this article.

Now, it’s time to do the step by step implementation of linear regression in python.

Data Set Overview:

Data set consists of eight columns. The first three columns contain some basic information about customers such as email id, address, avatar. We will be using only numerical data in our analysis. The other columns are:

  • Avg. Session Length: Average session of in-store style advice sessions.
  • Time on App: Average time spent on App in minutes
  • Time on Website: Average time spent on Website in minutes
  • Length of Membership: How many years the customer has been a member.

Step — 1: Importing Libraries & Data Set

We will start with importing the necessary python libraries which we will be using for our analysis. Basically, all you should do is apply the proper packages and their functions and classes.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline
customers = pd.read_csv("Ecommerce Customers.csv")

To get the clear idea about the libraries which we are using, you can read their documentations.

Pandas: https://pandas.pydata.org/docs/

Numpy: https://numpy.org/doc/

Matplotlib: https://matplotlib.org/3.2.1/contents.html

Seaborn: https://seaborn.pydata.org/

Step — 2: Descriptive Data Analysis

Let’s take a look at data set.

customers.head()
E-commerce Customer Data set

To take a closer look at the data I am using .head() function of pandas library which returns first five observations of the data set. Similarly you can also use .tail() function which returns last five observations of the data set.

customers.shape
.shape output

To find out the total number of rows and columns in a data set you can use .shape function. Data set consists of 500 rows and 8 columns.

customers.describe()

Using .describe() function, we can see some fundamental statistical details like percentile, mean, std, etc. of the data frame. From the given output, we can perceive that the mean value and median value (50%) of all the data points are nearly equal. This means that our data set is not skewed.

customers.info()

.info() function is extremely helpful to find out the data type of variables available in a data set. Email, address, and, avatar are non-numeric variables, while the rest of them are numeric. Since we wouldn’t be using them in our model, we can ignore them. But if there are some non-numeric variables (for example, Male/Female) we have to first convert them in to numeric before feeding them into the model.

Now, we have a good glimpse of our data set. We will explore it further with the help of graphs.

Step — 2: Exploratory Data Analysis

It is a good practice to understand the data first and try to gather as many insights from it before applying any machine learning algorithm to it. Exploratory Data Analysis a.k.a. EDA is all about making sense of data in hand, before getting them dirty with it. We will be using python seaborn library which is very interactive for plotting graphs.

To use linear regression for modelling,its necessary to remove correlated variables to improve your model.

There are two aspects to it:

  • Features that are uncorrelated with respect to Target variable, possess less information regarding the Target variable(Yearly Amount Spent is the Target variable here). Hence it is a good practice to remove such variables. But before removing them, it is a wise decision to also check for the p-values. If the p-value is higher than the alpha than you must drop that variable.
  • Features that are correlated among themselves, which tend to give similar information. Hence are redundant and good to remove one of two depending on their correlation with respect to the target variable. One with a strong positive correlation with respect to Target, you can keep.

One can find correlations using pandas .corr() function and can visualize the correlation matrix using a heatmap in seaborn.

corr = customers.corr()
ax = sns.heatmap(corr, annot = True)
Correlation heatmap

Here, we have a small negative correlation between Time on the website and the Yearly Amount Spent. We can drop Time on the website column after checking for p-value.

For skewness and data distribution, it is a good idea to plot distribution graphs. To plot the distribution plot, we will create a separate data frame namely “df” which will only contain numerical variables.

df = customers.drop([‘Email’, ‘Address’, ‘Avatar’], axis = 1 )

Now let’s start plotting distribution plot for each variable.

sns.distplot(df[‘Avg. Session Length’], hist = False)
sns.distplot(df[‘Time on App’], hist = False)
sns.distplot(df['Time on Website'], hist = False)
sns.distplot(df[‘Length of Membership’], hist = False)
sns.distplot(df[‘Yearly Amount Spent’], hist = False)

From the graphs, we can see that data is normally distributed. Let’s explore relationship between the variables on entire data set.

sns.pairplot(customers)

Based on this plot, we can see that the Length of Membership is the most correlated variable with the Yearly Amount Spent.

That’s it for the data exploration section. Congratulations! you’ve done the hard part. Let’s move towards the simple part, which is, implementing the linear regression model on our data set.

Step — 3: Data Preparation:

Now that we've explored the data a bit, let's go ahead and split the data into training and testing sets. Set a variable X equal to the numerical features of the customers and a variable y equal to the “Yearly Amount Spent” column.

y = customers['Yearly Amount Spent']
X = customers[['Avg. Session Length', 'Time on App','Time on Website', 'Length of Membership']]

To divide data set into train and test set we will be using builtin scikit learn library. You can learn more about it at: https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html

We will first import the library.

from sklearn.model_selection import train_test_split

Now, let’s use it on our data set.

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=101)

Step — 4: Training the Model:

Now its time to train the model on our training data.

  • Import linear regression from sklearn library.

You can find the details about the library at: https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html

from sklearn.linear_model import LinearRegression
  • Create an instance of a LinearRegression() model named lm
lm = LinearRegression()
  • Train/fit lm on the training data
lm.fit(X_train,y_train)
Model fitting

Now, you have trained your data. Let’s check out the coefficients of the model.

print(‘Coefficients: \n’, lm.coef_)
Model Coefficients

Step — 5: Model Evaluation:

For linear regression, there are three main errors (metrics) used to evaluate models, Mean Absolute Error, Mean Squared Error, and R2 score. You can easily find them using python builtin function.

from sklearn import metricsprint(‘MAE:’, metrics.mean_absolute_error(y_test, predictions))
print(‘MSE:’, metrics.mean_squared_error(y_test, predictions))
print(‘RMSE:’, np.sqrt(metrics.mean_squared_error(y_test, predictions)))
Error Metrics

That’s it! you’ve successfully trained a linear model on your data set. If you have any question regarding this article, please comment below!

--

--