Mismatched training and dev sets in Deep Learning
While building a Deep Learning model, it is a common practice to classify labelled data into three sets — namely training, dev and test sets. We train the algorithm using training set, validate using dev set, analyse errors and repeat until error on training and dev sets reduce. We pick the best performing model from the before approach and later use this model to find accuracy on test set. One of the important aspects of error analysis is identifying Bias and Variance. Value of the error(misclassified examples)on the training set is labelled as Bias and the difference between the value of error on dev set and value of error on training set is labelled as Variance. Based on the values of Bias and Variance, we proceed with different ways. If the values of bias and Variance are convincing, we stop the iteration and go with the testing on test set. This is the usual cycle in deep learning. Consider the following cases —
- Train set error — 1% and dev set error — 10% means that our model is overfitting train set and not being able to generalise unseen examples. This is called High Variance and can be reduced by introducing regularisation and then training model again.
- Train set error — 10% and dev set error — 11% means that our model is under fitting train set. This is called High Bias and can be reduced by training with a bigger network or a different neural network architecture.
- Train set error — 0.5% and dev set error — 1% means that our model is performing well and we can use this model to test on test set.
The above cycle has an intrinsic assumption that training and dev/test sets are from same distribution. But what if that’s not the case — Let’s say you want to build an application that lets users upload a picture from mobile app and the application classifies whether that picture contains cat or not. For training data preparation, we have gotten 10,000 pictures from the mobile uploads. But by crawling the web we can download huge numbers of cat pictures, and maybe we have 200,000 pictures of cats downloaded off the Internet. These two sets of pictures are basically from different distributions
— pictures from internet will have high quality and pictures uploaded from mobile app will be blurry and amateur. If we combine all the above two types of pictures and form train, dev and test sets — with 205000 in training set,2500 in dev and 2500 in test, we’ll have an advantage of having all the three sets from same distribution. But the disadvantage is that lot of the data in dev set will come from the web page distribution of images, rather than what we actually care about, which is the mobile app distribution of images.
Scenarios like above will force us to have different distributions in training and dev/test sets. Now, we can split the data in above scenario like this —Add 5000 images out of 10000 mobile uploads to the images from internet and call this training set. Split the rest of 5000 out of 10000 images from mobile uploads into dev and test sets. The advantage with this approach is that we are aiming our target where we want it to be.
But how do we validate our model ? As we are no more working on same distributions, the previous definition of Variance is not valid. In these scenarios, we split our data into four parts namely —
1&2 sets will be of same distribution and so does 3&4 . In our example, We form our new training set and training-dev set from the old training set of 205000 examples.Training-dev set will be of similar size to dev and test sets. Therefore, our new split will have 202500, 2500, 2500 and 2500 examples. Consider the following scenarios-
- training set error — 1% training-dev set error — 10% and dev set error — 11%.
Variance can be now defined as the difference between training-dev set error and training error. Here in the above case, our model is not performing well on unseen training-dev set examples even though our training set and training-dev set are of same distribution, which means that this is the case of High Variance.
2. training set error — 10% training-dev set error — 12% and dev set error — 13%.
This is a problem of High Bias as model is not performing well on the trained data itself.
3.training set error — 2% training-dev set error — 5% and dev set error — 14%.
This means that our model is performing well on training and training-dev sets of same distribution but not on dev set. This is mismatched training and dev set problem.
Generally speaking one way to get out of it is to make error analysis on what makes distributions of training and dev sets different. Once we understand the difference, we can think of what can we do make them more similar and synthesise artificial data. In our cat classification example, may be we can go through the training set and add a degree of blurriness to most of the images in the set.
Above analysis is from one of the lectures in Andrew Ng’s Deep learning specialisation course — https://www.coursera.org/learn/machine-learning-projects/