Using Deep Java Library on Spring Boot

Arika Saputro
3 min readJul 14, 2022

--

Hi, this is my first post on medium, hope this will be helpful.

Deep Java Library (https://djl.ai) is a Java library that supports multiple Deep Learning frameworks like Apache MxNet, PyTorch and Tensorflow. It’s already used by some tech companies such as Netflix and Amazon. Deep Java Library initial release on 30 Nov 2019 and the latest version(0.18.0) is already released on 12 July 2022.

I think this is worth a try considering big data application(i.e., Spark, Kafka, Beam, Flink) are Java/Scala based but there are no Java support that work for all deep learning engine (Tensorflow, Pytorch, ect). It has numPy like operators : https://javadoc.io/doc/ai.djl/api/latest/ai/djl/ndarray/NDArray.html to full support for doing deep learning in Java and supporting MXNet, Tensorflow, PyTorch, and ONNX Engine and no need to change deployment when new version of model trained from different engine.

Let’s try implementing this DJL on spring boot for basic classification problem like MNIST classification.

Dependencies :

<properties>
<java.version>1.8</java.version>
<ai.djl.version>0.18.0</ai.djl.version>
<tensorflow-native-auto.version>2.4.1</tensorflow-native-auto.version>
<mxnet-native-auto.version>1.8.0</mxnet-native-auto.version>
<jna.version>5.3.0</jna.version>
</properties>

I will use Spring Boot Starter dependency, DJL tensorflow and mxnet dependency, and also i will use DJL basicdataset to use MNIST dataset. DJL makes it easy to download and load the MNIST dataset into memory via the Mnist class contained in ai.djl.basicdataset.

Spring Boot Starter dependency :

<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>

DJL dependency :

<dependency>
<groupId>ai.djl</groupId>
<artifactId>api</artifactId>
<version>${ai.djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.tensorflow</groupId>
<artifactId>tensorflow-api</artifactId>
<version>${ai.djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.tensorflow</groupId>
<artifactId>tensorflow-engine</artifactId>
<version>${ai.djl.version}</version>
</dependency>
<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-engine</artifactId>
<version>${ai.djl.version}</version>
</dependency>

Auto dependency that will download the correct artifact at runtime :

<dependency>
<groupId>ai.djl.mxnet</groupId>
<artifactId>mxnet-native-auto</artifactId>
<version>${mxnet-native-auto.version}</version>
<scope>runtime</scope>
</dependency>
<dependency>
<groupId>ai.djl.tensorflow</groupId>
<artifactId>tensorflow-native-auto</artifactId>
<version>${tensorflow-native-auto.version}</version>
<scope>runtime</scope>
</dependency>

Basicdataset dependency :

<dependency>
<groupId>ai.djl</groupId>
<artifactId>basicdataset</artifactId>
<version>${ai.djl.version}</version>
<type>jar</type>
</dependency>

ML Model:

Here I create some basic MLP model:

public class MultilayerPerceptron extends SequentialBlock {
public MultilayerPerceptron(int input, int output, int[] hidden, Function<NDList, NDList> activation){
add(Blocks.batchFlattenBlock(input));
for(int hiddenSize : hidden){
add(Linear.builder().setUnits(hiddenSize).build());
add(activation);
}
add(Linear.builder().setUnits(output).build());
}
}

Service Layer :

In service layer i will create service class for Training and Testing:

@Service
public class TrainingServiceImpl implements TrainingService {

@Autowired
HelperService helperService;

@Override
public TrainingResult trainMnistDataset() throws Exception {
try(Model model = Model.newInstance(MODEL_NAME)){
model.setBlock(new MultilayerPerceptron(Mnist.IMAGE_HEIGHT*Mnist.IMAGE_WIDTH, Mnist.NUM_CLASSES, new int[]{128, 64}, Activation::sigmoid));
RandomAccessDataset trainingSet = this.helperService.getDataSet(Dataset.Usage.TRAIN, BATCH_SIZE, LIMIT);
RandomAccessDataset validateSet = this.helperService.getDataSet(Dataset.Usage.TEST, BATCH_SIZE, LIMIT);
try(Trainer trainer = model.newTrainer(this.helperService.constructTrainingConfig(MODEL_NAME))){
trainer.setMetrics(new Metrics());
trainer.initialize(new Shape(1, Mnist.IMAGE_HEIGHT*Mnist.IMAGE_WIDTH));
EasyTrain.fit(trainer, EPOCH, trainingSet, validateSet);
Path modelDire = Paths.get(MODEL_DIRE);
Files.createDirectories(modelDire);
model.setProperty("Epoch", String.valueOf(EPOCH));
model.save(modelDire, MODEL_NAME);
return trainer.getTrainingResult();
}
}
}
}
@Service
public class TestingServiceImpl implements TestingService {
@Override
public Classifications testDataset() throws Exception {
Image img = ImageFactory.getInstance()
.fromFile(Paths.get(IMAGE_FILE));
try(Model model = Model.newInstance(MODEL_NAME)) {
model.setBlock(new MultilayerPerceptron(Mnist.IMAGE_HEIGHT*Mnist.IMAGE_WIDTH, Mnist.NUM_CLASSES, new int[]{128, 64}, Activation::sigmoid));
model.load(Paths.get(MODEL_DIRE));
List<String> classes = IntStream.range(0, 10)
.mapToObj(String::valueOf)
.collect(Collectors.toList());
Translator<Image, Classifications> translator =
ImageClassificationTranslator.builder()
.addTransform(new ToTensor())
.optSynset(classes)
.build();
try(Predictor<Image, Classifications> predictor = model.newPredictor(translator)){
return predictor.predict(img);
}
}
}
}

Here i create Helper Service to get the dataset and construct training config :

@Service
public class HelperServiceImpl implements HelperService {
@Override
public RandomAccessDataset getDataSet(Dataset.Usage usage, int batchSize, int limit) throws Exception {
Mnist mnist = Mnist.builder()
.optUsage(usage)
.optLimit(limit)
.setSampling(batchSize, true)
.build();
mnist.prepare(new ProgressBar());
return mnist;
}

@Override
public DefaultTrainingConfig constructTrainingConfig(String outputDire) throws Exception {
SaveModelTrainingListener listener = new SaveModelTrainingListener(outputDire);
listener.setSaveModelCallback(
trainer ->{
TrainingResult result = trainer.getTrainingResult();
Model model = trainer.getModel();
float accuracy = result.getTrainEvaluation("Accuracy");
model.setProperty("Loss", String.format("%.5f", result.getValidateLoss()));
model.setProperty("Accuracy", String.format("%.5f", accuracy));
});
return new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())
.addEvaluator(new Accuracy())
.optDevices(Engine.getInstance().getDevices(Engine.getInstance().getGpuCount()))
.addTrainingListeners(TrainingListener.Defaults.logging(outputDire))
.addTrainingListeners(listener);
}
}

Rest Controller Layer :

@RestController
@RequestMapping("/mlp")
public class MLPController {
@Autowired
TrainingService trainingService;
@Autowired
TestingService testingService;

@GetMapping("/training")
public ResponseEntity<ResponseMessage> trainMnist(){
try {
trainingService.trainMnistDataset();
return ResponseEntity.ok(ResponseMessage.builder().message("Training Data Done").build());
}catch (Exception e){
return ResponseEntity.ok(ResponseMessage.builder().message(e.getMessage()).build());
}
}

@GetMapping("/testing")
public ResponseEntity<ResponseMessage> predict(){
try {
Classifications classifications = testingService.testDataset();
classifications.setTopK(1);
return ResponseEntity.ok(ResponseMessage.builder().message(classifications.toString()).build());
}catch (Exception e){
return ResponseEntity.ok(ResponseMessage.builder().message(e.getMessage()).build());
}
}
}

Let’s run it.

train mnist data
test for mnist image with class 6

That’s all! I think creating machine learning with Deep Java Library on Spring Boot is a simple and powerful approach that enables us to combine an existing Spring technology and the most proven deep learning frameworks (MxNet, PyTorch and Tensorflow).

I hope this will be helpful to you. Cheers!

--

--