Deep Learning

Deep Recurrent Neural Networks for Electroencephalography Analysis

Juan Vera
Intuition
Published in
27 min readFeb 12, 2024

--

On the application of Deep Learning to Enhance Clinical EEG Analysis.

Electroencephalography, or EEG for short, is an invaluable tool for clinics in diagnostics and treatment of various mental disorders and illnesses.

This includes epilepsy & seizures¹, sleep disorders² (insomnia³), Alzheimer’s disease⁴, and much more.

So, the process for EEG analysis is typically as such:

  1. Data Collection & Testing
  2. Signal Processing
  3. Classification
  4. Diagnosis / Treatment Protocol

Data collection & Testing involves the gathering of electrical activity from the brain through electrodes at the scalp. This is typically done over a lengthy period of time ranging from 20 minutes to multiple hours. During data collection, the patient may be asked to follow a set of commands to induce specific brain activity.

Signal Processing is where the raw EEG data is taken to then be filtered, de-noised, and then transformed (if needed) into frequency domains for further analysis. This allows for the extraction of meaningful data, from EEG artifacts & noise.

Classification is where EEG data is then classified into specific groups of interest to differentiate from normal and abnormal brain activity. In the case of seizure detection and epilepsy, the data is typically split between epileptiform EEG patterns and normal EEG patterns.

Afterwards, a patient is diagnosed and given a treatment protocol consisting of medication, additional health protocols, and possibly additional EEG recordings.

Unfortunately, there are many limitations that come along with using EEG as a means for diagnostics…

By no means, is EEG an accurate measure of brain health. Given that it records electrical activity at the scalp, it’s prone to high levels of noise and an inaccurate signal.

Despite undergoing robust signal processing pipelines to remove noise, EEG data can still be inadvertently misinterpreted by neurologists leading to opportunities for misdiagnosis.

Let’s use epilepsy diagnostics as an example.

Erroneous readings of EEG data can lead neurologists to wrongly interpret normal brain activity as abnormal epileptic brain activity.

Such erroneous readings are fueled by the lack of strict and clear-defined criteria for EEG interpretation.

As a result, the most common errors in interpretation are found in the temporal regions of the brain, typically due to incorrectly classifying wicket spikes as epileptiform patterns⁵.

Difference between Wicket and Epileptiform EEG | Source

These errors and misinterpretations can be due to a lack of experience in EEG analysis. It’s evident that a clinician who lacks the appropriate training is at risk for over-interpreting EEG⁶.

But the lack of a gold, universal standard for EEG analysis makes it difficult to objectively assess and interpret brain activity even amongst the most experienced neurologists.

A study conducted by Grant et al., assessed quality of EEG interpretations through intra-rater and inter-rater reliability amongst 6 pediatric and adult epileptologists⁷.

To clarify, intra-rater reliability defines the reliability of a specific individual at measuring a specific phenomenon.

Inter-rater reliability defines the reliability of measuring a specific phenomenon amongst a group of observers

It was found that in 5 out of the 6 epileptologists, median intra-rater reliability was ≤ 99% and the upper quartile of the collected data exhibited 100%.

It’s clear that individual epileptologists were highly confident about their personal interpretations of the sample EEG data.

On the flip side, the measured inter-rater reliability exhibited completely different results.

The aggregated kappa value amongst the epileptologists amounted to .44, the lowest being .29 and the highest being .62.

If you aren’t sure what kappa values are, feel free check this video out to clarify!

Ultimately, this was translated to a pretty high probability that an attempt to precisely and accurately reach a consensus on an EEG interpretation would fail

“When interpreting EEGs into the seven primary categories, the probability that a randomly selected pair of readers will disagree on a randomly selected category is about 42%, implying that the probability of one reader being wrong is at least 21%”

— Grant et al

Quite shocking isn’t it?

This means there’s approximately a 21% chance an epilepsy misdiagnosis…

And I’m not conflating conclusions from the study by Grant et al., onto epilepsy diagnostics.

Real world data confirms this!

“A population based study mainly in adults found a misdiagnosis rate of 23%, while 26% of subjects referred to a single adult neurologist with “refractory epilepsy” were found not to have epilepsy.

Developmental changes in the normal EEG, background EEG abnormalities, and “non‐epileptogenic epileptiform” abnormalities have all been used to erroneously support the diagnosis of epilepsy.”

CD Ferrie

Of a population of 214 patients initially diagnosed with epilepsy,

“49 (23.2%) were identified as not having epilepsy and all but two have since been withdrawn from antiepileptic medication.”

Scheepers et al

and

“Each year more than 90,000 people in England and Wales are wrongly given a diagnosis of epilepsy, a new study has estimated. This scale of misdiagnosis may be resulting in unnecessary costs of as much as £138m (€205m; $257m) a year, it says…

Given a prevalence of epilepsy of 7.7 cases in every 1000 people and a misdiagnosis rate of 23%”

