Automatic Sleep Stage Classification with Ear-EEG: Featu extraction, deep learning, and Grad-CAM

Krittaphas Chaisutyakorn
8 min readNov 16, 2023

--

Automatic Sleep Stage Classification via various approaches using Ear-EEG dataset

This project is part of Brain Code Camp 2023.

Introduction

Polysomnography (PSG) is currently the gold standard in detecting sleep-related disorders. Performing PSG requires a significant amount of resources and time. Each subject must sleep in a foreign laboratory environment with a number of metallic probes and wires attached to their body. This current limitation causes many people with sleep disorders to still lack access to PSG studies, especially in resource-limited settings. Many are trying to solve this problem. The Ear-EEG probe is one possible method worth exploring. Using an in-ear electric probe that resembles a consumer-grade in-ear headphone should improve accessibility to the test, ease of use, and comfort substantially.

Dataset

This project uses the open-source “Ear-EEG Sleep Monitoring 2017 (EESM17),” which consists of 9 polysomnography studies from 9 healthy individuals. Subjects included in this dataset ranged in age from 26 to 44 years old. The dataset comprises 8 channels of scalp EEG, EOG, EMG, and 7 channels of Ear-EEG in each ear. The data were collected at a 200 Hz sampling rate. The dataset includes sleep stage and sleep events labeled according to AASM guidelines.

In-ear EEG probe with 7 channels in each ear from Mikkelsen, K.B. et al.

Exploratory Data Analysis

Firstly the data were visualized. The data were visually checked for completeness, error, noise using MNE-python library.

1 Subject data visualized with MNE-python library

The dataset was split based on channel type into Scalp-EEG and Ear-EEG channels. This will be used to compare automatic sleep staging performance.

# Grouping channels by type
channels_type_group = dict(
scalp_channels = ['F3', 'F4', 'C3', 'C4', 'O1', 'O2', 'A1', 'A2'],
lt_ear_channels = ['ELA', 'ELE', 'ELI', 'ELB1', 'ELB', 'ELG', 'ELK'],
rt_ear_channels = ['ERA', 'ERE', 'ERI', 'ERB1', 'ERB', 'ERG', 'ERK'],
eog_channels = ['LOC', 'ROC'],
osat_channels = ['OSAT'],
emg_channels = ['CHIN12'],
)

scalp_eeg = data.pick(channels_type_group['scalp_channels'])

ear_eeg = data.pick(channels_type_group['ear_channels'])

This data had to be segmented into each epoch. This is done by using Pandas and Numpy packages.

import pandas as pd
import numpy as np
import mne
import json

# read channel group
with open("channels_group.json", 'r') as f:
channels_group = json.load(f)
f.close()

from utils import read_data

i = 1
data = mne.io.read_raw_eeglab(f'ear_eeg\sub-00{i}\ses-001\eeg\sub-00{i}_ses-001_task-sleep_eeg.set', preload=True)
events = pd.read_csv(f'ear_eeg\sub-00{i}\ses-001\eeg\sub-00{i}_ses-001_task-sleep_acq-scoring_events.tsv', sep='\t')
print("signal shape: ", data.get_data().shape)
print("events shape: ", events.shape)

# Get the dictionary of events
with open(f'ear_eeg\sub-00{i}\ses-001\eeg\sub-00{i}_ses-001_task-sleep_acq-scoring_events.json', 'r') as f:
events_id = json.load(f)
f.close()
events_id = events_id['staging']['Levels']
print(events_id)

The start of the signals dataset and the beginning of the annotated labels are not at the same time; therefore, we have to manually align the signals and labels.

# Get lights on & lights off time
events_mark = pd.read_csv(f'ear_eeg\sub-00{i}\ses-001\eeg\sub-00{i}_ses-001_task-sleep_events.tsv', sep='\t')

start_time = events_mark[events_mark['trial_type'] == 'Lights Off']['onset'].values[0]
end_time = events_mark[events_mark['trial_type'] == 'Lights On']['onset'].values[0]

print(f"Start time: {start_time}, End time: {end_time}")

# create events onset columns and adding lights off time
events['onset_hz'] = (events['onset'] + start_time) * sfreq

# mapping stage name to events_id
events['stage_name'] = [events_id[str(x)] for x in events['staging']]
events['stage_name'].value_counts()

events['signals'] = [X[:, int(x):int(x+sfreq*30)] for x in events['onset_hz']]
events['signals_shape'] = [x.shape for x in events['signals']]

events = events.drop(['onset', 'duration', 'staging'], axis=1)
events
This image shows final preprocessed product as a dataframe. Each row is 30 second epoch.

Approach 1: Feature Extraction + Machine Learning

In the first approach, we attempted to use feature extraction methods. Following the clinical EEG analysis approach, EEG is typically analyzed based on specific frequency bands. These bands, sorted from low to high frequency, include delta, theta, alpha, beta, and gamma bands. For this reason, we calculated power spectral density from EEG data to extract power specific to each frequency band.

from mne.time_frequency import psd_array_welch
from sklearn.preprocessing import StandardScaler
from sklearn.preprocessing import normalize

