Saving and Loading a TensorFlow model using the SavedModel API
The SavedModel
API allows you to save a trained model into a format that can be easily loaded in Python, Java, (soon JavaScript), upload to GCP: ML Engine or use a TensorFlow Serving server.
This post will cover saving a trained model in python and then loading that model in Java and Python.
What’s Saved ?
assets/
assets.extra/
variables/
variables.data-*****-of-*****
variables.index
saved_model.pb
The .pb
is the MetaGraphDef
which holds the graph structure. The variables
folder holds your learned weights. The assets
folder allow you to add external files that may be needed and assets.extra
is a place libraries can add their assets.
“MetaGraph is a dataflow graph, plus its associated variables, assets, and signatures. A
MetaGraphDef
is the protocol buffer representation of a MetaGraph.”
How is this different from tf.train.Saver
This is different from the Saver API (tf.train.Saver
) save method which only saves the variables by adding save
and restore
operations to the graph.
# construct graph!
...
# add save/restore ops
saver = tf.train.Saver()
...
# save after training
save_path = saver.save(sess, "/tmp/model.ckpt")
The Saver API saves the variables in checkpoint files and requires you to reconstruct the graph in order to load the variables. This is desired when you are splitting your training into separate sessions and want a quick way to resume training. But for loading a model in a different language or to “package” a complete model the SavedModel API is recommended.
Saving
For this example we’ll use the MNIST beginner tutorial used in the official TensorFlow documentation. We will modify the script in two ways: adding names to ops and adding a couple of lines to save the model after training.
Adding names
We’re adding names to the input and output operations so that we can reference the operations by name when we load it. This step isn’t really necessary but it does make it a lot easier to load.
x = tf.placeholder(tf.float32, [None, 784], name="myInput")
Most, if not all, TF operations allow you to specify a name. For this example I was able to add the name to the input placeholder.
y = tf.nn.softmax(tf.matmul(x, W) + b, name="myOutput")
Adding a name to the output for this script is also straight forward.
There are cases where you may have a final output node/tensor but didn’t get the chance to add a name for whatever reason. In those cases you could use tf.identity
which allows you to add a name given a Tensor.
def addNameToTensor(someTensor, theName):
return tf.identity(someTensor, name=theName)
Save
Add these lines after training using the same session (sess
).
The easiest way to save is using the tf.saved_model.simple_save
function:
simple_save(sess,
export_dir,
inputs={"myInput": x},
outputs={"myOutput": y})
simple_save
allows you to quickly save with the minimal amount of arguments needed. It uses some sensible defaults to provide this convenience and one of the most important, for our case, is the tag that is used.
The tag is used to distinguish different MetaGraphDef
saved and is needed when loading the model. You can use any string for a tag but by default it uses tag_constants.SERVING
(“serve”). This also supports the Predict API which means any TensorFlow Serving server can load the model.
The way to actually save multiple MetaGraphDef's
, to define your own tags or to include assets is to use the builder.
import tensorflow.python.saved_model
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_defbuilder = saved_model.builder.SavedModelBuilder(export_path)
signature = predict_signature_def(inputs={'myInput': x},
outputs={'myOutput': y})# using custom tag instead of: tags=[tag_constants.SERVING]
builder.add_meta_graph_and_variables(sess=sess,
tags=["myTag"],
signature_def_map={'predict': signature})
builder.save()
Loading
Loading is similar for Python and Java and will probably be the same for any language that supports that API. We need the directory of the Saved Model artifacts, the tag used and the names of the input/output tensors.
The TensorFlow JavaScript API doesn’t support this format yet :(.
TensorFlow Serving and GCP have different independent steps but both were painless to use for my simple use cases, except for learning a bit about gRPC for TF Serving, but we won’t be covering these.
Loading in Python
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, ["serve"], export_path) graph = tf.get_default_graph()
print(graph.get_operations()) sess.run('myOutput:0',
feed_dict={'myInput:0': ...
After calling the load
function the graph is loaded as the default graph so you can interact with the graph as if you had reconstructed the graph. The variables are also loaded so you can start running inference on any new data.
Loading in Java
compile "org.tensorflow:tensorflow:1.8.0"
import org.tensorflow.*SavedModelBundle savedModelBundle = SavedModelBundle.load("./export_path", "serve");Graph graph = savedModelBundle.graph();
printOperations(graph);Tensor result = savedModelBundle.session().runner()
.feed("myInput", tensorInput)
.fetch("myOutput")
.run().get(0);
The TensorFlow Java Api hates boxed types so make sure you’re using primitives and it’s probably a good idea to validate your inputs after converting to Tensors if you’re getting wrong results.