scikit-learnのStratifiedKFoldで交差検定

takkii
Music and Technology
7 min readNov 10, 2017

tkmさんのKaggle動画をみていたところ、 sklearn.model_selection.StratifiedKFold が出て来た。

これはCVをする際に、ラベルの比率が揃うようにtrainデータとtestデータを分けてくれるものだ。

K個に分け、1個をtestデータに、K-1個をtrainデータとする。その際にtrainデータにしかないラベルがあると分類がうまくできないので、trainとtestのラベルの比率を等しくするようにする。

K=2の場合

from sklearn.model_selection import StratifiedKFold
import numpy as np
# 9つのデータがそれぞれクラス1、2、3に属しているとする。
X = np.array([[1,1,1],[2,2,2],[3,3,3],[4,4,4],[5,5,5],[6,6,6],[7,7,7],[8,8,8],[9,9,9]])
y = np.array([1,2,3,1,2,3,1,2,3])
# n_splitsでKの数を指定
skf = StratifiedKFold(n_splits=2)
for train_index, test_index in skf.split(X,y):
print("X_train:",X[train_index])
print("y_train:",y[train_index])
print("X_test:",X[test_index])
print("y_test:",y[test_index])
print("-----")

結果

X_train: [[7 7 7]
[8 8 8]
[9 9 9]]
y_train: [1 2 3]
X_test: [[1 1 1]
[2 2 2]
[3 3 3]
[4 4 4]
[5 5 5]
[6 6 6]]
y_test: [1 2 3 1 2 3]
-----
X_train: [[1 1 1]
[2 2 2]
[3 3 3]
[4 4 4]
[5 5 5]
[6 6 6]]
y_train: [1 2 3 1 2 3]
X_test: [[7 7 7]
[8 8 8]
[9 9 9]]
y_test: [1 2 3]
-----

trainとtestにラベル1,2,3のデータが等しい数ずつ入ってる。

K=3の場合

skf = StratifiedKFold(n_splits=3)
for train_index, test_index in skf.split(X,y):
print("X_train:",X[train_index])
print("y_train:",y[train_index])
print("X_test:",X[test_index])
print("y_test:",y[test_index])
print("-----")

結果

X_train: [[4 4 4]
[5 5 5]
[6 6 6]
[7 7 7]
[8 8 8]
[9 9 9]]
y_train: [1 2 3 1 2 3]
X_test: [[1 1 1]
[2 2 2]
[3 3 3]]
y_test: [1 2 3]
-----
X_train: [[1 1 1]
[2 2 2]
[3 3 3]
[7 7 7]
[8 8 8]
[9 9 9]]
y_train: [1 2 3 1 2 3]
X_test: [[4 4 4]
[5 5 5]
[6 6 6]]
y_test: [1 2 3]
-----
X_train: [[1 1 1]
[2 2 2]
[3 3 3]
[4 4 4]
[5 5 5]
[6 6 6]]
y_train: [1 2 3 1 2 3]
X_test: [[7 7 7]
[8 8 8]
[9 9 9]]
y_test: [1 2 3]
-----

同様にtrainとtestにラベル1,2,3のデータが等しい数ずつ入ってる。

K=4の場合

最も少ないラベルの数より、Kを大きくすることはできない。上のデータだと、どのラベルともデータ数が3つなので、3つより多い数で分割できない。

skf = StratifiedKFold(n_splits=4)
for train_index, test_index in skf.split(X,y):
print("X_train:",X[train_index])
print("y_train:",y[train_index])
print("X_test:",X[test_index])
print("y_test:",y[test_index])
print("-----")

結果

ValueError: n_splits=4 cannot be greater than the number of members in each class.

ラベルの数が揃ってない場合

K=2の場合

極力揃えてくれる。

# 9つのデータがそれぞれクラス1、2に属しているとする。
X = np.array([[1,1,1],[2,2,2],[3,3,3],[4,4,4],[5,5,5],[6,6,6],[7,7,7],[8,8,8],[9,9,9]])
y = np.array([1,1,1,1,2,2,2,2,2])
# n_splitsでKの数を指定
skf = StratifiedKFold(n_splits=2)
for train_index, test_index in skf.split(X,y):
print("X_train:",X[train_index])
print("y_train:",y[train_index])
print("X_test:",X[test_index])
print("y_test:",y[test_index])
print("-----")

結果

X_train: [[3 3 3]
[4 4 4]
[8 8 8]
[9 9 9]]
y_train: [1 1 2 2]
X_test: [[1 1 1]
[2 2 2]
[5 5 5]
[6 6 6]
[7 7 7]]
y_test: [1 1 2 2 2]
-----
X_train: [[1 1 1]
[2 2 2]
[5 5 5]
[6 6 6]
[7 7 7]]
y_train: [1 1 2 2 2]
X_test: [[3 3 3]
[4 4 4]
[8 8 8]
[9 9 9]]
y_test: [1 1 2 2]
-----

N=3の場合

skf = StratifiedKFold(n_splits=3)for train_index, test_index in skf.split(X,y):
print("X_train:",X[train_index])
print("y_train:",y[train_index])
print("X_test:",X[test_index])
print("y_test:",y[test_index])
print("-----")

結果

X_train: [[3 3 3]
[4 4 4]
[7 7 7]
[8 8 8]
[9 9 9]]
y_train: [1 1 2 2 2]
X_test: [[1 1 1]
[2 2 2]
[5 5 5]
[6 6 6]]
y_test: [1 1 2 2]
-----
X_train: [[1 1 1]
[2 2 2]
[4 4 4]
[5 5 5]
[6 6 6]
[9 9 9]]
y_train: [1 1 1 2 2 2]
X_test: [[3 3 3]
[7 7 7]
[8 8 8]]
y_test: [1 2 2]
-----
X_train: [[1 1 1]
[2 2 2]
[3 3 3]
[5 5 5]
[6 6 6]
[7 7 7]
[8 8 8]]
y_train: [1 1 1 2 2 2 2]
X_test: [[4 4 4]
[9 9 9]]
y_test: [1 2]
-----

比率を極力近づけるために、3分割されたもののそれぞれのサンプル数は4,3,2と差が大きくなっている。

--

--

takkii
Music and Technology

Competitive Programming, MachineLearning, Manga, Music, BoardGame.