Neural Networks implementation from the ground up part 4 — training on MNIST dataset

Satvik Nema
The Deep Hub
Published in
5 min readJul 21, 2024

In the last blog, we completed our implementation of a neural network. In this blog we will test our neural network against a real dataset — called MNIST dataset. These are a collection of handwritten images of digits between 0 and 9, which our network will learn to classify.

FYI: all the source code is available here

Some of the examples of how the data is:

(1)

Each image is a 28x28 greyscale image, with each pixel having values between [0, 255], with 0 as white and 255 as black.

Let’s break this down:

  1. Download the dataset and load it in java
  2. Preprocess the values so that the pixels are flattened to a 1d input matrix and the values are between 0 and 1
  3. Configure the neural network to have 2 hidden layers of 16 neurons each. What about input and output? Well, input will be 784 neurons (flattened 28 * 28 pixels) and output will be of 10 neurons, where ith neuron indicating how much is the probability that the current input is i.
  4. Repeat the training process for ’N’ epochs and then validate the network with unseen samples.

Loading it in java and preprocessing

The dataset is free to download here. After downloading and extracting it, we make a MNISTReader class:

public class MnistReader {

public static List<Pair<Matrix, Matrix>> getDataForNN(
String imagePath, String labelsPath, int samples) {
try {
return getDataForNNHelper(imagePath, labelsPath, samples);
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private static List<Pair<Matrix, Matrix>> getDataForNNHelper(
String imagesPath, String labelsPath, int samples) throws IOException {
List<Pair<Matrix, Matrix>> data = new ArrayList<>();
try (DataInputStream trainingDis =
new DataInputStream(new BufferedInputStream(new FileInputStream(imagesPath)))) {
try (DataInputStream labelDis =
new DataInputStream(new BufferedInputStream(new FileInputStream(labelsPath)))) {
int magicNumber = trainingDis.readInt();
int numberOfItems = trainingDis.readInt();
int nRows = trainingDis.readInt();
int nCols = trainingDis.readInt();

int labelMagicNumber = labelDis.readInt();
int numberOfLabels = labelDis.readInt();

numberOfItems = samples == -1 ? numberOfItems : samples;

for (int t = 0; t < numberOfItems; t++) {
double[][] imageContent = new double[nRows][nCols];
for (int i = 0; i < nRows; i++) {
for (int j = 0; j < nCols; j++) {
imageContent[i][j] = trainingDis.readUnsignedByte();
}
}
Matrix imageData =
new Matrix(imageContent)
.apply(pixel -> MathUtils.scaleValue(pixel, 0, 255, 0, 1))
.flatten()
.transpose();

int label = labelDis.readUnsignedByte();
double[] output = new double[10];
output[label] = 1;
Matrix outputMatrix = new Matrix(new double[][] {output}).transpose();
data.add(Pair.of(imageData, outputMatrix));
}
}
}
return data;
}
}

Now because our matrices are immutable and return new matrices after every operation, the pre-processing becomes a one liner:

Matrix imageData =
new Matrix(imageContent)
.apply(pixel -> MathUtils.scaleValue(pixel, 0, 255, 0, 1))
.flatten()
.transpose();

We don’t need an explanation here as the function names are self explanatory.

Training

We now implement the MnistTrainer class which trains on the loaded input and adjusts the weights and biases

@Builder
@AllArgsConstructor
@NoArgsConstructor
@Data
public class MnistTrainer {
private NeuralNetwork neuralNetwork;
private int iterations;
private double learningRate;

public void train(List<Pair<Matrix, Matrix>> trainingData) {
int mod = iterations / 100 == 0 ? 1 : iterations / 100;
double error = 0;
for (int t = 0; t < iterations; t++) {
for (Pair<Matrix, Matrix> trainingDatum : trainingData) {
neuralNetwork.trainForOneInput(trainingDatum, learningRate);
double errorAdditionTerm =
neuralNetwork.getOutputErrorDiff().apply(x -> x * x).sum()
/ trainingData.size();
error += errorAdditionTerm;
}

neuralNetwork.setAverageError(error);

if ((t == 0) || ((t + 1) % mod == 0)) {
System.out.println("after " + (t + 1) + " epochs, average error: " + error);
}
error = 0;
trainingData = MathUtils.shuffle(trainingData);
}
}
}

This will be called from the main method.

public class Main {
public static void main(String[] args) throws IOException {
String rootPath = "/Users/satvik.nema/Documents/mnist_dataset/";
String trainImagesPath = rootPath + "train-images.idx3-ubyte";
String trainLabelsPath = rootPath + "train-labels.idx1-ubyte";

List<Pair<Matrix, Matrix>> mnistTrainingData =
MnistReader.getDataForNN(trainImagesPath, trainLabelsPath, 60000);

List<Integer> hiddenLayersNeuronsCount = List.of(16, 16);

int inputRows = mnistTrainingData.getFirst().getA().getRows();
int outputRows = mnistTrainingData.getFirst().getB().getRows();

MnistTrainer mnistTrainer =
MnistTrainer.builder()
.neuralNetwork(
NNBuilder.create(inputRows, outputRows, hiddenLayersNeuronsCount))
.iterations(100)
.learningRate(0.01)
.build();

Instant start = Instant.now();
mnistTrainer.train(mnistTrainingData);
Instant end = Instant.now();

long seconds = Duration.between(end, start).getSeconds();
System.out.println("Time taken for training: "+seconds+"s");
}
}

Notice how we set the 2 hidden layers with 16 neurons in hiddenLayersNeuronsCount

Network’s performance

MNIST dataset also includes separate 10,000 testing samples. We will use them to test how our trained network performs on unseen data.

Starting with a MnistTester :

@Builder
@AllArgsConstructor
@NoArgsConstructor
@Data
public class MnistTester implements NeuralNetworkTester {
private NeuralNetwork neuralNetwork;

public double validate(List<Pair<Matrix, Matrix>> trainingData) {
double error = 0;
int countMissed = 0;
List<String> missedIndexes = new ArrayList<>();
int index = 0;
for (Pair<Matrix, Matrix> trainingDatum : trainingData) {
neuralNetwork.feedforward(trainingDatum.getA());
Matrix output = neuralNetwork.getLayerOutputs().getLast();
int predicted = output.max().getB()[0];
int actual = trainingDatum.getB().max().getB()[0];
if (predicted != actual) {
countMissed++;
missedIndexes.add("("+index+", "+actual+", "+predicted+")");
}

Matrix errorMatrix = output.subtract(trainingDatum.getB());
error += errorMatrix.apply(x -> x * x).sum() / trainingData.size();
index++;
}
System.out.printf("Total: %s, wrong: %s%n", trainingData.size(), countMissed);
return error;
}
}

And running our validation:

String testImagesPath = rootPath + "t10k-images.idx3-ubyte";
String testLabelsPath = rootPath + "t10k-labels.idx1-ubyte";

List<Pair<Matrix, Matrix>> mnistTestingData =
MnistReader.getDataForNN(testImagesPath, testLabelsPath, -1);
MnistTester mnistTester = MnistTester.builder().neuralNetwork(trainedNetwork).build();
double error = mnistTester.validate(mnistTestingData);
System.out.println("Error: "+error);

Time for the truth

The accuracy is pretty good

after 1 epochs, average error: 0.6263571412645461
after 100 epochs, average error: 0.07471539583255844
after 200 epochs, average error: 0.060457042431757556
after 300 epochs, average error: 0.052867280710867826
after 400 epochs, average error: 0.04818163691903281
after 500 epochs, average error: 0.04496163434230489
after 600 epochs, average error: 0.04240323875238682
after 700 epochs, average error: 0.04034903547585861
after 800 epochs, average error: 0.03881550591240332
after 900 epochs, average error: 0.037430996099864056
after 1000 epochs, average error: 0.03629978820188779
Time taken for training: 4676s

Testing:
Total: 10000, wrong: 613

So out of the given 10k testing samples, we got only 613 wrong! That’s a accuracy of ~93.8%. Not so bad for a homegrown neural network is it?

Let’s examine which samples where wrong.

For one instance, this one was supposed to be a 4 which our network classified as 8:

And this one was supposed to be a 5 which the network classified as 3:

Well these mistakes can be forgiven isn’t it? xD

People have actually achieved a accuracy of about >99% on this dataset. Here are the overall benchmarks on this dataset (we break into top 40).

And this concludes our Neural Networks from scratch series :)

Resources

  1. https://github.com/SatvikNema/neural-net
  2. https://madhuramiah.medium.com/how-i-increased-the-accuracy-of-mnist-prediction-from-84-to-99-41-63ebd90cc8a0
  3. https://medium.com/thedeephub/neural-networks-implementation-from-the-ground-up-part-3-backpropagation-e9126938edac
  4. https://paperswithcode.com/sota/image-classification-on-mnist?metric=Accuracy

--

--