MNIST — Digits Classification with Keras

Manish Bhobé
7 min readSep 27, 2018

--

Achieving 99% accuracy on MNIST digits data

In this article I will show you how to develop a deep learning classifier using Keras library to achieve 99% accuracy on the MNIST digits database. We will develop a Convolutional Neural Network (CNN) for the classification. Code for this is available at my GitHub repository.

There are several excellent articles that explain how a CNN works — please refer to the following links on medium.com.

The MNIST dataset is an image dataset of handwritten digits made available by Yann LeCun et. al. here. It has has 60,000 training images and 10,000 test images, each of which are grayscale 28 x 28 sized images. It is a good beginner’s dataset to try learning techniques and pattern recognition methods on real-world data while spending minimal efforts on pre-processing and formatting.

Loading & Pre-processing Data

You can download the images from the above link, but the Keras library already provides a database of these digits in its keras.datasets module. We are going to use it from this module, to save some pre-processing time. Here is the code to load & display a random set of 14 digits from the dataset:

The above code produces an output like the one below — your’s could be different depending how you have seeded your random number generator.

Random Sample of 14 digits from the training dataset

Before we feed the images to our CNN, we need to do some pre-processing of the images. Following are the pre-processing steps:

  • Reshaping the images data to a tensor of shape (num_samples, image_height, image_width, num_channels) - for our 28x28 grayscale images, this would be (num_samples, 28, 28, 1), where num_samples = 60,000 for train dataset and num_samples = 10,000 for test dataset.
  • Re-scaling the images data to a values between 0.0 and 1.0 (i.e. each pixel should have value between (0.0 and 1.0] –0.0 inclusive & 1.0 exclusive.
  • One-hot-encode the labels — Keras provides a to_categorical() function in it's utils module, which we will use.

Following code illustrates the pre-processing steps:

There is one final pre-processing step, where I split the training dataset into train & cross-validation set. I follow the best practice of cross-training my neural network (NN), where the NN is trained on a training set and performance evaluated on a cross-validation set. I do not use the test set during the training, but only for the final evaluation & predictions. This way, I always have some data that the model has never seen. This helps me avoid over-fitting to some extent.

I will set aside 10% of the train set as the validation set (val set). Following code shows you how to split the train set into train & validation sets.

Now we are ready to build our CNN

Building our CNN with Keras

Keras is a very versatile, yet simple to learn and understand, deep learning libraries that can run on-top-of several other deep learning frameworks — it supports Tensorflow, Theano and Microsoft CNTK, with Tensorflow being the default. There are several reasons why you should make Keras your first library for deep learning — here are some of them.

Keras can be used both with a CPU as well as a GPU. For this example, I am using Keras configured with Tensorflow on a CPU machine — for a simple model like MNIST, a CPU configuration suffices. For any serious deep learning projects, a GPU is highly recommended otherwise complex ML models will take excruciatingly long to train.

We will build a CNN with the following architecture, using Keras’ Sequential API:

  • 3 Conv2D layers with 32, 64 and 64 filters each, using the relu activation, kernel_size=(3,3) and padding=’same’
  • Each Conv2D layer is followed immediately by a MaxPooling2D layer with a pool_size=(2,2)
  • We follow this with a Dense layer with 512 nodes and relu activation
  • Finally, our output layer is a Dense layer with 10 nodes (corresponding to the 10 output classes) and softmax activation function — we use softmax for multi-class classification
  • We compile the model with categorical_crossentropy loss and the adam optimizer.

Following is the code for the same:

We cross-train the model on (train_data2, train_labels_cat2) training set and (val_data, val_labels_cat) cross-validation set that we created earlier. Training is run for 15 epochs, with a batch size of 64. Here is the code to kick-off the cross-training loop.

results = model.fit(train_data2, train_labels_cat2, 
epochs=15, batch_size=64,
validation_data=(val_data, val_labels_cat))

Simple enough! The fit() call returns a history dict object as an attribute of the return variable results, which saved epoch-wise 4 metrics — training loss, training accuracy, validation loss and validation accuracy. These can be easily accessed with results.history['acc'], results.history['loss'], results.history['val_acc'] and results.history['val_loss'] respectively. Since we ran 15 epochs, each results.history['XXX'] object is a list with 15 elements.

To assess how our model performed, let us plot the results.history['acc'] and results.history['val_acc'] vs epochs and results.history['loss'] and results.history['val_loss'] vs epochs . The iPython Notebook for this article provides a show_plots() helper function that does just that — it display 2 plots laid out in a 1x2 grid as shown below.

# display plots...
show_plots(results.history)
Plots of losses and accuracies vs epochs

The plots suggests that this Keras model is over-fitting the training data after ~4 epochs:

  • We can see from the top-left plot that the training loss (blue dots) is falling smoothly towards zero, however the test loss (red line) falls for 4 epochs and then flattens out to a value around 0.03.
  • Also in the accuracy plots (top-right), we observe that the training accuracy (blue dots) rises towards 100%, whereas the validation accuracy flattens out after around 4 epochs — at a value around 99%.
  • This is a typical behavior of an over-fitting model. However our validation loss is not significantly lesser than our training loss, indicating that the model is over-fitting only slightly.

Evaluating Model against Test Data

Now for the most critical question: how does this model fare with test data? Recall that we had set-aside the test data, so the model has not seen this yet. A good performance metric against our test data will indicate that our model is generalizing well. The main aim of developing deep models is creating models that generalize well, so our predictions against unseen data will be accurate in most instances.

Recall that our test data is in the test_data and test_labels_cat arrays that hold 10,000 test images & labels, which we set aside. Here is the code to evaluate model’s performance against test data:

test_loss, test_accuracy = \
model.evaluate(test_data, test_labels_cat, batch_size=64)
print('Test loss: %.4f accuracy: %.4f' % (test_loss, test_accuracy))

Above code produces the following results:

10000/10000 [==============================] - 13s 1ms/step
Test loss: 0.0308 accuracy: 0.9919

Wow! We have achieved an accuracy of 99.2% on our test data indicating that our model has generalized well, even though it was overfitting slightly.

Predictions

Finally, let’s run some predictions using this model — we will run them again on test_data and test_labels_cat arrrays. The code below displays the first 20 predictions for brevity.

predictions = model.predict(test_data)
first20_preds = np.argmax(predictions, axis=1)[:25]
first20_true = np.argmax(test_labels_cat,axis=1)[:25]
print(first20_preds)
print(first20_true)
>>> array([7, 2, 1, 0, 4, 1, 4, 9, 5, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5, 4])>>> array([7, 2, 1, 0, 4, 1, 4, 9, 5, 9, 0, 6, 9, 0, 1, 5, 9, 7, 3, 4, 9, 6, 6, 5, 4])

Notice that we got 100% accuracy with the first 25 images (NOTE: the >>> in the above block indicate the output prompt and all text after >>> is the output).

Let’s check how many of the 10,000 images were incorrectly predicted by our model?

# how many mismatches?
(np.argmax(kr_test_pred, axis=1) != \
np.argmax(test_labels_cat,axis=1)).sum()
>>> 81

We got 81/10,000 incorrect predictions — not bad for such a simple model.

Summary

In this article I showed you how we can use Keras to develop a CNN that achieves 99% accuracy on the MNIST digits dataset. Since the images are rather simple (small + black&white), we expect deep learning models to have simple architectures and to perform reasonably well.

With the Keras library, we have to write a lot less code to create, train & evaluate the model compared to low level APIs like Tensorflow. Keras sits on top of Tensorflow and takes care of the low level details, leaving us to concentrate on data pre-processing and development of model architecture for our problem. I have also provided an iPython Notebook with corresponding Tensorflow code on my GitHub repository. You will notice that we get similar performance but we have to write a lot more code with Tensorflow.

Hope you enjoyed this article. I welcome your comments & feedback.

--

--

Manish Bhobé

IT Professional. Data Science, ML & Deep Learning enthusiast. Loves working with Tensorflow, Pytorch, scikit-learn, Python, Numpy & Pandas. Aspiring author.