GridDB
Published in

GridDB

How to Implement an Artificial Neural Network Using Java | GridDB: Open Source Time Series Database for IoT

What is an Artificial Neural Network?

Write the Data into GridDB

import java.io.IOException;
import java.util.Properties;
import java.util.Collection;
import java.util.Scanner;
import com.toshiba.mwcloud.gs.Collection;
import com.toshiba.mwcloud.gs.GridStore;
import com.toshiba.mwcloud.gs.GSException;
import com.toshiba.mwcloud.gs.GridStoreFactory;
import com.toshiba.mwcloud.gs.RowKey;
import com.toshiba.mwcloud.gs.Query;
import com.toshiba.mwcloud.gs.RowSet;
import java.io.File;
public static class BankCustomers {

@RowKey String rowNumber;
String surname, geography, gender, tenure, hasCrCard, isActiveMember;
int customerId, creditScore, age, numOfProducts, exited;
Double balance, estimatedSalary;
}
Properties props = new Properties();
props.setProperty("notificationAddress", "239.0.0.1");
props.setProperty("notificationPort", "31999");
props.setProperty("clusterName", "defaultCluster");
props.setProperty("user", "admin");
props.setProperty("password", "admin");
GridStore store = GridStoreFactory.getInstance().getGridStore(props);
Collection<String, BankCustomers> coll = store.putCollection("col01", BankCustomers.class);
File file1 = new File("Churn_Modelling.csv");
Scanner sc = new Scanner(file1);
String data = sc.next();


while (sc.hasNext()){
String scData = sc.next();
String dataList[] = scData.split(",");


String rowNumber = dataList[0];
String customerId = dataList[1];
String surname = dataList[2];
String creditScore = dataList[3];
String geography = dataList[4];
String gender = dataList[5];
String age = dataList[6];
String tenure = dataList[7];
String balance = dataList[8];
String numOfProducts = dataList[9];
String hasCrCard = dataList[10];
String isActiveMember = dataList[11];
String estimatedSalary = dataList[12];
String exited = dataList[13];



BankCustomers bc = new BankCustomers();


bc.rowNumber = rowNumber;
bc.customerId = Integer.parseInt(customerId);
bc.surname = surname;
bc.creditScore = Integer.parseInt(creditScore);
bc.geography = geography;
bc.gender = gender;
bc.age = Integer.parseInt(age);
bc.tenure = tenure;
bc.balance = Double.parseDouble(balance);
bc.numOfProducts = Integer.parseInt(numOfProducts);
bc.hasCrCard = hasCrCard;
bc.isActiveMember = isActiveMember;
bc.estimatedSalary = Double.parseDouble(estimatedSalary);
bc.exited =Integer.parseInt(exited);
coll.append(bc);
}

Retrieve the Data from GridDB

Query<bankcustomers> query = coll.query("select *");
RowSet</bankcustomers><bankcustomers> rs = query.fetch(false);
RowSet res = query.fetch();</bankcustomers>

Data Pre-Processing

import org.datavec.api.records.reader.RecordReader;
import org.datavec.api.records.reader.impl.transform.TransformProcessRecordReader;
import org.datavec.api.records.reader.impl.csv.CSVRecordReader;
import org.datavec.api.split.FileSplit;
import org.datavec.api.transform.TransformProcess;
import org.deeplearning4j.api.storage.StatsStorage;
import org.datavec.api.transform.schema.Schema;
import org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator;
import org.deeplearning4j.datasets.iterator.DataSetIteratorSplitter;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.MultiLayerConfiguration;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.weights.WeightInit;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.optimize.listeners.ScoreIterationListener;
import org.deeplearning4j.ui.api.UIServer;
import org.deeplearning4j.ui.storage.InMemoryStatsStorage;
import org.deeplearning4j.ui.stats.StatsListener;
import org.deeplearning4j.util.ModelSerializer;
import org.nd4j.evaluation.classification.Evaluation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.activations.Activation;
import org.nd4j.linalg.dataset.api.iterator.DataSetIterator;
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize;
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.learning.config.Adam;
import org.nd4j.linalg.io.ClassPathResource;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;
import org.slf4j.LoggerFactory;
import org.slf4j.Logger;
import java.util.Arrays;
private static Schema createSchema() {
final Schema scm = new Schema.Builder()
.addColumnString("rowNumber")
.addColumnInteger("customerId")
.addColumnString("surname")
.addColumnInteger("creditScore")
.addColumnCategorical("geography", Arrays.asList("France","Germany","Spain"))
.addColumnCategorical("gender", Arrays.asList("Male","Female"))
.addColumnsInteger("age", "tenure")
.addColumnDouble("balance")
.addColumnsInteger("numOfProducts","hasCrCard","isActiveMember")
.addColumnDouble("estimatedSalary")
.addColumnInteger("exited")
.build();
return scm;
}
private static RecordReader dataTransform(RecordReader reader, Schema scm){
final TransformProcess transformation = new TransformProcess.Builder(scm)

.removeColumns("rowNumber","customerId","surname")
.categoricalToInteger("gender")
.categoricalToOneHot("geography")
.removeColumns("geography[France]")
.build();

final TransformProcessRecordReader transformationReader = new TransformProcessRecordReader(reader,transformation);
return transformationReader;
}