FREQUENCY_BANDS = {
'delta': [0.5, 4.5],
'theta': [4.5, 8.5],
'alpha': [8.5, 11.5],
'sigma': [11.5, 15.5],
'beta': [15.5, 30],
}

def compute_psd_array(signal, sfreq=sfreq, fmin=0.5, fmax=80):

psd, freqs = psd_array_welch(signal, sfreq=sfreq, fmin=fmin, fmax=fmax, verbose=False)

# splitting psd in to conventional frequency bands
psd_norm = normalize(psd.mean(axis=1), axis=1)

# x = []
x_norm = []
for fmin, fmax in FREQUENCY_BANDS.values():
# psd_band = psd[:, :, (freqs >= fmin) & (freqs < fmax)].mean(axis=-1)
psd_band_norm = psd_norm[:, (freqs >= fmin) & (freqs < fmax)].mean(axis=-1)
# x.append(psd_band.reshape(len(psd), -1))
x_norm.append(psd_band_norm.reshape(len(psd_norm), -1))

return np.concatenate(x_norm, axis=1)
Power Spectral Density for each frequency band

After calculating power spectral density, the obtained density was input into various machine learning algorithms. We utilized the Leave One Out method to split the data into training and testing sets.

from sklearn.model_selection import LeaveOneOut
import glob
from sklearn.model_selection import StratifiedKFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, confusion_matrix, f1_score, precision_score, recall_score, cohen_kappa_score
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import FunctionTransformer

loo = LeaveOneOut()

signal_files = glob.glob('processed_data/*_signals.npy')
label_files = glob.glob('processed_data/*_labels.npy')

# initialize empty lists to store results
acc_sum = []
f1_sum = []
precision_sum = []
recall_sum = []
kappa_sum = []
conf_sum = None

for train_index, test_index in loo.split(signal_files, label_files):

# Rezeroing variables
X_train = None
y_train = None
X_test = None
y_test = None

print("TRAIN:", train_index, "TEST:", test_index)

for index in train_index:
if (X_train is None):
X_train = np.load(signal_files[index])
y_train= np.load(label_files[index])

else:
X_train = np.concatenate((X_train, np.load(signal_files[index])), axis=0)
y_train = np.concatenate((y_train, np.load(label_files[index])), axis=0)
print('X_train shape: ', X_train.shape)
print('y_train shape: ', y_train.shape)

for index in test_index:
if (X_test is None):
X_test = np.load(signal_files[index])
y_test = np.load(label_files[index])

else:
X_test = np.concatenate((X_test, np.load(signal_files[index])), axis=0)
y_test = np.concatenate((y_test, np.load(label_files[index])), axis=0)
print('X_test shape: ', X_test.shape)
print('y_test shape: ', y_test.shape)

# creating sklearn pipeline
pipe = make_pipeline(
FunctionTransformer(compute_psd_array, validate=False),
RandomForestClassifier(n_estimators=100, random_state=42)
)

pipe.fit(X_train, y_train)

# evaluate model
y_pred = pipe.predict(X_test)
acc = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred, average='macro')
precision = precision_score(y_test, y_pred, average='macro')
recall = recall_score(y_test, y_pred, average='macro')
kappa = cohen_kappa_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)

print('Accuracy: ', acc)
print('Confusion matrix: ', conf_matrix)

conf_matrix = np.expand_dims(conf_matrix, 0)

acc_sum.append(acc)
f1_sum.append(f1)
precision_sum.append(precision)
recall_sum.append(recall)
kappa_sum.append(kappa)

if conf_sum is None:
conf_sum = conf_matrix
else:
conf_sum = np.concatenate([conf_sum, conf_matrix], axis=0)

Result of feature extraction and machine learning approach using Scalp-EEG

The same method was applied to Ear-EEG data. Additionally, the data was downsampled from 200 Hz to 100 Hz to assess performance.

Result of feature extraction and machine learning approach using Scalp-EEG and Ear-EEG

Approach 2: Deep learning method

In this approach, feature extraction was omitted, and the preprocessed signal was input directly into the neural network. The neural network consisted of simple 1-dimensional convolutional layers with dense classifier layers. However, the neural network was not yet optimized and was constrained by computational resources. All aspects of the neural network and its training processes were implemented using the PyTorch library. The training step involved using the categorical cross-entropy loss function and the Adam optimizer.

Model architecture: simple sequentially connected network

The neural network approach performs worse than the feature extraction and machine learning approach.

Performance of the neural network compared to the feature extraction approach

Grad-CAM analysis

Explainable AI was also included in the project’s plan because an automatic computer algorithm must be comprehensible to be approved for use in clinical settings. Without comprehensibility, doctors and sleep specialists may not trust the AI to make clinical decisions.

Grad-CAM (Gradient-weighted Class Activation Mapping) is an algorithm designed for explainable AI methods. It can map important regions for classification algorithms. Originally designed for image classification AI, Grad-CAM can also be applied to signal data by implementing it on the convolutional layers.

The code for this was simple and required only minor adjustments to the original network model. The added codes are marked with #added in the snippet below.

class conv_net(nn.Module):

def __init__(self):
super(conv_net, self).__init__()

