Published in


Deep Learning for Java (DL4J) Getting Started: Tutorial

Build Iris Classification Neural Network using Deep Learning for Java (DL4J) library.

Photo by Alina Grubnyak on Unsplash

This is the first tutorial of a series of tutorials I’ll be writing in which you’ll work on building Neural Networks using DL4J (A Java-based deep learning library).


The only prerequisite is the knowledge of Java. If you have worked with basic Java SE and understand the basic Object-Oriented Programming (OOP) concepts you’ll be good to go. Also, a basic understanding of Neural Network or deep learning and the concepts would be a big plus.


We’ll be using IntelliJ IDEA CE. Download the Community version:

Download Community Version

Iris Classification Neural Network using DL4J

Open IntelliJ IDEA CE and Create a new project, name it LearningDL4J.

Open pom.xml file.

Add the following code in the pom.xml:







You’ll notice that ${dl4j.version} is highlighted red. You might have two red highlighted red syntax. Let’s solve these first:

Goto FileProject StructureLibraries (Under Project Settings) → +From Maven… → and in the search bar add org.nd4j:nd4j-native-platform:1.0.1-alpha

Next, do the same thing and install org.deeplearning4j:deeplearning4j-core:1.0.0-alpha

Finally, following the same step and search and install org.apache.cassandra:cassandra-all:1.1.4

Ensure that you have added all the three libraries mentioned in the picture above. Apply and press OK to close the Project structure window.
After you’re done adding these libraries you’ll notice both errors are gone from pom.xml.

About & Download Dataset

Hello-World of machine learning world “Iris Classification” dataset.

Iris dataset is a set of data that was gathered from the flowers of different species (Iris setosa, Iris versicolor, and Iris virginica). —

Download the Iris Classification .csv file or search for the CSV file on the internet.

Columns 1— 4 contain the different features of the species and column 5 contains the class of the record, or the species, coded with a value 0 which indicates Iris-setosa, 1 which indicates Iris-versicolor class, and 2 which indicates Iris-virginica.

Building Neural Network

The first step that we need to do is to load the dataset. As neural networks work with numbers so we’ll do vectorization (Transforming real-world data into a series of numbers). DL4J uses datavec library to do this.

