Fashion-MNIST with tf.Keras

This is a tutorial of how to classify the Fashion-MNIST dataset with tf.keras, using a Convolutional Neural Network (CNN) architecture. In just a few lines of code, you can define and train a model that is able to classify the images with over 90% accuracy, even without much optimization.

Image for post
Image for post

can be used as drop-in replacement for the original (10 categories of handwritten digits). It shares the same image size (28x28) and structure of training (60,000) and testing (10,000) splits. It’s great for writing “hello world” tutorials for deep learning.

is popular and well-regarded high-level deep learning API. It’s built right into to TensorFlow — in addition to being an independent open source project. You can write all your usual great Keras programs as you normally would using this tf.keras, with the main change being just the imports. Using tf.keras enables you to take advantage of functionality like eager execution and — should you like to down the road. Here, I’ll cover basics.

I will try to go over some of the deep learning terminologies. If you are a beginner to deep learning, I encourage you to compare and contrast my tutorial with much older MNSIT tutorial, using the original low-level TensorFlow APIs, to see how much easier things have become.

Run this notebook in Colab

All the code below is in a Jupyter Notebook on my GitHub. You can open the notebook with zero setup by directly opening which runs on Google’s VM in the Cloud. Choose this option if you just want to quickly open the notebook and follow along with this tutorial. To learn more about Colab, read the official blog or read my blog post on Colab .


There are ten categories to classify in the fashion_mnist dataset:

Label Description
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot

Import the fashion_mnist dataset

Let’s import the dataset and prepare it for training, validation and test.

Load the fashion_mnist data with the keras.datasets API with just one line of code. Then another line of code to load the train and test dataset. Each gray scale image is 28x28.

# Note in Colab you can type "pip install" directly in the notebook
!pip install -q -U tensorflow>=1.8.0
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# Load the fashion-mnist pre-shuffled train data and test data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.fashion_mnist.load_data()
print("x_train shape:", x_train.shape, "y_train shape:", y_train.shape)

Visualize the data

What I like the best about Jupyter Notebook is the visualization. And you can visualize an image from the training data set with matplotlib library’s imshow() to take a look at one of the images from the datasets. Note each image is gray scale in the shape of 28x28.

# Show one of the images from the training dataset
Image for post
Image for post

Data normalization

We then normalize the data dimensions so that they are of approximately the same scale.

x_train = x_train.astype('float32') / 255
x_test = x_test.astype('float32') / 255

Split the data into train/validation/test datasets

In the earlier step of importing the date, we had 60,000 datasets for training and 10,000 test datasets. Now we further split the training data into train/validation. Here is how each type of dateset is used in deep learning:

  • Training data — used for training the model
  • Validation data — used for tuning the hyperparameters and evaluate the models
  • Test data — used to test the model after the model has gone through initial vetting by the validation set.


Let’s define the model and train it.

Create the model architecture

There are two APIs for defining a model in Keras:

In this tutorial we are using the Sequential model API to create a simple CNN model repeating a few layers of a convolution layer followed by a pooling layer then a dropout layer. If you are interested in a tutorial using the Functional API, check out Sara Robinson’s blog .

Note you only need to define the input data shape with the first layer. The last layers is a dense layer with softmax activation that classifies the 10 categories of data in fashion_mnist.

model = tf.keras.Sequential()# Must define the input shape in the first layer of the neural network
model.add(tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu', input_shape=(28,28,1)))
model.add(tf.keras.layers.Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'))
model.add(tf.keras.layers.Dense(256, activation='relu'))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
# Take a look at the model summary

Compile the model

We use model.compile() to configure the learning process before training the model. This is where you define the type of loss function, optimizer and the metrics evaluated by the model during training and testing.


Train the model

We will train the model with a batch_size of 64 and 10 epochs.,
validation_data=(x_valid, y_valid),

Test Accuracy

And we get a test accuracy of over 90%.

# Evaluate the model on test set
score = model.evaluate(x_test, y_test, verbose=0)
# Print test accuracy
print('\n', 'Test accuracy:', score[1])

Visualize the predictions

Now we can use the trained model to make predictions / classifications on the test datasetmodel.predict(x_test) and visualize them. If you see the label as red, it means the prediction is not matching the true label; otherwise it’s green.

Image for post
Image for post
Visualization of 15 predictions


TensorFlow is an end-to-end open source platform for…

Welcome to a place where words matter. On Medium, smart voices and original ideas take center stage - with no ads in sight. Watch

Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore

Get unlimited access to the best stories on Medium — and support writers while you’re at it. Just $5/month. Upgrade

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store