Model Evaluation: Cross-Validation

Satria Suria
3 min readJun 18, 2024

--

Created with DALL·E 3

Cross-validation is a statistical method used to evaluate the performance and generalization ability of a machine learning model. It involves partitioning the data into subsets, training the model on some subsets (training sets), and testing it on the remaining subsets (validation sets). The goal is to ensure that the model performs well not just on the training data but also on unseen data.

Why Cross-Validation is Important

  1. Generalization: Cross-validation helps assess how well a model generalizes to an independent dataset, preventing overfitting or underfitting.
  2. Model Selection: It aids in selecting the best model by comparing performance metrics across different models or hyperparameter settings.
  3. Bias-Variance Trade-off: It provides insights into the bias-variance trade-off of a model, helping to balance complexity and performance.

Common Cross-Validation Techniques

1. K-Fold Cross-Validation

K-fold cross-validation is the most commonly used method. The dataset is randomly divided into K equal-sized folds. The model is trained K times, each time using K-1 folds as the training set and the remaining fold as the validation set. The final performance metric is the average of the K validation scores.

Example:

from sklearn.model_selection import KFold, cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris

# Load dataset
iris = load_iris()
X, y = iris.data, iris.target

# Define model
model = RandomForestClassifier()

# Define K-Fold Cross-Validation
kfold = KFold(n_splits=5, shuffle=True, random_state=42)

# Evaluate model
results = cross_val_score(model, X, y, cv=kfold)

# Print results
print(f"Cross-Validation Scores: {results}")
print(f"Mean Accuracy: {results.mean()}")

2. Stratified K-Fold Cross-Validation

Stratified K-Fold Cross-Validation is similar to K-Fold but ensures that each fold is representative of the overall class distribution. This is particularly useful for imbalanced datasets.

Example:

from sklearn.model_selection import StratifiedKFold, cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris

# Load dataset
iris = load_iris()
X, y = iris.data, iris.target

# Define model
model = RandomForestClassifier()

# Define Stratified K-Fold Cross-Validation
skfold = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

# Evaluate model
results = cross_val_score(model, X, y, cv=skfold)

# Print results
print(f"Stratified Cross-Validation Scores: {results}")
print(f"Mean Accuracy: {results.mean()}")

3. Leave-One-Out Cross-Validation (LOOCV)

LOOCV is an extreme version of K-Fold where K equals the number of data points. Each iteration uses a single data point as the validation set and the rest as the training set. It’s computationally expensive but maximizes the use of data.

Example:

from sklearn.model_selection import LeaveOneOut, cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris

# Load dataset
iris = load_iris()
X, y = iris.data, iris.target

# Define model
model = RandomForestClassifier()

# Define Leave-One-Out Cross-Validation
loocv = LeaveOneOut()

# Evaluate model
results = cross_val_score(model, X, y, cv=loocv)

# Print results
print(f"LOOCV Scores: {results}")
print(f"Mean Accuracy: {results.mean()}")

4. Time Series Split

For time series data, it’s important to maintain the chronological order. Time Series Split ensures that training always precedes testing, reflecting the real-world scenario.

Example:

from sklearn.model_selection import TimeSeriesSplit, cross_val_score
from sklearn.ensemble import RandomForestRegressor
import numpy as np

# Generate synthetic time series data
X = np.arange(100).reshape(-1, 1)
y = np.sin(X).ravel()

# Define model
model = RandomForestRegressor()

# Define Time Series Split
tscv = TimeSeriesSplit(n_splits=5)

# Evaluate model
results = cross_val_score(model, X, y, cv=tscv)

# Print results
print(f"Time Series Split Scores: {results}")
print(f"Mean Score: {results.mean()}")

--

--

Satria Suria

I am learning AI Engineering and Software Development and would like to reinforce my understanding by writing about them. I welcome constructive criticism.