Roger Dobson¹⁰

So, given the 21% probability than an EEG reading is misinterpreted, patients (epilepsy, in this example), can suffer from the adverse consequences of misdiagnosis such as,

  • Restricted Driving
  • Unnecessary epileptic mediation causing drastic side effects such as fatigue, poor focus, vomiting, feeling “drunk”, tremors, worsened mood, anxiety, and much more¹¹.
  • Unnecessary treatment costs, amounting up to >$10k¹²
  • Job insecurity & unemployment¹³

This list could go on and on…

Keep in mind, this is only in regards to epilepsy.

What about other neurological conditions that have EEG based diagnostics? What about sleep disorders or Alzheimers?

Odds are that they too suffer from similar cases of misdiagnosis when using clinical EEG…

So how can we improve the accuracy of EEG diagnostics?

Well, one way to mitigate the observer error is to adjust the role that the observer plays by introducing more a objective methodology of analysis…

Have you ever heard about deep recurrent neural networks?

What are DRNNs?

A Deep Recurrent Neural Networks (DRNN) is a type of machine learning architecture specifically designed to analyze and understand sequences of data such as text, speech, and fortunately for us: EEG!

They’re also great for distinguishing patterns in large datasets to make future predictions on past observations.

Let’s break down DRNNs a little further.

Neural Networks are a subset of machine learning and serve as the base architecture for deep learning algorithms. They’re called neural as they mimic how the brain sends electrical signals through it’s neurons.

Without going into the interesting mathematics, at the highest level, this is what a neural network looks like:

Pretty simple, huh? | Image by Author

The input layer of a neural network is for taking in pieces of data to either train the model or test an already trained model.

As data is fed into the input layer and through the network, the individual nodes of the network, called neurons, apply various mathematical operations on the data in order to make unique distinctions in specific patterns.

Once the data has fully passed through the network, the output layer sends out the prediction that a neural network makes based on the patterns it ‘sees’.

Sorta like this.

In order to ensure that a network makes accurate predictions based on a dataset, it’s important to iteratively ‘train’ the network by continuously feeding it diverse datasets and continuously giving it feedback through back-propagation.

As a network is trained, it ‘learns’ to make accurate predictions by continuously adjusting it’s weights.

Adjusting the weights of a network are what can allow for the neural network to modify its mathematical operations to then allow for more accurate predictions.

The downside to an ordinary neural network is it’s inability to take in sequential pieces of data of a lengthy size given it’s lack of a memory mechanism.

Instead, they’re primarily geared to take in fixed-size inputs and produce fixed-size outputs.

Clearly, for EEG, which is a sequential form of data, this isn’t ideal.

So we resolve this by introducing Recurrent Neural Networks.

Here’s what a simple RNN looks like:

Still pretty simple! | Image by Author

The added recurrence to the network is what defines an RNN.

It allows for a network to ‘remember’ sequential data by allowing past bits of information in a sequence to influence it’s future predictions.

So, when using EEG data, an RNN would be able to make predictions based it’s temporal patterns over a period of time rather than relying on the immediate behavior of the EEG, whether it be in the time or frequency domain.

This is extremely important as the most value attained from EEG comes from it’s temporal patterns.

Temporal patterns meaning, patterns seen over a period of time.

Now, to make an RNN, deep, all that’s left is the introduction of additional hidden layers.

This is what a Deep Recurrent Neural Network (DRNN) looks like:

Yep. This is it | Image by Author

Essentially, hidden layers are the additional layers added to a network that aren’t the input nor output layers.

The addition of hidden layers into an RNN can potentially allow for more accurate predictions as it applies more functions through it’s additional recurrent units. They tend to me more accurate as they allow for a network to recognize increasingly specific patterns.

So, ultimately a Deep Recurrent Neural Network could allow for increasingly accurate EEG classification…

With the guidance of curiosity, I managed to find a paper describing the use of a DRNN for the classification of EEG analysis.

ChronoNet: A Deep Recurrent Neural Network for Abnormal EEG Identification …

we propose a novel recurrent neural network (RNN) architecture termed ChronoNet which is inspired by recent developments from the field of image classification and designed to work efficiently with EEG data. ¹⁴

Roy et al

To figure out how the model works, I thoroughly read the paper and replicated the model’s code, with some guidance and modifications from here.

If you’re curious to check out the paper for yourself, you can do so here.

ChronoNet

The motivation to create ChronoNet follows a similar line of reasoning as described earlier.

To be able to identify a neurological disorder through EEG, a patient typically needs long-term monitoring of their brain activity. Singular short sessions don’t suffice.

Yet as a result, huge amounts of data are produced which then need to be manually interpreted by neurologists.

The complications that arise from this can be incorrect interpretations given a lack of expertise and complex data, delays in diagnostics, and delays in treatment.

Figuring out a way to automate EEG analysis at the classification step is what can potentially allow for accelerated diagnostics and decreased workload for a neurologist.

