How to invoke a trained TensorFlow model from Java programs

The primary language in which TensorFlow machine learning models are created and trained is Python. However, many enterprise server programs are written in Java. So, you will often run into situations where you need to invoke the Tensorflow model that you trained in Python from a Java program.

If you are using CloudML on the Google Cloud Platform, this is no problem — in CloudML, predictions are made through a REST API call and so you can do this from any programming language. But what if you have downloaded the TensorFlow model, and want to carry out predictions offline?

Here’s how you can make predictions in Java using Tensorflow models that were trained in Python.

Note: The Tensorflow team has now started to add Java bindings. See https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java for details. Try that first, and if it doesn’t work for you, come here ...

Write out model files in Python

The first thing to do is save the TensorFlow model in Python in two formats: (a) the weights, biases, etc. as a “saver_def” file (b) the graph itself as a protobuf file. To preserve your sanity, you might want to save the graph as both text and as a binary protobuf format. You will find it helpful to read through the text format to find the names assigned by TensorFlow to nodes that you did not explicitly assign names to. The code to write these three files from Python:

# create a Saver object as normal in Python to save your variables
saver = tf.train.Saver(...)
# Use a saver_def to get the "magic" strings to restore
saver_def = saver.as_saver_def()
print saver_def.filename_tensor_name
print saver_def.restore_op_name
# write out 3 files
saver.save(sess, 'trained_model.sd')
tf.train.write_graph(sess.graph_def, '.', 'trained_model.proto', as_text=False)
tf.train.write_graph(sess.graph_def, '.', 'trained_model.txt', as_text=True)

In my case, the two magic strings printed out from save_def were save/Const:0 and save/restore_all — so that’s what you’ll see in my Java code. Change these when you write your Java code if yours are different.

The .sd file contains weights, biases, etc. (the actual values for the Variables in your graph). The .proto file is a binary file containing your computation graph and .txt the corresponding text version.

Invoking Tensorflow C++ from Java

Even though you may have used Tensorflow in Python to feed data to your model and train it, the Tensorflow Python package actually calls on a C++ implementation to carry out the actual work. Therefore, we can use Java Native Interface (JNI) to directly invoke C++ and use C++ to create the graph and restore the weights and biases from the model from Java.

Rather than write all the JNI calls by hand, it is possible to use an open-source library called JavaCpp to do this. To use JavaCpp, add this dependency to your Java Maven pom.xml:

<dependency>
<groupId>org.bytedeco.javacpp-presets</groupId>
<artifactId>tensorflow</artifactId>
<version>0.9.0–1.2</version>
</dependency>

If you are using some other build management system, add Javacpp presets for tensorflow and all of its dependencies to your application’s classpath.

Create model in Java

In your Java code, read the proto file to create a Graph definition as follows (imports are omitted for clarity):

final Session session = new Session(new SessionOptions());
GraphDef def = new GraphDef();
tensorflow.ReadBinaryProto(Env.Default(),
"somedir/trained_model.proto", def);
Status s = session.Create(def);
if (!s.ok()) {
throw new RuntimeException(s.error_message().getString());
}

Next, restore the weights and biases from the saved model file using Session::Run(). Note how the magic strings from saver_def are used.

// restore
Tensor fn = new Tensor(tensorflow.DT_STRING, new TensorShape(1));
StringArray a = fn.createStringArray();
a.position(0).put(“somedir/trained_model.sd”);
s = session.Run(new StringTensorPairVector(new String[]{“save/Const:0”}, new Tensor[]{fn}), new StringVector(), new StringVector(“save/restore_all”), new TensorVector());
if (!s.ok()) {
throw new RuntimeException(s.error_message().getString());
}

Making predictions in Java

At this point, your model is ready. You can now use it to make predictions. This is similar to how you’d do it in Python — you have to pass in values for all your placeholders and evaluate the output node. The difference is that you have to know the actual names of the placeholder and output nodes. If you didn’t assign these nodes unique names in Python, Tensorflow assigned them names. You can find out what they are by looking at the trained_model.txt file that got written out. Or you can go back to your Python code and assign the key nodes names that you remember. In my case, the input placeholder was called Placeholder; the dropout node placeholder was called Placeholder_2, and the output node was called Sigmoid. You’ll see these referenced in the Session::Run() call below.

In my case, the neural network uses 5 predictor variables. Assuming that I have the array of inputs that are the predictors to my neural network model and want to do the prediction for 2 sets of such inputs, my input is a 2x5 matrix. My NN has only one output, so for 2 sets of inputs, the output tensor is a 2x1 matrix. The dropout node is given a hardcoded input of 1.0 (in prediction, we keep all nodes — the dropout probability is only for training). So, I have:

// try to predict for two (2) sets of inputs.
Tensor inputs = new Tensor(
tensorflow.DT_FLOAT, new TensorShape(2,5));
FloatBuffer x = inputs.createBuffer();
x.put(new float[]{-6.0f,22.0f,383.0f,27.781754111198122f,-6.5f});
x.put(new float[]{66.0f,22.0f,2422.0f,45.72160947712418f,0.4f});
Tensor keepall = new Tensor(
tensorflow.DT_FLOAT, new TensorShape(2,1));
((FloatBuffer)keepall.createBuffer()).put(new float[]{1f, 1f});
TensorVector outputs = new TensorVector();
// to predict each time, pass in values for placeholders
outputs.resize(0);
s = session.Run(new StringTensorPairVector(new String[] {“Placeholder”, “Placeholder_2”}, new Tensor[] {inputs, keepall}),
new StringVector(“Sigmoid”), new StringVector(), outputs);
if (!s.ok()) {
throw new RuntimeException(s.error_message().getString());
}
// this is how you get back the predicted value from outputs
FloatBuffer output = outputs.get(0).createBuffer();
for (int k=0; k < output.limit(); ++k){
System.out.println(“prediction=” + output.get(k));
}

That’s it — you are now using Java to carry out your predictions. There are several steps, but that is to be expected when one is mixing 3 programming languages (Python, C++ and Java). But the important thing is that it can be done, and that it is relatively straightforward.

Of course, doing this doesn’t take advantage of hardware acceleration and distribution. If you want to make predictions at a very high rate in real-time, you should consider using CloudML.