[ML basics][Regression] How to tell if a dataset is linear or not?

Abhinav Mahapatra
4 min readJan 7, 2019

--

Well, the question is pretty simple on this one.

How would you tell if a given dataset is linear or non-linear in nature? Of course, the selection of the models to be utilized will depend on it.

Well then, let us get started.

First, the difference between linear and non-linear functions:

(Left)Linear functions (right)Non-linear functions

Linear function: Can be simply defined as a function which always follows the principle of :

input/output = constant.

A linear equation is always a polynomial of degree 1 (for example x+2y+3=0). In the two dimensional cases, they always form lines; in other dimensions, they might also form planes, points, or hyperplanes. Their “shape” is always perfectly straight, with no curves of any kind. This is why we call them linear equations.

Non-linear function: Any function that is not linear is simply put, Non-linear. Higher degree polynomials are nonlinear. Trigonometric functions (like sin or cos) are nonlinear. Square roots are nonlinear.

That is all fine and dandy but how will we find if a dataset is linear or not. Graphs are easy if we have a single dimension (not always as we will see here)but how to tackle multiple dimensional datasets?

(Left)Generating a linear dataset (Right) Graph of the same dataset

As we can see in the above fig. It is not always as straightforward to get an idea of a function is linear or not from a graph.

Q. How do we solve it?

So, the idea is to apply simple linear regression to the dataset and then to check least square error. If the least square error shows high accuracy, it implies the dataset being linear in nature, else dataset is non-linear.

Simple right?

Alright, let us get to code:

Starting with linear dataset:

# General imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Generating data
X = np.random.randn(100,1)
c = np.random.uniform(-10,10,(100,))
# adding another linear column
X = np.hstack((X, 4*X))

Y = (4*X[:,1] + c)

plt.scatter(X[:, 0], Y)
plt.show()
plt.scatter(X[:, 1], Y)
plt.show()
# Applying linear reg
from sklearn.linear_model import LinearRegression
regressor = LinearRegression().fit(X, Y)
# Checking the accuracy
from sklearn.metrics import r2_score
print(r2_score(regressor.predict(X), Y))

Outputs:

(Left)Graph of the first column with y(Right)Graph of the second column with y
The R2 accuracy score is about 84%

Now to the Non-linear Dataset:

# General imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Generating data
X = np.random.randn(100,1)
c = np.random.uniform(-10,10,(100,))
# adding another non-linear column
X = np.hstack((X, X*X))

Y = (4*X[:,1] + c)

plt.scatter(X[:, 0], Y)
plt.show()
plt.scatter(X[:, 1], Y)
plt.show()
# Applying linear reg
from sklearn.linear_model import LinearRegression
regressor = LinearRegression().fit(X, Y)
# Checking the accuracy
from sklearn.metrics import r2_score
print(r2_score(regressor.predict(X), Y))

Outputs:

(Left)Graph of the first column with y(Right)Graph of the second column with y
The R2 accuracy score is about -122%

Needless to say, that is extremely undesirable accuracy score. While the whole code was almost the same, we can see the addition of non-linearity had a very profound effect on the accuracy score.

Before getting started with a dataset, a 4 line code on a small validation set to check if the dataset is linear or not can save a lot of your time.

Any suggestions for a more streamlined process or any doubts, please feel free to comment.

Sources used: https://www.quora.com/What-is-the-difference-between-linear-and-non-linear-equations, https://study.com/academy/lesson/how-to-recognize-linear-functions-vs-non-linear-functions.html, https://stackoverflow.com/questions/7181014/determine-if-a-set-of-data-is-from-a-linear-or-logarithmic-function

--

--