In order to make this type of automation reality, while retaining (and perhaps even improving) the accuracy of a human reader, Roy et al., sculpted ChronoNet.

The model architecture is designed just like an RNN, but with some slight modifications.

It takes certain features from Google’s Inception network architecture and DenseNet, a deep convolutional neural network architecture, to create the ideal architecture for EEG analysis.

This is ChronoNet

I’ll explain more about the architecture as I show you it’s replicated code.

For the purposes of building this model, I used a Jupyter Notebook.
Each code snippet represents a singular cell.

First things first, we need to import all our dependencies to design and the train the model

import glob #For getting mat files from our directory

import scipy.io as sio #For loading mat files
import mne #For EEG data handling
import numpy as np #For arrays and mathematical functions on arrays

import torch
import torch.nn as nn #For building our model
import torchmetrics #To assess model accuracy
from torch.utils.data import TensorDataset,DataLoader #To create a dataset from Tensors (our features and labels); To load our data in batches and iterate over it.

import sklearn
from sklearn.model_selection import GroupKFold #Gets GroupKFold Algorithm / Iterator
from sklearn.preprocessing import StandardScaler #Gets StandardScalar to standardize our dataset
from sklearn.base import TransformerMixin,BaseEstimator #To create our StandardScalar3D class

from pytorch_lightning import LightningModule,Trainer #LightningModule to organize our Pytorch code for trainer, to train our model.

SciPy, MNE, and NumPy allow for us to take in our dataset and apply the modifications needed to prepare it for our model.

PyTorch is the machine learning framework upon which our model is built on. It’ll allow for us to design our model through it’s neural network (torch.nn) module.

We import TensorDataset and DataLoader to allow for us to create a dataset based on our created Tensors and load it in batches for our model.

From SciKit-Learn, GroupKFold allows for K-Fold Cross Validation (allowing for more effective training), StandardScalar standardizes our data, and TransformerMixin / BaseEstimator lets us create a class to standardize data in the shape of a 3D matrix.

Finally, we use PyTorch Lightning which enables us to undergo the training process. LightningModule provides a framework for defining our model and the Trainer automates the training of our model.

This might seem a little confusing at first, but I’ll explain later in the article!

So for testing purposes, of our model, we first need to create a Tensor.

Investigating Tensors with PyTorch | DataCamp
This is a Tensor

A Tensor describes a data structure, which can be in the form of scalars, vectors, or matrices. It’s a generalized way of referring to data arrays of different dimensions.

input=torch.randn(3,14,512)
input.shape

We use torch.randn to create a 3-Dimensional Tensor with 3 being the “depth” of the Tensor representing the batch size, 14 being the the “rows / height” of the Tensor, and 512 being the “columns / width” of the Tensor.

Once you run this cell, you should get this output,

torch.Size([3, 14, 512])

verifying that our Tensor is in our desired shape.

Here’s a visual of how it’d look like if drawn out:

*not drawn to scale | Image by Author

Afterward, we create what’s called an Inception block.

Wait, what the heck is an Inception block?

Well, let me break it down for you.

In 2014, researchers at Google released a paper called Going Deeper with Convolutions”.

This paper outlined a convolutional neural network (CNN) “codenamed ‘Inception’ ” which was able to set “the new state of the art [SOTA] for classification and detection in the ImageNet Large-Scale Visual Recognition Challenge 2014”.

What made this Inception network SOTA in image recognition was it’s use of what are called Inception blocks.

In a vanilla CNN, at each layer you apply a convolution through a kernel to the input data, to then detect specific patterns or features mapped onto a newly created feature map for the next layer.

Sorta like this :)

If you’d like to learn more about CNNs, feel free to check out this video.

These convolution operations are typically applied with separate kernels within separate layers.

This means, in order to be able to detect specific features of a dataset with more precision, we’d need to add deeper layers in our network.

Sounds simple, right?

Well not so much. In reality, the introduction of deeper layers to our network can be a mitigator to the model’s accuracy for the desired task.

Too many layers can increase model complexity thereby increasing computational complexity, causing the model to overfit, which then therefore decreases the performance of the model.

The proposed inception network solves this problem by adding Inception blocks which use multiple kernels per layer to apply multiple convolutions per layer.

So, this is an Inception Block.

At the end of each layer, it takes the resulting output features of each convolution and concatenates them into a singular vector output which then serves as the input for the next inception block.

If you’d like to go a lil deeper, check out this video.

With this, we can extract more efficiently extract relevant features while optimizing for a lowered computational complexity.

In essence, this is what we’re adding to ChronoNet in the following lines of code.