self.conv_block = nn.Sequential(
nn.Conv1d(1, 64, 5, padding='same'),
nn.BatchNorm1d(64, device=device),
nn.ReLU(),
nn.Conv1d(64, 128, 5, padding='same'),
nn.BatchNorm1d(128, device=device),
nn.ReLU(),
nn.Conv1d(128, 128, 5, padding='same'),
nn.BatchNorm1d(128, device=device),
nn.ReLU(),
nn.Conv1d(128, 64, 5, padding='same'),
nn.BatchNorm1d(64, device=device),
nn.ReLU(),
nn.Conv1d(64, 32, 5, padding='same'),
nn.BatchNorm1d(32, device=device),
nn.ReLU()
)

self.dense_block = nn.Sequential(
nn.Linear(96000, 50),
nn.ReLU(),
nn.Linear(50, 5),
)

self.gradients = None ##added

def activations_hooks(self, grad):
self.gradients = grad

def forward(self, x):
# 5 1dconvolutional layers
x = self.conv_block(x)

# register the hook
if self.train and x.requires_grad:
h = x.register_hook(self.activations_hooks) ##added

# continuing with the remaining layers
x = x.view(-1, 16*6000)

x = self.dense_block(x)

x = nn.LogSoftmax(dim=1)(x)

return x # return logits from forward network

def get_activations_gradient(self): ##added
return self.gradients

def get_activation(self, x): ##added
return self.conv_block(x)

After the training was done we can use the gradients to compute Grad-CAM using the codes below.

def get_GradCAM(model, signal, average_factor = 10):

# set model to evaluation mode
model.eval()

logits = model(signal)
pred = logits.argmax(dim=1)

# make sure that the model has these 2 functions
gradients = model.get_activations_gradient()
activations = model.get_activation(signal).detach()

# pooled the gradients across the channels
gradients = torch.mean(gradients, dim=0, keepdim=True)
activations = activations * gradients

# calculate the heatmap
heatmap = torch.mean(activations, dim=1).squeeze()

# apply relu to the heatmap
heatmap = torch.maximum(heatmap, torch.tensor(0.))

# normalize the heatmap
heatmap /= torch.max(heatmap)

# average heatmap signals for better visualization

heatmap = nn.AvgPool1d(average_factor)(heatmap.unsqueeze(0)).squeeze(0)
heatmap = nn.Upsample(scale_factor=average_factor)(torch.tensor(heatmap).unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)

return heatmap.cpu().numpy()


def plot_signal_heatmap(signal, heatmap, length=3000):

# plot the signal and heatmap
random_start = np.random.randint(0, 3001-length)

signal_plot = signal[0][0].cpu().numpy()
signal_plot = signal_plot[random_start:random_start+length]

fig, axs = plt.subplots(2, 1, figsize=(10, 8), sharex=True,
gridspec_kw={'height_ratios': [3, 1]})

axs[0].set_title(f'Signal, Sleep stage {label.cpu().numpy().item()}')
axs[0].plot(signal_plot, label='signal')

axs[1].set_title('Heatmap')
hm = axs[1].pcolormesh([heatmap[random_start:random_start+length]], cmap='Oranges')

plt.show()


heatmap = get_GradCAM(model, signal, average_factor=10)

plot_signal_heatmap(signal, heatmap, length=3000)
Signal with Grad-CAM heatmap below. A region with orange color signified the important region for the neural network

This Grad-CAM algorithm can be used to explain the neural network. However, in this project, the performance of the neural network will need to be improved before conducting further analysis.

Conclusion

In this project, two approaches to automatic sleep stage classification were tested on both scalp-EEG and ear-EEG data. As expected, sleep staging on ear-EEG data performed much worse than its counterpart. The deep learning approach performed worse than the classical machine learning approach. This may be due to data class imbalance, limited data availability, time, and resource limitations. The deep learning approach can still be improved. This project also tests that the Grad-CAM algorithm is a suitable method not only for image AI but also for signal AI.

Reference

Mikkelsen, K.B., Villadsen, D.B., Otto, M. et al. Automatic sleep staging using ear-EEG. BioMed Eng OnLine 16, 111 (2017). https://doi.org/10.1186/s12938-017-0400-5

Alexandre Gramfort, Martin Luessi, Eric Larson, Denis A. Engemann, Daniel Strohmeier, Christian Brodbeck, Roman Goj, Mainak Jas, Teon Brooks, Lauri Parkkonen, and Matti S. Hämäläinen. MEG and EEG data analysis with MNE-Python. Frontiers in Neuroscience, 7(267):1–13, 2013. doi:10.3389/fnins.2013.00267.

R. R. Selvaraju, M. Cogswell, A. Das, R. Vedantam, D. Parikh and D. Batra, “Grad-CAM: Visual Explanations from Deep Networks via Gradient-Based Localization,” 2017 IEEE International Conference on Computer Vision (ICCV), Venice, Italy, 2017, pp. 618–626, doi: 10.1109/ICCV.2017.74.

--

--

Krittaphas Chaisutyakorn
0 Followers

General Physician, Data reseacher at Siriraj Data Innovation Center