Analytics Vidhya
Published in

Analytics Vidhya

The Fashionable “Hello World” of Deep Learning

It is time to move over from the MNIST handwritten digits database and embrace the MNIST fashion database.

Photo by NordWood Themes on Unsplash

“If it doesn’t work on MNIST, it won’t work at all”, they said. “Well, if it does work on MNIST, it may still fail on others.”

Update 2021: This article was published a long time ago. Some of the details and features of the libraries might have changed. It is advised to look at the documentation for relevant details

Training a model on the MNIST handwritten digits data is a classic in the Machine Learning community. It is considered as the entry point for Deep learning exercises.

MNIST handwritten database has been around for quite a long and has been intensively studied. It’s a set of 60,000 training images, plus 10,000 test images, assembled by the National Institute of Standards and Technology (the NIST in MNIST) in the 1980s. The aim is to classify grayscale images of handwritten digits (28 × 28 pixels) into 10 categories (0 through 9). “Solving” the MNIST is analogous to the “Hello World” of deep learning — it’s what you do to verify that your algorithms are working as expected. While working in the Deep learning domain, we are expected to encounter the MNIST database repeatedly.

The Fashion-MNIST

The Fashion-MNIST was released in August 2017 as a direct drop-in replacement for the original MNIST dataset for benchmarking machine learning algorithms. It shares the exact image size and structure of training and testing splits as each example is a 28x28 grayscale image, associated with a label from 10 classes

Need for replacement

Just because a model performs well on the MNIST handwritten data doesn’t guarantee it will perform well on other data. This is primarily because the images we have today are more complex than the handwritten digits in the MNIST database.

  • MNIST is too easy: It has been observed that Convolutional networks can achieve an accuracy of around 99% on MNIST data which may not always be the real-world scenario.
  • MNIST is overused
  • MNIST cannot represent modern CV tasks which tend to be more complex.

Training a neural network in Keras on Fashion MNIST dataset

Since this article uses Keras, make sure you have Keras installed and running.

Let us now explore and use the fashion-MNIST dataset to train a neural network to detect clothing categories.

This article is only meant as an introduction to the Fashion MNIST dataset. For a detailed understanding of how the network is trained in Keras, please refer here.

There are 10 categories of clothing that have been assigned a label.

1. Loading the Fashion-MNIST dataset in Keras

The Fashion- MNIST dataset comes preloaded in Keras in a set of four Numpy arrays.

import keras
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from keras.datasets import fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

The labels are not included in the dataset, so that we can create a list of label_names.

label_names = [‘T-shirt/top’, ‘Trouser’, ‘Pullover’, ‘Dress’, ‘Coat’, ‘Sandal’, ‘Shirt’, ‘Sneaker’, ‘Bag’, ‘Ankle boot’]

2. Data Exploration

Let us look at the shape of training and testing data.

#Training Data
train_images.shape
(60000, 28, 28)
len(train_labels) # Total no. of training images
60000
train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)
#Testing Data
test_images.shape
(10000, 28, 28)
len(test_labels)
10000
test_labels
array([7, 2, 1, ..., 4, 5, 6], dtype=uint8)
  • The training data consists of 60,000 images, with each image represented by 28 x 28 pixels.
  • Likewise, the testing data consists of 10,000 images, with each image again represented by 28 x 28 pixels.
  • The label for test and train data is an integer between 0 and 9.

3. Preprocessing the Data

The data must be preprocessed before feeding into the network. Let’s explore an image of the training set.

plt.imshow(train_images[1])
plt.grid(False)
plt.colorbar()
plt.show()

The color bar shows the pixel intensity fall in the range of 0 to 255. We will rescale these values to a range of 0 to 1.

#Rescaling test and train imagestrain_images = train_images / 255.0test_images = test_images / 255.0

Exploring the first 30 training set images to ensure that the images have been scaled.