class InceptionBlock(nn.Module): # Creating the class for our inception block
def __init__(self, in_channels):
super().__init__()
self.conv1 = nn.Conv1d(in_channels, out_channels = 32, kernel_size = 2, stride = 2, padding = 0) # Defines 1st Convolution
self.conv2 = nn.Conv1d(in_channels, out_channels = 32, kernel_size = 4, stride = 2, padding = 1) # Defines 2nd Convolution
self.conv3 = nn.Conv1d(in_channels, out_channels = 32, kernel_size = 8, stride = 2, padding = 3) # Defines 3rd Convolution
self.relu=nn.ReLU() # Defines the ReLU Activation Function

#Here, xn is the output of the nth layer.
def forward(self,x): #Defining the forward function
x1 = self.relu(self.conv1(x)) #performing 1st conv and outputting x1
x2 = self.relu(self.conv2(x)) #performing 2nd conv and outputting x2
x3 = self.relu(self.conv3(x)) #performing 3rd conv and outputting x3
x = torch.cat((x1,x2,x3), dim = 1) #taking all outputs of convolutions and concatenating them on 1 Dimension
return x

We create the class for our Inception block which will is made to extract the spatial features of a dataset. It inherits the methods and properties from the nn.Module class from PyTorch. This allows us to define, structure, and sculpt our network.

Within our InceptionBlock class and our __init__ function we initialize by using the super() function to call the methods and properties of the nn.Module class.

So then, we define our convolutions as described by the paper.

The 1st 1DConvolution (1DConv) uses a kernel size of 2, a stride of 2, and outputs 32 channels or feature maps.

Hence, in_channels, out_channels = 32, kernel_size = 2, stride = 2, padding = 0

The 2nd 1DConv uses a kernel size of 4, stride length of 2, and outputs 32 features maps.

in_channels, out_channels = 32, kernel_size = 4, stride = 2, padding = 1

And the 3rd 1DConv uses a kernel size of 8, a stride length of 2, and outputs 32 feature maps.

in_channels, out_channels = 32, kernel_size = 8, stride = 2, padding = 3

in_channels is defined when initializing the class and not within the nn.Conv1d() function, in order to allow for generalization for different potential use-cases for the network.

padding is simply a way of making sure that our data structure remains uniform to it’s intended use.

Then the ReLU activation function is called / defined to be applied onto our input data.

This is what the ReLU activation function is, where y=x

Activation functions are pretty important!

They decide how each individual neuron within a network is activated.

They also introduce non-linearity allowing for a network to detect non-linear patterns a dataset.

The next function we’ve defined is forward() which feeds our data into the network, performing 3 different convolutions ( conv1, conv2, conv3) on the same dataset in parallel, and the resulting outputs from the 3 convolutions are then concatenated onto a single vector, x.

So that’s the built Inception block which was reused 3 times in ChronoNet, made for improved computational efficiency and increased performance.

Now, we can define the actual architecture of ChronoNet.

class ChronoNet(nn.Module):
def __init__(self, channel):
super().__init__()
self.inception_block1=InceptionBlock(channel) # 1st Inception Block
self.inception_block2=InceptionBlock(96) # 2nd Inception Block
self.inception_block3=InceptionBlock(96) # 3rd Inception Block
self.gru1 = nn.GRU(input_size = 96, hidden_size = 32, batch_first = True) # 1st GRU layer
self.gru2 = nn.GRU(input_size = 32, hidden_size = 32, batch_first = True) # 2nd GRU layer
self.gru3 = nn.GRU(input_size = 64, hidden_size = 32, batch_first = True) # 3rd GRU layer
self.gru4 = nn.GRU(input_size = 96, hidden_size = 32, batch_first = True) # 4th GRU layer
self.relu = nn.ReLU() # ReLU Activation Function
self.gru_linear=nn.Linear(in_features = 64, out_features = 1) # Linear Layer for the 4th GRU
self.flatten = nn.Flatten() # Flattening Layer
self.fc1 = nn.Linear(32,1) # Fully Connected Layer / Output Layer.

def forward(self,x): # Defining the feed forward function
x=self.inception_block1(x) # Fed to Inception Block 1
x=self.inception_block2(x) # Fed to Inception Block 2
x=self.inception_block3(x) # Fed to Inception Block 3
x=x.permute(0,2,1) # Permuted for GRU layers
gru_out1,_=self.gru1(x) # Fed into GRU layer 1
gru_out2,_=self.gru2(gru_out1) # Fed into GRU layer 2
gru_out=torch.cat((gru_out1, gru_out2), dim = 2) # Concatenated, defining the skip connection
gru_out3,_=self.gru3(gru_out) # Fed into GRU layer 3
gru_out = torch.cat((gru_out1, gru_out2, gru_out3), dim = 2) #C Concatenated, defining the next 2 skip connections
gru_out = gru_out.permute(0,2,1) # Permuted for the linear layer
linear_out=self.relu(self.gru_linear(gru_out)) # Fed into the linear layer to reduce dimensionality
linear_out = linear_out.permute(0,2,1) # Permuted for the 4th GRU layer
gru_out4,_=self.gru4(linear_out) # Fed into the 4th GRU Layer
x=self.flatten(gru_out4) # Data is Flattened for Fully Connected Layer
x=self.fc1(x) # Fed into the Fully Connected Layer
return x # Output

