Using a Pre-Trained TensorFlow Model on Android — Part 2
In Part 1, I introduced you to the TensorFlowInferenceInterface
and the org.tensorflow:tensorflow-android
dependency. Together they provide an easy way to embed pre-trained TensorFlow models in your Android app.
In this post, we’ll dig more into the detail by looking at this simple example GitHub project:
What’s in the Dependency?
Let’s look at what the org.tensorflow:tensorflow-android
dependency brings into our project. The Project view in Android Studio helpfully lets us browse to the app/build/intermediates/exploded-aar
folder to see what’s been downloaded.
We can see that we have a native binary file for each of four architectures: arm64-v8a
, armeabi-v7a
, x86
, and x86_64
. There is also a classes.jar
which contains the TensorFlowInferenceInterface
and other supporting classes.
Project Structure
Our TensorFlow graph (.pb) and labels (.txt) are in app/src/main/assets
.
The TensorFlowImageClassifier
interfaces with the TensorFlowInferenceInterface
.
Code — Initialization
Initialize the TensorFlowInferenceInterface
.
c.inferenceInterface = new TensorFlowInferenceInterface(assetManager, modelFilename);
Code — Model Definitions
We need to define some names and sizes necessary to interface with our TensorFlow model.
Most of the values shown below are from MainActivity.
private static final String MODEL_FILE =
"file:///android_asset/mnist_model_graph.pb";
private static final String LABEL_FILE =
"file:///android_asset/graph_label_strings.txt";
These names of the model and label files match the files we saw in the assets
folder earlier.
private static final String INPUT_NAME = "input";
private static final String OUTPUT_NAME = "output";
The names of the input and output for the model come from our mnist.py training script. The script contains lots of complicated commands necessary for training a model from scratch, but these lines show where the input and output tensors are defined:
x_2 = tf.placeholder(“float”, shape=[None, 784], name=”input”)<snip>OUTPUT = tf.nn.softmax(tf.matmul(FC1, W_OUT) + B_OUT, name="output")
Another way to figure out the names of your input and output nodes is to import your TensorFlow model into TensorBoard and inspect it there.
The final thing we need to specify is the input size.
private static final int INPUT_SIZE = 28;
In the case of our MNIST example, our training data was all 28x28 pixel character images; to keep things simple, our sample project has users draw their character on a 28x28 “pixel” canvas (by passing 28 into the DrawModel class constructor).
Why do we actually need to define this? According to the code comments:
Ideally, inputSize could have been retrieved from the shape of the input operation. Alas, the placeholder node for input in the graphdef typically used does not specify a shape, so it must be passed in as a parameter.
Fortunately the outputSize is obtained directly from the TensorFlow model:
int numClasses =
(int) c.inferenceInterface.graph().operation(outputName).output(0).shape().size(1);
Code — Running the Classifier via TensorFlowInferenceInterface
Now we need to pass our 28x28 pixel drawings by the user into our pre-trained classifier.
Here are some (simplified) highlights of the recognizeImage method from the TensorFlowImageClassifier:
@Override
public List<Recognition> recognizeImage(final float[] pixels) {
// Copy the input data into TensorFlow.
inferenceInterface.feed(inputName, pixels, new long[]{inputSize * inputSize});
// Run the inference call.
inferenceInterface.run(outputNames);
// Copy the output Tensor back into the output array.
inferenceInterface.fetch(outputName, outputs);
// Find the best classifications.
for (int i = 0; i < outputs.length; ++i) {
<snip>
} return recognitions;
}
We feed
in the pixel data, run
the classifier, then fetch
the outputs.
Those outputs are then sorted to get the one with the highest confidence (above a specified threshold), and shown to the user:
You can read more about TensorFlowInferenceInterface.java in the TensorFlow Android contrib.
Warning — the TensorFlowInferenceInterface
API Changes!
TensorFlow is still under active development, and the TensorFlowInferenceInterface
changed between r1.1 and r1.2 release of TensorFlow.
You can see this by viewing the TensorFlowImageClassifier at r1.1 (uses fillNodeFloat
) and at r1.2 (uses feed
).
This might confuse you if you are viewing older examples.
DISCLOSURE STATEMENT: These opinions are those of the author. Unless noted otherwise in this post, Capital One is not affiliated with, nor is it endorsed by, any of the companies mentioned. All trademarks and other intellectual property used or displayed are the ownership of their respective owners. This article is © 2017 Capital One.
For more on APIs, open source, community events, and developer culture at Capital One, visit DevExchange, our one-stop developer portal: https://developer.capitalone.com/