Split the Data into Train and Test Sets

private static RecordReader generateReader(File file) throws IOException, InterruptedException {
final RecordReader reader = new RecordReader();
reader.initialize(new FileSplit(res));
reader.initialize(res);
final RecordReader transformationReader=applyTransform(reader,createSchema());
return transformationReader;
}

Define the Input and Output Labels

final int labIndex=11;
final int numClasses=2;
final int batchSize=8;
final INDArray weightArray = Nd4j.create(new double[]{0.57, 0.75});
final RecordReader reader = generateReader(res);
final DataSetIterator iterator = new RecordReaderDataSetIterator.Builder(reader,batchSize)
.classification(labIndex,numClasses)
.build();

Data Scaling

final DataNormalization normalization = new NormalizerStandardize();
normalization.fit(iterator);
iterator.setPreProcessor(normalization);
final DataSetIteratorSplitter iteratorSplitter = new DataSetIteratorSplitter(iterator,1250,0.8);

Define the Neural Network

final MultiLayerConfiguration config = new NeuralNetConfiguration.Builder()

.weightInit(WeightInit.RELU_UNIFORM)

.updater(new Adam(0.015D))

.list()

.layer(new DenseLayer.Builder().nIn(11).nOut(6).activation(Activation.RELU).dropOut(0.9).build())

.layer(new DenseLayer.Builder().nIn(6).nOut(6).activation(Activation.RELU).dropOut(0.9).build())

.layer(new DenseLayer.Builder().nIn(6).nOut(4).activation(Activation.RELU).dropOut(0.9).build())

.layer(new OutputLayer.Builder(new LossMCXENT(weightArray)).nIn(4).nOut(2).activation(Activation.SOFTMAX).build())

.build();

Train the Model

final UIServer ui = UIServer.getInstance();
final StatsStorage stats = new InMemoryStatsStorage();
final MultiLayerNetwork multiNetwork = new MultiLayerNetwork(config);
multiNetwork.init();
multiNetwork.setListeners(new ScoreIterationListener(100),
new StatsListener(stats));
ui.attach(stats);
multiNetwork.fit(iteratorSplitter.getTrainIterator(),100);
final Evaluation ev =  multiNetwork.evaluate(iteratorSplitter.getTestIterator(),Arrays.asList("0","1"));
System.out.println(ev.stats());

Make Predictions

private static Schema createSchema() {
final Schema scm = new Schema.Builder()
.addColumnString("rowNumber")
.addColumnInteger("customerId")
.addColumnString("surname")
.addColumnInteger("creditScore")
.addColumnCategorical("geography", Arrays.asList("France","Germany","Spain"))
.addColumnCategorical("gender", Arrays.asList("Male","Female"))
.addColumnsInteger("age", "tenure")
.addColumnDouble("balance")

.addColumnsInteger("numOfProducts","hasCrCard","isActiveMember")
.addColumnDouble("estimatedSalary")
.build();
return scm;
}
public static INDArray predictedOutput(File inputFile, String modelPath) throws IOException, InterruptedException {
final File file = new File(modelPath);
final MultiLayerNetwork net = ModelSerializer.restoreMultiLayerNetwork(file);
final RecordReader reader = generateReader(inputFile);
final NormalizerStandardize normalizerStandardize = ModelSerializer.restoreNormalizerFromFile(file);
final DataSetIterator iterator = new RecordReaderDataSetIterator.Builder(reader,1).build();
normalizerStandardize.fit(iterator);
iterator.setPreProcessor(normalizerStandardize);
return net.output(iterator);
}

Compile and Run the Code

export CLASSPATH=$CLASSPATH:/home/osboxes/Downloads/griddb_4.6.0-1_amd64/usr/griddb-4.6.0/bin/gridstore.jar
javac BankRetention.java
java BankRetention

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store