ChronoNet is inheriting the methods and properties from the PyTorch nn.Module class.

The __init__ function defines our structure while the forward function feeds our data through the network.

Under __init__, we define the first 3 inception layers by using our InceptionBlock class we built earlier.

Afterward, the GRU layers are defined through the use of nn.GRU.

Hold up, what is a GRU?

A GRU stands for Gated Recurrent Unit. It’s an important feature of recurrent neural networks as they aim to mitigate what’s called the vanishing gradient problem

A vanishing gradient typically occurs when you input datasets of lengthy sequences.

As longer sequences of data are introduced and iterated on, gradients tend to become small that they don’t effectively update parameters of a layer leading to a slower training process.

This puts emphasis on more recent datapoints creating a type of bias within our network.

If you’d like to understand this concept more, feel free to check out this video.

So in order to reduce this issue, we introduce Gated Recurrent Units.

Here’s a quick visual:

It makes use of an input and hidden state at time step t, to then apply activation functions and concatenations to produce a new hidden state and output.

I won’t go too into detail here, but if you’re curious you can learn about it here. The mathematics is pretty interesting!

So we define the three GRU layers through nn.GRU, each taking in the appropriate input size to result in the appropriate output size, detecting specific temporal features of a dataset.

Afterward we define the ReLU activation function.

Then we define the final fully connected layers through the PyTorch function, nn.Linear.

These fully connected layers will take in the identified spatial features from the InceptionBlocks and the temporal features from the GRU layers, to output a final result.

Now, moving onto the forward() function we’ve defined, all this function does is feed our data through the network we’ve just defined.

  1. Data is passed through the Inception blocks
  2. Data is passed through the GRU layers
  3. Data is passed through the Linear layers
  4. Then, we get a final output.

We introduce the .permute() function in order to structure our data in the expected manner for the GRU layers (structure, per PyTorch docs) and once more to reverse that modification for the linear layers.

Before we move forward, I want to touch on the GRU layers once more.

In our code, [this is not a new Jupyter cell btw]

gru_out1,_=self.gru1(x) # Fed into GRU layer 1
gru_out2,_=self.gru2(gru_out1) # Fed into GRU layer 2
gru_out=torch.cat((gru_out1, gru_out2), dim = 2) # Concatenated, defining the skip connection
gru_out3,_=self.gru3(gru_out) # Fed into GRU layer 3
gru_out = torch.cat((gru_out1, gru_out2, gru_out3), dim = 2) #C Concatenated, defining the next 2 skip connections
gru_out = gru_out.permute(0,2,1) # Permuted for the linear layer

you might notice that we’re using torch.cat to concatenate the outputs of each GRU layer onto each next GRU layer.

You can visualize this in the diagram of ChronoNet provided by the original paper.

But why?

You might recall, as I mentioned earlier, “It [ChronoNet] takes certain features from … DenseNet.

Let me explain what DenseNet is.

You might have heard of another neural network architecture called, ResNet.

It aimed to solve the vanishing gradient problem, by connecting the output of each layer to the output of the very next layer.

just like this :0

This is a very high level description, you can learn more about it here or directly read the original paper here

DenseNet is a deep convolutional neural network (DCNN), that leverages skip connections in a similar fashion as ResNet.

The key difference lies in the amount of skip connections the network uses.

DenseNet takes the output of each layer and connects it to every single other layer in the network.

Ultimately, the architecture looks like this:

pretty sick, huh?

There are many key advantages to this architecture including, improved computational efficiency, improved back-propagation & deeper supervision for error correction, and most importantly an improved gradient flow which even further mitigates the vanishing gradient problem

When using sequential data, such as EEG, ridding of this challenge is super important in order to get accurate classifications.

It’s exactly why these connections, Dense Connections, are implemented in ChronoNet.

The C-RNN architecture is not immune to the problem of degradation which sometimes impedes the training of very deep neural networks…

….

To tackle this issue, inspired by the DenseNet architecture proposed by [13] for CNNs, we incorporate skip connections in the stacked GRU layers of C-RNN to form the C-DRNN architecture…. Intuitively, skip connections will lead to GRU layers being ignored when the data demands a lower model complexity than offered by the entire network¹⁴.

Roy et al.

and that’s exactly what those lines of code do.

So, that’s our model!

We can now begin to test, train, and validate our model!

Using the random tensor, input which we created earlier, torch.Size([2, 14, 512]), we feed it through the model just like this.

model = ChronoNet(14)
out = model(input)
out.shape

We don’t need to explicitly call the forward() function as PyTorch’s nn.Module class automatically does so for us.

We should get the following output:

torch.Size([3, 1])

where 3 defines the 3 batch sizes and 1 defines the probability per batch size.

If you run out.data, rather that out.shape, you’d be able to see the output per each batch for our randomly created tensor.

