Code Review: Google Colab MNIST prediction with sklearn and skorch

Apr 1 · 5 min read

For the past few weeks I have been endeavouring to improve my programming skills and have begun studying PyTorch, a library that spun off of the c++ library, Torch. PyTorch is Python friendly, so it can easily be integrated into a program written in Python.

One trying that I have been studying is convolutional neural networks, or CNN, and in particular I wanted to use a CNN in the MNIST dataset, which is a dataset that is composed of digital images of digits between 0 and 9. Because I am new, I researched the internet how CNN’s are used on the MNIST dataset and I was unable to find exactly what I wanted. Because I am used to entering competitions in Kaggle and Analytics Vidhya, just to name a few, I wanted to find an algorithm that would actually reveal the predictions.

I searched the internet and found quite a few algorithms, but although they recorded the accuracy and loss of the model, they did not specifically list the predictions, which is what I wanted to see.

After several days of research and asking different questions to the search engine, I finally found an algorithm, which is the subject of this post. This program is a little different because it utilises the library, skorch, which is a sklearn compatible neural network library for pytorch. The documentation for skorch can be found here:-

Link to GitHub account where I found the algorithm that is used in this post can be found here:- skorch/MNIST.ipynb at master · skorch-dev/skorch (

In addition to using skorch, the program also uses the MNIST dataset from openml, which is a website that lists a variety of websites that data scientists can download and practice on. The documentation for the function fetch_openml() can be found here:-

The program used in this post has been written in Google Colab, which is a free online Jupyter Notebook that has many libraries already installed on it. Unfortunately, skorch is not already installed on Google Colab, so it must be installed using the code below:-

Once skorch is installed, it is necessary to import the libraries that will be necessary to execute the code, being sklearn, numpy, pandas and matplotlib:-

The dataset must be loaded, and in this instance the MNIST dataset in the openml website is used. Sklearn has a function, fetch_openml() can be used for this purpose.

The MNIST dataset in openml has 70,000 rows of data, so before going any further it would be a good idea to set Google Colab to work with the GPU, as it has more memory and will work faster. This is accomplished by using Google Colab’s menu selection, Runtime → Change Runtime → GPU.

The data must be preprocessed and this is accomplished by first defining the X and y variables. The X variable, being independent, is‘float32’). The y variable, being independent, is‘int64’).

The X variable is then normalised to a value of between 0 and 1. This is accomplished by dividing the value of X by 255. The reason for this is because each value in the dataset is between 0 and 255.

Once the X variable is normalised, it is split into training and test sets using sklearn’s train_test_split() function:-

(A simple neural network is created with the preprocessing statements above, but this network is not included in this post. This model scored an accuracy of 96.2%)

A convolutional neural network is then created in order to train the training set and make predictions on the test set.

The variable XCnn is defined by reshaping the X variable so it will fit into the neural network.

XCnn is then split into training and test sets using sklearn’s train test split() function.

A convolutional neural network is then created using the Pytorch library:-

Once the convolutional neural network is created, a random number is generated using torch.manual_seed(0).

The data is then fed into skorch’es NeuralNetClassifier() function and where it is trained and then fit, which is very similar to how sklearn’s estimator’s work.

The model is then predicted on using the test data and an accuracy of 98.84% is achieved, which is better than the score that would be achieved by a simple neural network:-

I then implemented a confusion matrix, which shows the examples that are correct and the ones that are in error:-

In conclusion, I am very happy to have stumbled across the skorch library and will endeavour to use it whenever I am working with PyTorch. Skorch provides a wrapper around PyTorch that has a sklearn interface. Skorch abstracts away PyTorch’es training loop, which makes a lot of the boilerplate code obsolete. Skorch also works with many types of data, to include PyTorch tensors, numpy arrays, and Python dicts.

The code for this program can be found in its entirety in my personal GitHub account, the link being here:- MNIST/MNIST_PyTorch.ipynb at main · TracyRenee61/MNIST (


Medium is an open platform where 170 million readers come to find insightful and dynamic thinking. Here, expert and undiscovered voices alike dive into the heart of any topic and bring new ideas to the surface. Learn more

Follow the writers, publications, and topics that matter to you, and you’ll see them on your homepage and in your inbox. Explore

If you have a story to tell, knowledge to share, or a perspective to offer — welcome home. It’s easy and free to post your thinking on any topic. Write on Medium

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store