Implementing an Artificial Neural Network in Pure Java (No external dependencies).

Visualization of the training loss with JavaFx

Deep learning frameworks have oversimplified the process of implementing neural networks, and it’s sometimes easy to fall into the trap of abstracting away the learning process, believing that you can simply stack arbitrary layers together and it will take care of everything automagically[1]. Having a solid foundation in machine learning (ML) by implementing core concepts from scratch such as backpropagation algorithm (for NNs, CNNs and RNNs) is important. Take the time to understand its derivations, and try to derive it yourself from scratch and also implement it from scratch in code and see if you can make it work. The knowledge you gain will stick and will be independent of any framework you decide to learn later. In my learning process I thought it's worth knowing what’s happening under the hoods for intellectual curiosity. In this article I present to you my simple implementation of a two layer NN in pure java.

If you’re in hurry here is the complete code. Its python/numpy version can be found here

Network Architecture

Represented below is a two layer feed-forward neural network we are going to implement in java. We will use the following network architecture, but all the concepts can be scaled for any number of layers and nodes.

Two layer Neural Network

The pattern that we will teach our neural network to recognize is the XOR operation. The XOR operator truth table is shown below for the operation y= x1 XOR x2

XOR table

Some Background Mathematics

The following are the equations are the forward prop equations for the above neural network architecture [2]. Upper indices indicates the layer and lower indices indicates the nodes index.

Part 1: Forward propagation Equations


We all have used for-loops for majority of the tasks which needs an iteration over a long list of elements. I am sure almost everybody, who is reading this article, wrote their first code for matrix or vector multiplication using a for-loop back in high-school or college. For-loop has served programming community long and steady. However, it comes with some baggage and is often slow in execution when it comes to processing large data sets (many millions of records as in this age of Big Data) [3].

So let’s vectorize our equations. Combining the computation of the nodes for the hidden layer (Ignore the activation functions first, will come back to them).

Things to note are

1. Increasing the number of nodes will increase the number of rows of our weight matrix.

2. Increasing the number of features will increase the number of columns of matrices.

Here is what I mean

For example let’s say we have added another node in the hidden layer. The matrix equations grows downwards.

Activation function

We need activation functions to learn non-linear complex functional mappings between the inputs and target outputs of our data. From the previous section I just ignored the activation function equations for easy justification.We will be using a sigmoid activation function.

Part 2: Back propagation Equations

To get an intuitive understanding of how back prop works, I will use the diagram below to illustrate gradient computation which will then be used by gradient descent to perform an update of the learnable parameters w and b. For the sake of simplicity I will use a single layer neural network (a logistic regression). The idea can be scaled to N layer neural network [2].

We are going to use the cross entropy loss to compute the cost

Computing the (dw) derivative of loss w.r.t to weights.This can be done by using chain rule as shown below.

Computing the (db) derivative of loss w.r.t to biases.

Update Equations

We will be using gradient descent to perform parameters update for each layer as follows.

Full Java Code

Understanding the above concepts is crucial part of understanding how this code works. contains all the matrix operations

Training Results

Below is the result after training the NN for 4000 iterations. We can clearly see(Prediction = [[0.01212, 0.9864, 0.986300, 0.01569]]) that our network have done a good job of trying to emulate the XOR operation. As we can see the inner values are getting pushed to 1 while the outer ones are getting pushed to zero.

Cost = 0.1257569282040295
Prediction = [[0.15935, 0.8900528, 0.88589, 0.0877284694]]



Cost = 0.015787269324306925
Prediction = [[0.013838, 0.984561, 0.9844246, 0.0177971]]
Cost = 0.013869971354598404
Prediction = [[0.01212, 0.9864, 0.986300, 0.01569]]


This is not an efficient implementation of a neural network, but my intention was to convey an intuitive understanding of machine learning concepts and have the ability to communicate them into code.

Did you find this article helpful? Did you spotted any mistake? (Possibly, because this is my first article and English is not my primary language). Have opinion/comments? Drop them bellow.