Splitting test and train data based on strata

Tim Kamanin
teachmymachine
Published in
2 min readSep 12, 2017

When we split a dataset into test and train sets, we often use different tactics for splitting. Very often methods are based on random records picking and putting them in different sets.

For example, sklearn has a simple function that allows us to split our set just easily:

train_set, test_set = train_test_split(data, test_size=0.2, random_state=42)

However, there’s a flaw in this method: datasets we get after splits can be unrepresentative.

For example, let’s imagine we try to predict a music genre preference and one of our most important features is age, and we have different age categories like up to 13, teens, 20–29, 30–39, 40–49, 50–59 and 60+.

Our random splitting function can easily end up grabbing more data from one of the categories than others. What can be even worse, it may omit one of age groups at all, because our function picks items from data set randomly.

What we want to do, is to get an equally representative train and test sets that contain equal data representation based on age.

How do we do that? Meet stratified sampling, our data is divided into homogeneous groups called strata, and the right number of items is sampled from each stratum to guarantee that the test set is representative of the age groups.

In our example, we have 6 strata aka age groups. Here’s a universal function that splits the test set based on strata provided:

from sklearn.model_selection import StratifiedShuffleSplitdef split_strat_train_test(data, strat_key, test_ratio):
"""
Split train and test data based on strata to get a representative train and test sets
- data is Pandas DataFrame
- strat_key is strata column name
"""
split = StratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=42)
for train_index, test_index in split.split(data, data[strat_key]):
strat_train_set = data.loc[train_index]
strat_test_set = data.loc[test_index]
return strat_train_set, strat_test_set

We’ll use it like this:

train_set, test_set = split_strat_train_test(data, “age”, 0.2)

As a result, you should get an equally representative by age group train set that consists of 80% of initial data and an equally representative by age group test set that contains 20% of the original data.

--

--