MNIST Using Recurrent Neural Network

Ting-Hao Chen
Machine Learning Notes
4 min readJan 9, 2018

Let’s try to predict the handwritten digit by RNN.

If you are interested, the code (jupyter notebook and python file) of this post can be found here.

MNIST dataset

In this tutorial, I am going to demonstrate how to use recurrent neural network to predict the famous handwritten digits “MNIST”.

The original dataset can be downloaded here:

http://yann.lecun.com/exdb/mnist/

However, We are going to directly use the same MNIST dataset from TensorFlow.

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/")

The MNIST dataset consists:

  • mnist.train: 55000 training images
  • mnist.validation: 5000 validation images
  • mnist.test: 10000 test images

Each image is 28 pixels (rows) by 28 pixels (cols). We treat each image as a sequence of data, that is, the first row is the first step; second row is the second step and so on. Therefore, n_steps = number of rows and n_inputs = number of columns.

RNN data flow

The mnist dataset from TensorFlow assumes that you are using one-hot encoding, however, we are not going to do that. Therefore, we need to reshape the dataset from [num_data, 28*28] to [num_data, n_steps, n_inputs]. Since there are many outputs from the RNN, we only care about the last one. As a result, “state” in the code is considered as our output.

# hyperparameters
n_neurons = 128
learning_rate = 0.001
batch_size = 128
n_epochs = 10
# parameters
n_steps = 28 # 28 rows
n_inputs = 28 # 28 cols
n_outputs = 10 # 10 classes
# build a rnn model
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.int32, [None])
cell = tf.nn.rnn_cell.BasicRNNCell(num_units=n_neurons)
output, state = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
logits = tf.layers.dense(state, n_outputs)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y, logits=logits)
loss = tf.reduce_mean(cross_entropy)
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)
prediction = tf.nn.in_top_k(logits, y, 1)
accuracy = tf.reduce_mean(tf.cast(prediction, tf.float32))

Reshape the test dataset to [num_test, n_steps, n_inputs]

# input data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets(“MNIST_data/”)
X_test = mnist.test.images # X_test shape: [num_test, 28*28]
X_test = X_test.reshape([-1, n_steps, n_inputs])
y_test = mnist.test.labels

Now, go ahead and train the model!

# initialize the variables
init = tf.global_variables_initializer()
# train the model
with tf.Session() as sess:
sess.run(init)
n_batches = mnist.train.num_examples // batch_size
for epoch in range(n_epochs):
for batch in range(n_batches):
X_train, y_train = mnist.train.next_batch(batch_size)
X_train = X_train.reshape([-1, n_steps, n_inputs])
sess.run(optimizer, feed_dict={X: X_train, y: y_train})
loss_train, acc_train = sess.run(
[loss, accuracy], feed_dict={X: X_train, y: y_train})
print('Epoch: {}, Train Loss: {:.3f}, Train Acc: {:.3f}'.format(
epoch + 1, loss_train, acc_train))
loss_test, acc_test = sess.run(
[loss, accuracy], feed_dict={X: X_test, y: y_test})
print('Test Loss: {:.3f}, Test Acc: {:.3f}'.format(loss_test, acc_test))

The output is pretty good! The test accuracy is 97.2% without further optimization.

Epoch: 1, Train Loss: 0.287, Train Acc: 0.906
Epoch: 2, Train Loss: 0.205, Train Acc: 0.938
Epoch: 3, Train Loss: 0.094, Train Acc: 0.977
Epoch: 4, Train Loss: 0.099, Train Acc: 0.961
Epoch: 5, Train Loss: 0.023, Train Acc: 0.992
Epoch: 6, Train Loss: 0.022, Train Acc: 1.000
Epoch: 7, Train Loss: 0.110, Train Acc: 0.969
Epoch: 8, Train Loss: 0.042, Train Acc: 0.992
Epoch: 9, Train Loss: 0.086, Train Acc: 0.969
Epoch: 10, Train Loss: 0.101, Train Acc: 0.977
Test Loss: 0.099, Test Acc: 0.972

It is always nice to plot train loss vs epoch, therefore you could know if the train loss has converged.

Train Loss and Train Accuracy in every epoch

Let’s visualize the test images and the predictions made by the RNN model.

Prediction

--

--