Classifying MNIST Digits from Brain Waves using Machine Learning in Python
Writers: Zander Chown and Caelen Hilty
Completed for the course: Computational Analysis of Big Data at DIS Copenhagen
All code available on GitHub.
One of the first problems commonly used to introduce machine learning to students is the MNIST handwritten digits dataset, which has been around since 1998. The dataset is composed of 70 thousand handwritten digits, expressed in 28 x 28 pixel grayscale images. It is a great entry point for learning machine learning techniques because it is somewhat interesting, easy to solve by most algorithms, and used as a benchmark test by the cutting edge of the field.
For some, including David Vivancos, research director at MindBigData, the handwritten digit task has just gotten too easy. The most sensical way to up the challenge? Filter the classic MNIST images through the human brain of course!
On June 1st, 2023, MindBigData released an open dataset with 140,000 2-second unprocessed electroencephalography (EEG) recordings of one subject, David Vivancos, looking at one MNIST digit at a time on a screen. During each digit presentation, he also listened to an audio recording of the corresponding number being spoken.
With the release of this dataset, MindBigData has challenged the ML community to perform the classic MNIST handwritten digit classification task, but only using the information contained in the EEG recordings. While this specific task has minimal real-world applications, EEG itself is relevant for understanding normal and diseased brain function, so advancements on this dataset will improve not only ML techniques, but also further our understanding of the brain, computer-brain interfaces, and provide insight into the viability of ML-assisted diagnosis.
EEG is a non-invasive neuroimaging technique that measures electrical activity in the brain using an array of electrodes placed on the scalp. It is commonly used in a research setting to study neurological disorders and sleep.
For this project, our goal was to, using Python, make a first attempt at extracting the necessary information from the raw EEG data to perform this classification task. We won’t keep you in suspense: we failed. But we also hope that our failure will provide some insight into the nature of the dataset and its challenges to inform future attempts.
Baseline Tests
To get started working with the data, we threw together some quick and dirty tests to establish a baseline accuracy to improve on with a more complex model.
The dataset is huge, so we tried coarse feature extraction using a Fourier transform of each signal from each channel for each trial in the dataset. A Fourier transform maps the signal from the time domain to the frequency domain, finding the frequencies of the sine and cosine waves that sum to produce the original signal.
NumPy’s fast Fourier transform makes this easy. After transforming each signal, we aggregated the frequencies into 5 Hz bins, summing the spectral power within each bin. This technique reduced the dimensionality of our data from 64,000 data points per signal to just 1,280, providing a more manageable set of features for preliminary analysis.
We tried three simple machine learning approaches on this processed data.
- Random Forest Classifier: constructs an ensemble of decision trees that separate the classes based on thresholds along each dimension of the data. Accuracy: 9.61%
- k-Nearest Neighbors: looks at the k most similar data points from the training set and uses the most frequent class in those ‘neighbors’ to classify a new data point. Accuracy: 10.60%.
- PCA and t-SNE: two unsupervised learning techniques that reduce the dimensionality of the data, finding the linear (PCA) or non-linear (t-SNE) directions along which the most variance is explained. Once in this reduced dimensional space, we can look for separability or clustering of the different classes. Clustering: None.
None of these simple techniques could perform significantly better than a guess on this complex dataset. Admittedly, the Fourier transform binning was extremely coarse and we spent minimal time optimizing model hyperparameters, so it would have been surprising to see any other result.
On to the big model!
Deep Learning Pipeline
With a baseline accuracy established, we moved towards a deep learning strategy with three major phases:
- Preprocessing — clean the raw EEG data and reduce data storage requirements
- Channel Selection — identify which EEG electrodes contained the most important information about the task
- Neural Network Training — train a neural network on the processed data
Preprocessing
Raw EEG data is extremely noisy and contains an unwieldy amount of data. Processing this data to address these problems was absolutely necessary before training a model. That being said, we wanted to minimize the amount of preprocessing involved to avoid corrupting the signal. Any steps we missed, we hoped, could be handled by the neural network.
Our preprocessing pipeline involved four steps:
- Bandpass filtering
- Decimation
- Truncation
- Z-score normalization
(1) Bandpass filtering
There are two major sources of noise that we sought to eliminate using a bandpass filter.
The first source of noise is aliasing. When a high-frequency oscillating signal is sampled at too low of a sampling frequency, the true signal is distorted.
The Nyquist-Shannon theorem states that a signal containing a maximum frequency of B Hz can be accurately represented with a sampling rate of 2*B Hz. In other words, any frequencies found in our raw EEG data that are greater than half of our sampling rate are not to be trusted. The sampling frequency in the raw data was 250 Hz, so we must discard any frequencies above 125 Hz.
The second source of noise in our data is electrical interference. Where the EEG recordings were made, San Lorenzo de El Escorial, Spain, the frequency of electrical service is 50 Hz. In the power spectral density plot of our signals, we see a huge spike at 50 Hz, which has nothing to do with the processes in David’s brain and should be eliminated.
To eliminate both of these sources of noise, we performed a bandpass filter, which attenuated signal frequencies below 0.1 Hz and above 35 Hz. This step both satisfied the Nyquist-Shannon theorem (35 Hz < 125 Hz) and eliminated the electrical interference found at 50 Hz. We eliminate frequencies below 0.1 Hz as they are too slow to be significant in a 2-second stimulus response paradigm. After performing a Fourier transform, the amplitude of each frequency was rescaled by the filter before signal reconstruction.
The SciPy Python package contains a number of functions useful for signal processing, including filters, but found it instructive to construct our own from scratch. This is a potential source of error in our preprocessing.
(2) Decimation
Because signals above 45 Hz had been entirely attenuated by the bandpass filter, it was possible to decrease the sampling rate without distorting the filtered signal. We chose to down sample by a factor of 2, bringing the new sampling rate to 125 Hz. This step also halved the data storage requirements, which was highly desirable.
(3) Truncation
For the first hundred milliseconds, the human brain is not yet fully aware of the stimulus. To prevent bleed-over from previous trials and eliminate irrelevant parts of the recording, we truncated each signal by removing the first third of the data.
(4) Z-score normalization
Finally, another source of noise is that the EEG signals were recorded across over 200 sessions. There is a wide range of potential confounding variables that could change the way the subject’s brain responds to a stimulus on any given day: diet, sleep, caffeine, temperature, mental state, etc. could all change the signal in unpredictable ways.
It is likely impossible to fully eliminate these effects. As a simple approach, we used z-score normalization, rescaling the signals such that within each session, the mean signal strength was zero with a standard deviation of 1.
It may be worthwhile to use a correlation-based analysis to eliminate signals that vary significantly from their within-label mean across sessions or some other more advanced technique for eliminating the effects of these confounding variables.
Channel Selection with k-Nearest Neighbors
128 channels of EEG data is a lot of EEG data. Even after it has been cleaned and processed it is the case that many of these channels are recording parts of the brain that likely have nothing to do with image recognition, hearing or even numbers. With this in mind or next goal after preprocessing was to identify what channels are most important for classification.
The approach we started with was taken from the paper “Visual Brain Decoding for Short Duration EEG Signals.” For this exact same classification problem on a different dataset, their approach was to create k-Nearest Neighbors (kNN) models to do binary classification between all numbers for each channel. It doesn’t really make sense to us to use kNN for time series data, but because of the success that Mishra, Sharma and Bhavsar had, we went through with this approach. Part of the reason it might makes sense in this context is because you don’t need the models to be particularly good, only good enough to create a distinction between important and unimportant channels.
There are 45 combinations of numbers 0–9 to do binary classification on; this isn’t that bad when there are only 14 channels as there were in the Mishra et al. dataset, but for 128 channels this becomes a bit much. With this in mind our initial approach was to use kNN to perform binary classification between a number and not a number for all of the channels. The “not-a-number” samples are taken from the EEG recordings during the 2 seconds of black screen that David looked at between each digit. This approach was a failure with the absolute best channels only achieving 52% accuracy, a number easily explained by random chance.
With that failure the decision was made to go back to binary classification between all number pairs. With 45 pairs multiplied by 128 channels we were a bit intimidated by the 5760 models that would have to be created and run. Fortunately, although Kaggle was tested by the extra memory requirements, we were able to multi-thread this process and train models on four channels at a time.
During this process, we used Pearson correlation to remove more noisy signals to see if that would help improve the results. This was done by correlating every signal in a channel to the mean signal for that channel. If a signal had too low of a correlation with the mean, it was excluded from analysis. With no idea what threshold level to try we ran the full kNN channel tournament three times: once removing no noisy signals, once with only signals above 0.2 correlation, and once with correlations above 0.3.
From all of these models we determined which channels were best at distinguishing between all digit pairs and which channel was most accurate for each particular number comparison. Our hope was that there would be a significant overlap between all of these different selections and a few channels would shine out as the best. This did not happen — there was a little overlap between the three different runs of this process, but not a lot. In total, the three runs agreed on only 8 channels as important and disagreed about all other channels that had been identified as important in individual runs.
With this result we made the decision that because none of the accuracies had been high and we had very little faith in any part of the channel selection process, we would go forward working with all 128 channels and hope that a neural network could learn to ignore the unimportant channels.
It should be noted a convolutional neural network was tried with the 8 agreed upon channels and achieved 92% accuracy on the training data. However, as will become a theme with this project, this model achieved only 11% accuracy on the test data. As will be explained later, this is not an unusual result for the neural networks trained during this project.
Neural Network Training and Results
Our data is in the shape of 128 channels by 167 time series points. This format is pretty similar to image data, so some of the typical concepts related to working with image data can be applied to our problem.
With image data a typical component of neural network architecture is a 2-D convolutional layer. These layers work by sliding a series of filters or ‘kernels’ across the image. These filters preserve information about the spatial arrangement of the data. Our data does have a spatial component to it, we could for example arrange our 128 channels by their location on David’s head. This has been done by other teams working with similar datasets that David has created, though this approach displayed no significant advantages.
For simplicity, we chose instead to focus only on the temporal dimension of the data. Just like in image data, where points next to each other should be considered together, in temporal data, points before and after one another should be considered as related. With this in mind, we chose to use a 1D convolutional network, where a convolutional filter with shape (number of channels x filter length) is slid across only the time dimension, summing across channels.
Our initial model was a single 1D convolutional layer into a dense layer. Initially, we used the tanh activation function to allow negative activation values, indicating negative correlation between the kernel and that region of the signal. However, we quickly realized that the standard ReLu or leaky-ReLu activation functions offered a significant boost to training speed.
Our first “successful” model consisted of a rescaling layer to bring the inputs closer to the -1 to 1 range, a 1-D convolutional layer consisting of 128 filters of size 16, followed by a dense layer of 256 nodes and our 10 node output layer. This model was able to quickly get to 98% train and 99% validation accuracy. While excited, we were skeptical. It seemed unlikely that we could get such a good result so quickly and with such a relatively simple network design. Sure enough, when applied to the unseen test data it achieved a mere 10% accuracy, no better than a guess.
Disappointed, we sought a reason for why our model would do so well on our validation data and yet so poorly on the test data. One reason that we considered was that the test data was all taken from the final sessions of recording: could it then be possible that David’s brain had changed during the recording process and our model trained on his brain waves from early in the recording process could no longer understand his brain by the end? During the loading of the data into a TensorFlow dataset, we had taken random points to be used for validation from each block of 1,000 data points we would load in at a time. In this way our training and validation data included data points from the same recording sessions.
If David’s brain had changed we could prove this by doing a k-fold cross-validation. We divided the data into 20 categories without randomization, maintaining their original order in the dataset. Then we divided it again, this time separating data into 20 categories randomly. If our theory about David’s brain was correct we expected that our model would only perform well on the randomized k-fold data.
In doing this test we had to change the way we loaded data by changing from a TensorFlow dataset to a NumPy array to not get memory crashes. Through this change we were no longer able to get beyond 11% validation in any of our folds, randomized or nonrandomized. Despite rechecking our earlier code and not finding anything suspicious (we had split our dataset using TensorFlow .take() and .skip() methods), there must have been a data-related bug somewhere for us to be able to get our 98% validation result.
Going forward, we used this new more memory efficient loading function with NumPy arrays to load our data and used the Sci-kit learn library to split our data into train and validation sets.
Our goal was to find a new model architecture and hyperparameters that would allow a model to learn on both the training and the validation data. The problem we encountered time and time again, over more than 100 different attempts, was that either:
1) the model would fail to learn anything, remaining at about 10–11% accuracy (baseline) on both train and validation sets, OR
2) the model would overfit so violently that it would achieve upwards of 90% accuracy on train and only baseline accuracy on validation.
This level of aggressive overfitting was something that we had never seen before and we struggled to figure out how to tune our models, both in their architecture and different regularization techniques to find a happy balance.
On the model architecture side of the equation, we tried many changes:
- different filter sizes for the convolutional layers, ranging from a size of 5 to a size of 16
- varied number of convolutional layers, anywhere from a single layer up to 5 layers
- transposed 1-D convolutional layers — layers that slide a convolutional filter across channels instead of time
- different sizes of dense layers and different numbers of dense layers
- maximum pooling layers
- batch normalization layers
- LSTM — a recurrent neural network architecture
- 2-D convolutional layers
On the regularization side we tried models with varying amounts of L1 and L2 weight penalties, models with regular dropout, and models with spatial 1-D dropout. With all of these regularization techniques the balance point was incredibly precise. Adding just a little L1 or L2 regularization or a bit more dropout could instantly turn a model that was violently overfitting into a model that could not achieve beyond baseline for training or test.
Another interesting result that was found when testing hyper parameters was the effect of batch size. It was our initial belief that larger batch sizes would help prevent overfitting and data memorization. This proved to be incredibly wrong. Testing on our first “successful” model, we found that a batch size of 128 led to violent overfitting. At 64, the model learned much slower but still ended up overfitting the data. At batch sizes 32 and 16, the model failed to progress past baseline. As it turns out, the conclusion that larger batch sizes can lead to more memorization-based overfitting is supported by the literature.
Overall result: Failure.
Takeaways
Despite not finding a model architecture that could successfully classify the digits, our work has not been entirely useless.
With minimal prior knowledge about signal preprocessing and EEG data, we were at least able to produce models that could overfit the training data. While this could mean that our model is simply memorizing every data point in the training set, we don’t believe this is the case:
- The training dataset is very large and the models are relatively small, making it difficult to overfit to such an extreme.
- During training, many of our models could easily achieve >95% training accuracy. When we shuffled the labels, these same models could not achieve significantly higher than 10% training accuracy.
Thus, assuming no massive oversights in our data pipeline, it is likely not the case that our models are simply memorizing the training data. Instead, it was overfitting to real features that accurately describe the different classes within the training data, just failing to find features that can generalize to the test data. If even our naive, sub-optimal preprocessing was capable of extracting some information out of the data, there is some hope that this task should be possible.
However, our results also give cause for some pessimism. Our primary suspect for the aggressive overfitting we observed in our models is that our preprocessing has failed to eliminate the variability between sessions. Improving on our work will likely involve the application of better domain knowledge to inform high-level feature extraction. Finding universal features that are mostly invariant in the face of the many variables that can affect a subject’s brain activity on any given day may be impossible with the limitations of existing EEG technology.
Sorry, David. At least for now, you’ll have to figure out what digit you’re looking at all by yourself. :(