Neural Networks implementation from the ground up part 1

Satvik Nema
The Deep Hub
Published in
4 min readJun 22, 2024

How many times have you heard the terms — machine learning, AI, Neural Networks, LLMs, etc, and, if you’re like me, have no idea of what any of those mean exactly?

Last weekend I decided I had enough and started to look behind the curtains of these buzzwords.

(1)

I decided to implement a neural network in java without using any of the libraries/frameworks to see what it requires. Just plain two dimensional floating point arrays along with some college level differential calculus. Why you’d ask? coz why not!

The network will be configurable. What exactly will be configurable?

  1. number of hidden layers
  2. number of neurons per layer
  3. other general things like learning rate, starting weights and biases

I am not going into the depths of mathematics in here. There a LOT of online free resources to get you started on the math aspects for understanding a neural network on a theoretical level. I will be attaching these in the resources section.

This blog series will all be about implementation in java, and is in 4 parts:

  1. Setting up the Matrix object and Neural Network structure (this blog)
  2. Configuring weights and biases and coding up one iteration of feedforward flow
  3. Coding one iteration for backpropagation flow
  4. Tying it all together to train the MNIST dataset with good accuracy

Getting started

So let’s get started with building the most fundamental component in Neural Network — Matrix

It’s a conventional 2d array and will be wrapped around commonly used Matrix operations:

public class Matrix {
private static final Random random = new Random(1);
private final double[][] content;

private final int rows;
private final int columns;

public Matrix(int i, int j) {
this.rows = i;
this.columns = j;
this.content = new double[i][j];
}

public Matrix(double[][] arr) {
int r = arr.length, c = arr[0].length;
content = arr;
rows = r;
columns = c;
}

public Matrix dot(Matrix b) {}

public Matrix cross(Matrix b) {}

public Matrix transpose() {}

public Matrix apply(Function<Double, Double> f) {}

public Matrix add(Matrix b) {}

public Matrix subtract(Matrix b) {}

public static Matrix random(int r, int c, double start, double end) {}

public static Matrix ones(int r, int c) {}

public Matrix flatten() {}

public double sum() {}
}

The methods dot, cross, transpose, add, subtract are the usual matrix operations and don’t need further explanation.

For others:

  1. apply will transform each element in the matrix with according to the Function passed. Can be used as follows matrix.apply(x -> x * x) which will simply square each element of the matrix and return a new Matrix object
  2. random will generate a matrix having random double values between [start, end) with given dimensions r rows and c columns.
  3. ones will return a matrix filled with 1s
  4. flatten will convert a matrix of dimension r x c into 1 x (r*c) with elements lined up one after the other row-wise.
  5. sum will sum up all values in the matrix and return it.

Note that the content is final and a Matrix object is immutable in nature, i.e once created, the contents cannot be changed for a given Matrix instance.

The methods are implemented here https://github.com/SatvikNema/neural-net/blob/main/src/main/java/com/satvik/ml/util/Matrix.java

Coming to the NeuralNetwork structure itself

Consider this example

We have a lot going around in this diagram. Lets break it down

  1. Green coloured neurons are inputs, yellow ones are hidden layers and red ones are outputs
  2. Wij denotes the weight of the edge going from neuron i in layer L to neuron j in layer L+1. To save redundant labels, have kept them same for each layer.
  3. Bij denotes the bias set for the neuron j in the layer i.
  4. Wi denotes the weight matrix for the layer i.
  5. Hij denotes the output of neuron j in layer i.

We follow the below configuration to encapsulate all of this in the object NeuralNetwork

public class NeuralNetwork {

private List<Matrix> weights;
private List<Matrix> biases;
private List<Matrix> layerOutputs;
private int layers;
private Matrix outputErrorDiff;
private double averageError;

public void serialise(String filePath) throws IOException {}

public static NeuralNetwork deserialise(String filePath) throws IOException {}

public void feedforward(Matrix input) {}
}
  1. weights.get(i) stores the weight matrix for layer i
  2. biases.get(i) stores the biases for layer i
  3. layerOutputs.get(i) denotes the output of the layer i
  4. layers simply store the number of layers in the neural network, which will be 3 in this case
  5. outputErrorDiff and averageError don’t matter now. They will be used during backpropagation
  6. feedForward does one iteration for feedforward flow using the given weights and biases. We will cover this in the next part
  7. serialise() and deserialise() will be used to persist/load an already trained neural network to/from disk. Their implementation is here

We have setup the basic building blocks that are required to train a neural network.

In this next blog, we will be seeing how to configure the starting weights and baises (dimensions might be tricky due to matrix cross products) and implement the feedforward method.

Stay tuned!

Resources

  1. For visualising graphs, partial derivatives, and why a partial derivative always points to the direction of steepest ascent: first 25 videos from this playlist will set your intuitions on the right path.
  2. For math equations and their derivations: this write-up from brilliant.org hits the sweet spot.
  3. For getting a high level overview of what exactly a neural network is: checkout this series from 3Blue1Brown

--

--