For example,

tensor([[0.1636],
[0.1654],
[0.1684]])

So, that verifies that we’re getting the expected output shape for our model.

In order to train our model, we first need to load our data onto our program.

In this case, we used this dataset, which holds the EEG data of people with intellectual development disorders (IDD) and a normal control group.
We’ll be training our model with this dataset.

IDD = 'Data/Data/CleanData/CleanData_IDD/Rest'
TDC = 'Data/Data/CleanData/CleanData_TDC/Rest'

Now, the EEG from the dataset is in the .mat format, which denotes MATLAB files.

We don’t want our files to be in .mat format, we want it to be in a format that allows for our model to read and use the data.

This is where the library, MNE, sculpted specifically for EEG analysis, comes into play.

We define and build this function

def MatMNE(data):
ch_names = ['AF3', 'F7', 'F3', 'FC5', 'T7', 'P7', 'O1', 'O2', 'P8', 'T8', 'FC6', 'F4', 'F8', 'AF4'] #Defining Channel Names
sampling_freq=128
info = mne.create_info(ch_names, sfreq=sampling_freq, ch_types='eeg') #Creating an MNE Info Object
info.set_montage('standard_1020') #Setting our EEG montage
data = mne.io.RawArray(data, info) #Creating a RawArray to make our data readable for MNE
data.filter(l_freq=1, h_freq=30) #Defining a bandpass filter
data.set_eeg_reference() #Setting our reference EEG by taking the average of all channels
epochs = mne.make_fixed_length_epochs(data, duration = 4) #Creating epochs
return epochs.get_data()

which allows us to convert our data from a .mat file into a 3D NumPy array, making our data readable for our model.

Within the function, we define the metadata of the dataset, both the names of the channels used and the sampling frequency.

Using the ch_names and sampling_freq we use the funtion, mne.create_info() to create an info class, holding all the needed info. Then we set the montage for the EEG dataset, which in this case was the 10–20 system.

Afterward, through the function mne.io.RawArray(), we create a raw object that will hold our data and info object. The data is filtered, data.filter, using a bandpass filter from 1 Hz — 30 Hz. Then the data will be re-referenced by taking the average of all channels and subtracting it from each channel, through data.set_eeg_reference().

Then, we epoch our data to a length of 4 seconds using, mne.make_fixed_length_epochs(). Our final function, epochs.get_data() then returns our data as a 3D NumPy array.

So now, we load our data and make use of this function

%%capture #To rid of uneccesary logging

idd_subject=[] #Creating a list for IDD subject data

for idd in glob.glob(IDD+'/*.mat'): #For each .mat file, read it as idd
data = sio.loadmat(idd)['clean_data'] #for each .mat file, idd, get data under the clean_data key in a NumPy array.
data = MatMNE(data) #Passing our data through our MatMNE function and reassigning to 'data'.
idd_subject.append(data) #Appending the patient data into a singular array.

So we initialize a list which will hold the EEG of patients with intellectual developmental disorder. Using glob, we get our data from our directory and use scipy.io.loadmat() to load our MATLAB files. Now we use our defined cuntion MatMNE() to convert our .mat files into a readable NumPy array. We then get all the created arrays, and append it to the idd_subject list..

The same is done for the control group

tdc_subject=[] #Creating a list for IDD subject data

for tdc in glob.glob(TDC+'/*.mat'): #For each .mat file, read it as idd
data = sio.loadmat(tdc)['clean_data'] #for each .mat file, idd, get data under the clean_data key in a NumPy array.
data = MatMNE(data) #Passing our data through our MatMNE function and reassigning to 'data'.
tdc_subject.append(data) #Appending the patient data into a singular array.

Now we create labels for our dataset.

control_epochs_labels=[len(i)*[0] for i in tdc_subject] #Labels our data for TDC as '0'
patients_epochs_labels=[len(i)*[1] for i in idd_subject] #Labels our data for IDD as '0'

We essentially label the data elements of our tdc_subject list with 0s and our idd_subject list with 1s.

Now, we modify our data further,

data_list = tdc_subject + idd_subject #Adds the length of our 2 datasets to a singlular variable
label_list=control_epochs_labels + patients_epochs_labels #Adds the labels of our 2 datasets to a singular variable
groups_list = [[i]*len(j) for i, j in enumerate(data_list)] #Indexes the test subjects / patients

and append our data into a singular data_list. The same is done with the labels, creating a singular label_list. Our data_list is iterated on to create indices (enumerate) for each test subject.

Now, we’d need to standardize our data, bringing data down to a similar scale, to ensure we maintain uniformity in our model.

Given that our data is 3 dimensional, and sci-kit’s StandardScalar() only accepts 2D data, we need to create a class that defines a 3D standard scalar.

So we define the class and functions,