Create a new Java Class inside src > java > {Create a package (optional} > Right Click > New > Java Class

I’ve named my java class IrisClassification

TYPE the following code in the IrisClassification class:

public static void main(String[] args) {    BasicConfigurator.configure();


private static void loadData() {
try(RecordReader recordReader = new CSVRecordReader(0,',')) {
recordReader.initialize(new FileSplit(
new ClassPathResource("iris.csv").getFile()
} catch (Exception e) {
System.out.println("Error: " + e.getLocalizedMessage());

Before explaining the code above let’s address the errors you might be facing. If you have copy-pasted the code then you are missing these imports:

import org.apache.log4j.BasicConfigurator;
import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
//Do type the code and avoid copy-pasting as I won’t be showing the //imports in the code snippets moving further

So, now you’re might be left with this error “Try-with-resources are not supported at language level 5”.

To solve this either click on Module Settings you see in the popover of the error or open Project Structure from File menu.

Select 12— No new language features from Language level, click on Apply and then OK. This should resolve that error as shown below:

Now in the code above you have the main function that is just calling the loadData() function and BasicConfigurator.configure(); is basically a logger that is required by DL4J so just add it there.
Inside the loadData() function we have a try-catch statement in which we have CSVRecordReader object that takes 2 arguments. Here 0 indicates the number of lines to skip (so, we’re not skipping any line) and then delimiter which is , as it’s a CSV file.

The next line where we are initializing the recordReader object we have to provide the file path. So, what I did is that I placed iris.csv file in the resources folder as shown below:

Next, we’ll iterate over the dataset, and to do that we use DataSetIterator which is an interface with multiple implementations. Meaning datasets can be of huge sizes so it also provides the ability to go a page or cache the values which come in handy in a lot of cases but since this dataset only has 150 values so we’ll read all the data at once.

Now within the try block TYPE this code:

DataSetIterator iterator = new RecordReaderDataSetIterator(recordReader, 150, 4, 3);
DataSet allData =;

In these three lines of code, we made a new RecordReaderDataSetIterator and provided it with the recordReader objects that contain our file and passed other arguments like batch size which is 150 and there are 4 columns that have the feature values of the flowers and the three labels i.e. 0, 1, and 2.
The basically loads all the data into the memory at once and lastly, we shuffled the values as the dataset I’m using the values were pretty neat meaning the first 50 rows are of label 0 and then the next 50 of 1 and so on, the value 123 is the seed which will shuffle them in the same order every time we’ll run the program.

Next, we’ll normalize the data. You can take it (if you’re coming from a python background of data analytics) as a fit-transform of data. But do note that Normalization may differ for different types of data. For example, when working with images data of various sizes we would first collect size statistics and then scale them to a uniform size. As in this case we have numbers data, normalization would mean transforming them into a normal-distribution and to do that we’ll be using NormalizerStandardize. So, in the try clause after the shuffle type the following code:

DataNormalization normalizer = new NormalizerStandardize();;

Now after fit-transform we split the data. Splitting the data means we split the dataset into a train-test. We’ll use the train data to train our deep learning model and then the test data will be used to evaluate the model.

So, I’m dividing the data into 65% (0.65) for training and the remaining 35% (0.35)for testing:

SplitTestAndTrain testAndTrain = allData.splitTestAndTrain(0.65);
DataSet trainingData = testAndTrain.getTrain();
DataSet testingData = testAndTrain.getTest();

The data is set for us to start building our neural network.

Before we start setting up the network I just made a small change in code. I’ve initialized two variables. As it’s a good practice. It’s totally optional!

Create a new function named irisNNetwork and TYPE this code in it.

private static void irisNNetwork(DataSet trainingData, DataSet testData) {

MultiLayerConfiguration configuration = new NeuralNetConfiguration.Builder()
.updater(new Nesterovs(0.1, 0.9))
.layer(0, new DenseLayer.Builder().nIn(FEATURES_COUNT).nOut(3).build())
.layer(1, new DenseLayer.Builder().nIn(3).nOut(3).build())
.layer(2, new OutputLayer.Builder(


The above code is just one line of code. But to actually understand the beauty of this one line you have to imagine a chain. Each . is the new piece added in the chain. Now see this picture below:

A simple neural network consists of an input layer, a hidden layer, and an output layer.

This is what we’re trying to build and the above code is the configuration to build this structure for example how many nodes (circles) will there be in the input layer and many other settings.

The activation() the function runs inside the nodes (circles). Think of it as the formula that’ll do some calculation at each node and based on the result of that calculation our network will produce results. Before explaining more about the activation function I recommend watching this 5 minutes video where the guy also talks about the activation function at 1:40

There are a lot of different activation functions and each has its advantages. So, without making it complex at this point you just need to know why activation functions are required (Which is explained in the video) and in our case, we’re using TANH which is called as “hyperbolic tangent”.

In the video, the guy points out that channels that connect layers with each other have weights so the weightInit is basically specifies one of the many ways to set up initial weights. So XAVIER is basically drawing each weight for the node from a Gaussian distribution and it’s a good choice to start off with.

updater is used to set learningRate which is very important as it can have a huge effect on the ability of the network. This is something that we might have to go back and forth to tweak but in our really simple example we’re using a significant value of 0.1 and Nesterovs is basically used to set the momentum of weights that’ll be assigned to channels.

One of the problems that occur in neural networks is overfitting. This happens when the network sets high weights for the training data and this results in bad outcomes. To solve or avoid overfitting we use l2 regularization. Regularization penalizes the network for too large weights and prevents overfitting.

Now we’ll be creating the layers. DenseLayer is basically fully connected layers. The first layer ind: 0 should contain the same amount of nodes as the columns in the training data i.e. 4 that’s what we did in the nIn() part of the first layer. This first layer that we have created will act as the input layer. The nOut() part of this layer is explained in the second layer explanation.
The second layer contains 3 nodes and this is the value we can variate, but then nOut in the previous layer has to be the same. This will act as our hidden layer.
The third and final layer, in this case, will be out output layer and it has to match the number of classes i.e. 3 or CLASSES_COUNT. So, this will create our network and the image below is demonstrating how it’ll look in a visual world for the first row of data:

Taken from:

The above network takes four inputs from our dataset, process it in the hidden layer, and send the signal to the output layer. This is a simple classifier.

After layers are created we have done two things. First, we set backpropagation to true using backprop. When the network moves from the Input Layer → Hidden Layer → Output layer it’s called forward propagation and when the network moves from Output Layer → Hidden Layer → Input Layer it’s called backpropagation. When forward propagation is finished the network analyzes the output, if the output is wrong the network then uses backpropagation to adjust weights and bias, and just like that when backpropagation is finished again forward propagation starts. So, basically, backpropagation is the central mechanism by which neural networks learn. It is the messenger telling the network whether or not the network made a mistake when it made a prediction (on the output layer).
Second, this that we're doing is setting pretrain to false, as we don’t want pre-training of the data. When you start training a neural network, you start by initializing the weights randomly. As soon as you start training the weights to get changed due to the backpropagation concept we’ve just learned. Once you’re satisfied with the results of your training data you might want to save the weights of the network so that the trained neural network can perform a similar task. So when you want to use this trained model on another similar dataset you can use the weights you saved from the previous network as the initial weight values for your new experiment. Initializing the weights this way is referred to as using a pre-trained network. Since our example is very simple so we don’t need to do this and that is why it’s set to false.

Well, this is the simplest theory I could come up with to explain the chain we made in that one line code. Let’s get back to the coding again.

Now we have the configurations it’s time to create a neural network (basically Initializing and running it). Type these lines inside IrisNNetwork() function:

MultiLayerNetwork model = new MultiLayerNetwork(configuration);

Next, we test the trained model by using the rest of the dataset and verify the results with evaluation metrics for our three classes:

INDArray output = model.output(testData.getFeatureMatrix());
Evaluation eval = new Evaluation(3);
eval.eval(testData.getLabels(), output);

One last thing before we run it. We need to call this function in our try block in loadData() function. So let’s add that:

Yes! go ahead run it. You’ve earned that… (If you see the run button is disabled, right-click anywhere in the code editor and select run from there)

This is the output I’ve got. You might get slightly different results than this. (NOTE: If you run into error or cannot execute the program, refer to possible error section in the last)

The result above shows the accuracy of the system to be 81% which is not bad. We usually tweak some other parameters to see if our model can be improved and save a copy of the best model.

Here is the full code:

Possible Errors

As after doing all of this and you might run into some errors when you run the program. So, here is a common error I’ve seen but feel free to post a comment to this article and I’ll post that error solution in reply:

If you get “Error: java: Compilation failed: internal java compiler error” ensure the following:

The Project java SDK is version 12.0.2. Then if that’s okay open preferences:

Make sure the “Target bytecode version” is 12.


As always feel free to post a comment if you face any error, I’ll try my best to answer.

I’m learning DL4J and as I build something successfully in it I post it on medium, so do watch out for the future DL4J tutorials as further, I’ll be going into computer vision tutorials where you’ll be working for images. But to get a basic understanding of how to build a Neural Network using DL4J, working on the Iris dataset is the best place to start.

Where to go from here? Here is the next thing you should do after you have finished this one:

Do press the 👏 so that others can find it as well and do highlight if you find any typos.




DataCTW is a website of online tutorials hosted at different places and are completely free. Our aim is to provide tutorials from scratch so that anyone who is looking to learn something new can benefit from it.

Recommended from Medium

TorchServe and [TorchElastic for Kubernetes], new PyTorch libraries for serving and training…

NLP using Python

Reinforcement Learning (DDPG and TD3) for News Recommendation

Heroes of Machine Learning — Top Experts and Researchers you should follow

All you need to know about CNN, Visual Explained- Convolutional Neural Networks

AI with Python — Data Preparation

Machine Learning Algorithms & Concepts

Yelp Reviews Classification

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store
Chaudhry Talha

Chaudhry Talha

Passionate about using technology for Social Impact. Let’s connect:

More from Medium

Bellman-Ford in Distance Vector Routing Protocol using Java

Representation Of A Graph In Memory

Pacific Atlantic Water Flow: Leetcode — Blind 75 (Graph)

Object Detection in Android/iOS using Xamarin Forms and ONNX