Handwritten digit recognition using PyTorch
In this tutorial we are going to build a digit classifier by training a neural network on MNIST data-set.
First we need to import the necessary libraries to build and train a neural network, assuming you have the libraries installed.
We imported torchvision because it consist of many popular datasets like MNIST, CIFAR etc. To import MNIST dataset, we need to write the following lines of code.
MNIST data-set consists of hand written digits of size 28*28px.
The data set will be downloaded in the data directory created in the working directory.
torchvision.datasets
is used to download and import the data-set while torch.utils.data.DataLoader
returns an iterator over the data-set.
So now we have the necessary data. Next step is to define our neural network.We are building a simple neural network with hidden layer of 500 neurons. Input layer will be consisting of 28*28(784) units. Output will be of 10 units(since we predicting numbers from 0–9). We will be using ReLU for activation.
Then we need to define out loss function and optimizer. Cross entropy loss and adam is used . We also need to define our network hyper-parameters.
Now comes the training part.
- Feed Forward the network.
- Compute loss .
- Back-propagate to compute the gradients.
- Update the weights with the help of optimizer.
- Repeat(1–4) till the model converges.
We have now successfully trained our network. Next we need to calculate percentage accuracy of our network on the test data.
You will get an accuracy close to 95–96%. It is possible to get even high accuracy by adjusting the hyper-parameters or network architecture itself. It is also possible to achieve ~99% and above with help of convolutional neural networks.
Image source : Google Images
Thanks to https://github.com/yunjey for his great tutorials on introduction to PyTorch