gkf = sklearn.model_selection.GroupKFold()  # Assigning GroupKFold Function to gkf
class StandardScaler3D(BaseEstimator,TransformerMixin): #3D data shape of [Batch, Sequence, Channels]
def __init__(self):
self.scaler = StandardScaler()

def fit(self,X,y=None):
self.scaler.fit(X.reshape(X.shape[0], -1))
return self

def transform(self,X):
return self.scaler.transform(X.reshape(X.shape[0], -1)).reshape(X.shape)

which will allow us to standardize our 3D data.

We also, initialized the GroupKFold() function, which will aid us in the next step.

Now, we concatenate our data and modify it a bit further,

data_array= np.concatenate(data_list) 
label_array=np.concatenate(label_list)
group_array =np.concatenate(groups_list)
data_array = np.moveaxis(data_array, 1, 2)

moving the 1st axis of data_array to the 2nd.

Now, we perform what’s called K-Fold Cross Validation using Scikit-Learn’s GroupKFold() function.

In essence, K-Fold Cross Validation is an approach for training a machine learning model that splits our data in various ways to allow for optimal training and testing.

K is the variable that defines the number of groups / folds we divide our data into.

During each iteration of K-Fold Cross Validation, a single fold is assigned as the testing set while the remaining folds serve as the training set. The model is trained on the training set and validated on the testing set.

This is repeated over i iterations, with each iteration having a different testing set.

Afterward, the result of each iteration on the testing set is averaged to yield the overall evaluation for the model.

So, if I wanted to perform 4-Fold Cross Validation on a dataset of 24 data points, it might look something like this:

4 Fold Cross Val | Image by Author

So in the following code,

accuracy = [] #Creates a list to store our model accuracies accross iterations
for train_index, val_index in gkf.split(data_array, label_array, groups=group_array):
# In GroupKFold cross-validation, get unique indices for training (train_index) and validation (val_index) sets.
# The 'groups' variable ensures that there is no data leakage between different subjects/patients.
train_features, train_labels = data_array[train_index], label_array[train_index]
# Gets training features and labels based on the indices obtained at the kth split.
val_features, val_labels = data_array[val_index], label_array[val_index]
# Gets validation features and labels based on the indices obtained at the kth split.
scaler = StandardScaler3D() # Initializes a StandardScaler instance for feature scaling.
train_features = scaler.fit_transform(train_features) #Fits the data and then transforms (standardizes) it
val_features = scaler.transform(val_features) #Transforms (standardizes) data
train_features = np.moveaxis(train_features,1,2) #Flip Axis to fit into ChronoNet Architecture
val_features = np.moveaxis(val_features,1,2) #Flip Axis to fit into ChronoNet Architecture

We create a list to store the accuracy of our model for each iteration. Then we use the .split() function to divided our data into 5 folds (the default value).

Then, we define the training features (train_features) and (train_labels) based on the indices obtained at k-split. The same is done for the validation features (val_features) and labels (val_labels).

Then we apply our predefined StandardScalar3D() function to standardize our data.

Now, we convert our NumPy arrays of training and validation data,

train_features = torch.Tensor(train_features)
val_features = torch.Tensor(val_features)
train_labels = torch.Tensor(train_labels)
val_labels = torch.Tensor(val_labels)

into Tensor format.

Tensors differ from NumPy arrays as they’re able to run on GPUs and provide more functionality when working with deep learning as they support operations for training neural networks, computing gradients, and performing more efficient computation

Now, it’s time to train and test our model.

We’ll be using PyTorch Lightning, to allow for easier training and testing.

We define our class and functions for training / testing the model.

class ChronoModel(LightningModule):
def __init__(self):
super(ChronoModel,self).__init__()
self.model=ChronoNet(14)
self.lr=1e-3 ##Defining learning rate of our model. .0001 per step size
self.bs=12 ## Defining the batch size
self.worker=2 ## Defining the # of workers, a parallel process
self.acc=torchmetrics.Accuracy(task='binary') ## For measuring accuracy of our model.
self.criterion = nn.BCEWithLogitsLoss() ## For measuring accuracy of our model based on the final Sigmoid Activation function.
self.train_outputs = [] #To store our outputs from training the model
self.val_outputs = [] # To store our outputs from validating the model


def forward(self,x): ## Defining forward function, for feeding our data into the model.
x=self.model(x)
return x

def configure_optimizers(self): ## Defining our optimizer
return torch.optim.Adam(self.parameters(), lr=self.lr) #Implementing the Adam optimizer on our model.

def train_dataloader(self): #Loads our training data
dataset = TensorDataset(train_features,train_labels) #Creates a tensor data object from our Tensors representing our training features
dataloader = DataLoader(dataset, batch_size=self.bs,num_workers = self.worker,shuffle=True) #Loads our data using the dataset, batch size, workers (parallel processes). Our data will be shuffled at each Epoch (per shuffle = True) to prevent overfitting.
return dataloader

