WHY DO WE DO NEED CROSS VALIDATION? (DATA SCIENCE INTERVIEW QUESTION)

Nishesh Gogia
5 min readDec 6, 2019

--

I have been working on Machine learning projects for a while now and one technique which really impressed me is “K FOLD CROSS VALIDATION”.

This technique is one of the best example to state the fact that how simple things can bring big impacts on your Machine Learning Model.

SO FIRSTLY WHY DO WE NEED K FOLD CROSS VALIDATION?

AND BEFORE THAT WHAT IS CROSS VALIDATION?

Lets understand CROSS VALIDATION.

Cross Validation comes into picture when we find a fundamental problem with data splitting into train and test only.

It is very common to split dataset into TRAINING DATASET AND TEST DATASET, but here we want to split data into three parts.

  1. Training dataset.
  2. Cross-Validation dataset.
  3. Test dataset.

But the question is WHY???

Let’s take a simple example

Lets say you have a query point xq and you want to find the class label(yq) and just to make it simple, we are assuming this problem to be a binary classification problem.

So yq can either be 1 or 0.

Suppose I am using KNN(k nearest neighbour) algorithm and i split data into TRAIN(80%) and TEST(20%) only for now.

Now if I simply convey you the nature of KNN, it simply finds K(let’say K=5 here) nearest point from my xq(distance measure here is euclidian), then take the majority vote so if THREE of the nearest point belongs to “1", then yq will be marked as 1.

But in this whole process we missed one detail, that is finding the values of K.

Now finding the value of K is very simple, we will start from K=1,We will find its nearest neighbour from TRAINING DATA and then find its accuracy on TESTING DATA.

Again We will start from K=2, find its NEAREST NEIGHBOUR and then find accuracy.

And then again do the same thing upto the moment you find the Best K.

Accuracy will first increase with the increasing K and the start decreasing with the increasing K.

HERE IS A PLOT OF ACCURACY OF TEST DATASET VS VALUE OF “K”.

Everything looks fine upto now.

BUT there is a fundametal problem in this process.

Objective of Machine Learning is to learn a function to predict yq for UNSEEN point xq(HERE PLEASE DO NOTE THE WORD UNSEEN).

Now what we were doing is finding the value of K by accuracy of TEST DATASET. Actually we are using Test dataset to learn the model and that is the fundamental problem.

We need a dataset which is unseen for the model then only we can decide whether its accurate or useless.

Now to tackle this problem we introduced CROSS VALIDATION, so now we know that data will be divided into 3 parts.

  1. TRAINING(60%)
  2. CROSS VALIDATION(20%)
  3. TESTING(20%)

Now our problem is solved right?

Yes one problem is solved but now there comes another problem.

If you observe carefully, you will notice that there is loss in TRAINING DATA percentage, it was 80 percent earlier and now it is 60 percent.

There is a Simple Fact in machine learning,

“MORE THE DATA, BETTER THE MODEL”

Now Training data has only 60 percent data so it means there is a loss of information and that will lower the performance of our model.

NOW HOW TO TACKLE THIS PROBLEM?

IS THERE A WAY SO THAT WE CAN USE 80% OF THE DATA?

Now its the time to introduce

K FOLD CROSS VALIDATION.

First thing is there is nothing common in this K fold and K Nearest Neighbour, K is just treated as a variable number here.

Now the intution behind k fold cross validation is simple, we will understand it by the same example we did earlier.

Let’s say we have training data Dtrain that contains 80% of the total data we have. Now lets divide this D train data randomly into k folds(here k=4)

D1(random 20% of data)

D2(random 20% of data)

D3(random 20% of data)

D4(random 20% of data)

Now we will do the same thing we we did earlier, for different k’s we ll find the accuracy on Cross validation data.

But we don’t have cross validation data right?

Here is the catch.

We will put D1 D2 D3 in training data (Dtrain) and D4 in cross validation data for K=1 and find the accuracy on cross validation data.

Now we will put the K same ie 1, and put D2 D3 D4 into Dtrain and D1 in cross validation and find the accuracy.

Like this for keeping the K same four times we will cover all the cases.

Now for same K but for different Cross Validation accuracy lets say a1, a2, a3 and a4 respectively, we will find the average accuracy.

for k=1, average accuracy is:-

a1’=a1+a2+a3+a4/4

Now for different K we will follow the same procedure and we will get different values of accuracy, lets say

for k=2, average accuracy is:-

a2’=b1+b2+b3+b4/4

and continues…

So lets say we have a1’, a2’, a3, a4' ….

for different K, now we will see which K is giving us the maximum accuracy and by that K will be decided and model will be build.

We will further use Test Data to find the actual performance of the model.

NOW YOU WILL SAY HOW TO FIND K VALUE IN K FOLD CROSS VALIDATION.

Generally This K Value could be anything between 5 to 10.

PROBLEM WITH K FOLD CROSS VALIDATION

If you have observed in K fold Cross Validation, we are putting same K and repeating the process K times and that increases the COMPUTATIONAL TIME.

Time it takes to compute the optimal test increases by K times.

Thanks for reading….

--

--