plt.figure(figsize=(8,10))
for i in range(30):
plt.subplot(5,6,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(train_images[i])
plt.xlabel(label_names[train_labels[i]])

All the images have indeed been scaled. Let us move over to the model building now.

4. Building the Network architecture

We will configure the layers of the model and then proceed with compiling the model.

  • Setting up the Layers
from keras import models
from keras import layers
network = models.Sequential()network.add(layers.Dense(128, activation='relu', input_shape=(28 * 28,)))
network.add(layers.Dense(10, activation='softmax'))

The network consists of a sequence of two Dense layers. These are densely connected or fully connected neural layers. The first layer has 128 nodes. The second (and last) layer is a 10-node softmax layer which returns an array of 10 probability scores. Each node contains a score that indicates the probability that the current image belongs to one of the 10 classes.

  • Compiling the model

After the model has been built, we enter the compilation phase, which primarily consists of three essential elements:

  • Loss Function: loss (Predicted — Actual value) is the quantity that we try to minimize during the training of a neural network.
  • Optimizer: It determines how the network will be updated based on the loss function
  • Metrics: to measure the accuracy of the model. In this case, we will use accuracy.
network.compile(optimizer='rmsprop',loss='categorical_crossentropy',
metrics=['accuracy'])

5. Training the Model

To start training, we call the methodnetwork.fit, i.e., the model is "fit" to the training data.

network.fit(train_images, train_labels, epochs=5,batch_size=128)Epoch 1/5
60000/60000 [==============================] - 3s 55us/step - loss: 0.5805 - acc: 0.7989
Epoch 2/5
60000/60000 [==============================] - 3s 44us/step - loss: 0.4159 - acc: 0.8507
Epoch 3/5
60000/60000 [==============================] - 3s 42us/step - loss: 0.3692 - acc: 0.8679
Epoch 4/5
60000/60000 [==============================] - 3s 45us/step - loss: 0.3403 - acc: 0.8767
Epoch 5/5
60000/60000 [==============================] - 3s 44us/step - loss: 0.3185 - acc: 0.8842

This model reaches an accuracy of about 0.88 (or 88%) on the training data.

6. Model Evaluation

We evaluate the model’s performance on the test dataset

test_loss, test_acc = network.evaluate(test_images, test_labels)
print('test_acc:', test_acc)
10000/10000 [==============================] - 0s 30us/step test_acc: 0.8683

The accuracy on the test set comes out to be around 87% which is fewer visualizations on the training dataset- a clear case of overfitting.

7. Making Predictions

Let us now use the trained model to make predictions on some images.

predictions = model.predict(test_images)

Let us see what our model predicts for the 10th image.

predictions[10]array([9.7699827e-05, 5.6700603e-05, 1.0853803e-01, 1.0991561e-06, 8.7915897e-01, 4.5721102e-10, 1.2143801e-02, 1.0442269e-10, 3.1952586e-06, 4.5470620e-07], dtype=float32)

The prediction is an array of 10 numbers representing the probability of different types of clothing. We can also find the label which has the highest probability for the 10th item on the list.

np.argmax(predictions[10])
4

4 corresponds to ‘Coat.’ Thus, our model predicts the 10th item to be a coat. Let us check the accuracy of this prediction.

test_labels[10]4

Indeed it is 4, i.e., a ‘Coat.’

Analyzing Predictions

Images are best understood through visualizations. Let us write a function that outputs the predicted image and the probability that it belongs to that category.

True prediction labels are green, while incorrect prediction labels are red. The number gives the percentage for the predicted tag.

Let’s look at the 15th image, its predictions, and probabilities.

i = 15
plt.figure(figsize=(8,3))
plt.subplot(1,2,1)
plot_image(i, predictions, test_labels, test_images)
plt.subplot(1,2,2)
plot_value_array(i, predictions, test_labels)
plt.xticks(range(10),class_names, rotation=90)
plt.yticks(range(1))

The green label indicates that ‘Trouser’ is the correct prediction.

i = 1000
plt.figure(figsize=(8,3))
plt.subplot(1,2,1)
plot_image(i, predictions, test_labels, test_images)
plt.subplot(1,2,2)
plot_value_array(i, predictions, test_labels)
plt.xticks(range(10),class_names, rotation=90)
plt.yticks(range(0,1))

For the 1000th image, the model predicts a ‘Shirt’ with a 71% probability. However, in reality, it is a ‘T-Shirt/Top.’ The red label indicates a wrong prediction with the correct label written in brackets alongside the wrong one.

Let us plot a few more images.

num_rows = 5
num_cols = 3
num_images = num_rows*num_cols
plt.figure(figsize=(5*num_cols, 2*num_rows))
for i in range(num_images):
plt.subplot(num_rows, 2*num_cols, 2*i+1)
plot_image(i, predictions, test_labels, test_images)
plt.subplot(num_rows, 2*num_cols, 2*i+2)
plot_value_array(i, predictions, test_labels)
plt.xticks(range(10))

For a majority of the images, our model has performed exceptionally well. However, there are wrong predictions also which have been marked in ‘red.’

Conclusion

Fashion-MNIST is an improvement upon the original MNIST regarding data complexity and accuracy. Since its release, it has been a topic of discussion among researchers, students, and other enthusiasts in this domain. Han Xiao — Senior Scientist III @ Tencent AI Lab, the person who released the dataset, says,” In the future, I will continue maintaining Fashion-MNIST, making it accessible to more people around the globe. It doesn’t matter whether you are a researcher, a student, a professor, or an enthusiast. You are welcome to use Fashion-MNIST in papers, meetups, hackathons, tutorials, classrooms, or even on T-shirt”. So, go ahead and start experimenting with the fashionable version of MNIST.

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store