def training_step(self,batch,batch_idx): #Defining our function for a single training step
signal,label = batch # From each batch, we unpack signal and label data
output=self(signal.float()) # Output of data given a signal
loss=self.criterion(output.flatten(),label.float().flatten()) #Calculating loss using the output and BCEWithLogitsLoss function
acc=self.acc(output.flatten(),label.long().flatten()) #Calculating the accuracy using the output and the Accuracy() function from torchmetrics
self.train_outputs.append({'loss': loss, 'acc': acc}) # To append / add our output onto our training output list
return {'loss':loss,'acc':acc} # Returns our model loss and accuracy

def on_train_epoch_end(self):
acc=torch.stack([x['acc'] for x in self.train_outputs]).mean().detach().cpu().numpy().round(2) # Takes the average accuracy for the outputs and stacks it onto a singular Tensor. Then detaches it from the gpu, converts it into a NumPy array, round it to the nearest hundreth, and passes it onto the cpu
loss=torch.stack([x['loss'] for x in self.train_outputs]).mean().detach().cpu().numpy().round(2) # Takes the average loss for the outputs and stacks it onto a singular Tensor. Then detaches it from the gpu, converts it into a NumPy array, round it to the nearest hundreth, and passes it onto the cpu
self.train_outputs.clear() #To free up memory after each Epoch
print('train acc loss', acc,loss) # Printing our final training accuracy and loss

def val_dataloader(self): #Loads our validation data
dataset = TensorDataset(val_features,val_labels) #Creates a tensor data object from our Tensors representing our validation features
dataloader = DataLoader(dataset, batch_size=self.bs,num_workers = self.worker,shuffle=True) #Loads our data using the dataset, batch size, workers (parallel processes). Our data will be shuffled at each Epoch (per shuffle = True) to prevent overfitting.
return dataloader

def validation_step(self,batch,batch_idx): #Defining our function for a single step
signal,label = batch # From each batch, we unpack signal and label data
output=self(signal.float()) # Output of data given a signal
loss=self.criterion(output.flatten(),label.float().flatten()) #Calculating loss using the output and BCEWithLogitsLoss function
acc=self.acc(output.flatten(),label.long().flatten()) #Calculating the accuracy using the output and the Accuracy() function from torchmetrics
self.val_outputs.append({'loss': loss, 'acc': acc}) #To append / add our output onto our validation output list.
return {'loss':loss,'acc':acc} # Returns our model loss and accuracy

def on_validation_epoch_end(self):
acc=torch.stack([x['acc'] for x in self.val_outputs]).mean().detach().cpu().numpy().round(2) # Takes the average accuracy for the outputs and stacks it onto a singular Tensor. Then detaches it from the gpu, converts it into a NumPy array, round it to the nearest hundreth, and passes it onto the cpu
loss=torch.stack([x['loss'] for x in self.val_outputs]).mean().detach().cpu().numpy().round(2) # Takes the average loss for the outputs and stacks it onto a singular Tensor. Then detaches it from the gpu, converts it into a NumPy array, round it to the nearest hundreth, and passes it onto the cpu
self.val_outputs.clear() # To free up memory after each epoch.
print('val acc loss', acc, loss)

Our training process is initialized (__init__) by setting basic configurations such as the learning rate, batch size, number or workers, etc.

Our data will be fed into the network, both training and validation sets, each outputting an accuracy value and a loss value after the end of each training_step and validation_step.

To actually run this process, we can run the following lines of code:

model = ChronoModel()
trainer = Trainer (max_epochs = 1)
trainer.fit(model)

And you should get the accuracy / loss for the 1st training and validation set!

After my first run this is what I got:

Definitely not optimal, as of yet, but I hypothesize that given more iterations and updates to the model weights, this’ll become increasingly accurate.

Per ChronoNet’s paper, it reached a training accuracy of 90.6% and testing accuracy of 86.57%, “compared to the recently published state-of-the-art performance [19], ChronoNet shows 1.17% better accuracy. Out of the four recurrent architectures, ChronoNet achieves both the best training and testing accuracy”.

So What?

ChronoNet, by Roy et al, outlines the essential uses of ML models to automate tasks without human error to make way for higher accuracy.

EEG is already prone to mistakes in signal processing, a low signal to noise ratio, volume conduction, high impedance, that having a reliable near accurate interpretation EEG becomes extremely difficult.

You might remember,

“When interpreting EEGs into the seven primary categories, the probability that a randomly selected pair of readers will disagree on a randomly selected category is about 42%, implying that the probability of one reader being wrong is at least 21%”

— Grant et al

21% chance of a misdiagnosis! Crazy, right?

One of the paths forth to improve this is to deploy ML models, such as ChronoNet to healthcare infrastructures to improve diagnostics and improve overall quality of care.

Hopefully some day in the future, this’ll become reality :)

i hope you enjoyed this read

feel free to contact me on twitter or linkedin if you have any questions!

also, feel free to subscribe to my newsletter where I send out bi-weekly updates on what I’m working on!

--

--