Published in


Image Classification Neural Network Tutorial: Getting Started with DL4J

We’ll look at MINST Digits images dataset to build image classification neural network using DL4J


It is of high importance that you finish this tutorial first:


We’ll be using a famous dataset called MNIST (Basically the hello world of image classification). General MNIST dataset consist of 70,000 images of 28×28 pixels, representing handwritten 0–9 digits. 60,000 are part of the training set, which is the set used to train the network, while the remaining 10,000 are part of the test set. Download the .zip file:

Let’s start coding

As you have already created a project so let’s add a new java class in it named MinstClassifier

private static final String RESOURCES_FOLDER_PATH = "ADD_PATH_TO_RESOURCE_HERE";private static final int HEIGHT = 28;
private static final int WIDTH = 28;
private static final int N_SAMPLES_TRAINING = 60000;
private static final int N_SAMPLES_TESTING = 10000;
private static final int N_OUTCOMES = 10;
private static DataSetIterator getDataSetIterator(String folderPath, int nSamples) throws IOException {}
File folder = new File(folderPath);
File[] digitFolders = folder.listFiles();
NativeImageLoader nil = new NativeImageLoader(HEIGHT, WIDTH);
ImagePreProcessingScaler scalar = new ImagePreProcessingScaler(0,1);
INDArray input = Nd4j.create(new int[]{nSamples, HEIGHT*WIDTH});
INDArray output = Nd4j.create(new int[]{nSamples, N_OUTCOMES});
int n = 0;
for (File digitFolder: digitFolders) {
int labelDigit = Integer.parseInt(digitFolder.getName());
File[] imageFiles = digitFolder.listFiles();

for (File imgFile : imageFiles) {
INDArray img = nativeImageLoader.asRowVector(imgFile);
input.putRow(n, img);
output.put(n, labelDigit, 1.0);
//Joining input and output matrices into a dataset
DataSet dataSet = new DataSet(input, output);
//Convert the dataset into a list
List<DataSet> listDataSet = dataSet.asList();
//Shuffle content of list randomly
Collections.shuffle(listDataSet, new Random(System.currentTimeMillis()));
int batchSize = 10;

//Build and return a dataset iterator
DataSetIterator dsi = new ListDataSetIterator<DataSet>(listDataSet, batchSize);
return dsi;
public static void main(String[] args) throws IOException {


long t0 = System.currentTimeMillis();
DataSetIterator dataSetIterator = getDataSetIterator(RESOURCES_FOLDER_PATH + "training", N_SAMPLES_TRAINING);
private static void buildModel(DataSetIterator dsi) {

int rngSeed = 123;
int nEpochs = 2;

System.out.printf("Build Model...");
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.updater(new Nesterovs(0.006, 0.9))
.layer(new DenseLayer.Builder()
.layer(new OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)

MultiLayerNetwork model = new MultiLayerNetwork(conf);
//Print score every 500 interaction
model.setListeners(new ScoreIterationListener(500));

System.out.print("Train Model...");;
DataSetIterator testDsi = getDataSetIterator(RESOURCES_FOLDER_PATH+"/testing", N_SAMPLES_TESTING);
System.out.print("Evaluating Model...");
Evaluation eval = model.evaluate(testDsi);

long t1 = System.currentTimeMillis();
double t = (double)(t1-t0)/1000.0;
System.out.print("\n\nTotal time: "+t+" seconds");


So, you have just finished building classification model using DL4J. MINST Digit is the hello world of classification example in deep learning projects and I’ll be writing more classification projects but this will be a prerequisite;



DataCTW is a website of online tutorials hosted at different places and are completely free. Our aim is to provide tutorials from scratch so that anyone who is looking to learn something new can benefit from it.

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store