How to write a Custom Keras model so that it can be deployed for Serving
How to adapt custom Layers, Model, loss, preprocessing, postprocessing into a servable API
If the only Keras models you write are sequential or functional models with pre-built layers like Dense and Conv2D, you can ignore this article. But at some point in your ML career, you will find that you are subclassing a Layer or a Model. Or writing your own loss function, or needing custom preprocessing or postprocessing during serving. And at that point, you will find that you can not easily export your model using
model.save(EXPORT_PATH)
Even if you are able to save the model successfully, you might find that the model is not able to deployed successfully into TensorFlow Serving or one of the managed services such as Vertex AI or Sagemaker that wrap TF Serving. And even if you successfully deploy it, you may find that the results are unintuitive or flat-out wrong.
The documentation that explains what you have to do is unfortunately scattered among multiple pages. Some of the recommended approaches will work if you get everything right but won’t report errors, some approaches cause dramatic slowdowns in training, others are inflexible in serving, some cause model saving to take hours, and often, error messages can be hard to understand.
This article is about what you have to do if you have a custom anything (Layer, Model, Lambda, Loss, Preprocessor, Postprocessor) in Keras.
Example Model
To illustrate, I’m going to use a Named Entity Recognition (NER) model from the Keras examples. Basically, a NER that is trained to identify names and locations will take a sentence of the form:
John went to Paris
and return:
NAME out out LOCATION
How this model works itself isn’t all that important. Just that it involves custom Keras layers and a custom Keras model (i.e. they involve subclassing layers.Layer and keras.Model):
class TransformerBlock(layers.Layer):
...class TokenAndPositionEmbedding(layers.Layer):
...class NERModel(keras.Model):
...
The full code that goes with this article is on GitHub.
Such an NLP model will have to do some custom preprocessing of the input text. Basically, the input sentence is split into words, lowercased, and then converted into an index based on a vocabulary. The model is trained on the vocabulary ids:
def map_record_to_training_data(record):
record = tf.strings.split(record, sep="\t") tokens = tf.strings.split(record[0])
tokens = tf.strings.lower(tokens)
tokens = vocab_lookup_layer(tokens) tags = tf.strings.split(record[1])
tags = label_lookup_layer(tags)
return tokens, tags
The input tf.data pipeline is:
train_dataset = (
tf.data.TextLineDataset('train.csv')
.map(map_record_to_training_data)
.padded_batch(batch_size)
)
and the model is trained with a custom loss:
loss = CustomNonPaddingTokenLoss()
ner_model.compile(optimizer="adam", loss=loss)
ner_model.fit(train_dataset, epochs=10)
You can see why I picked this example — it has custom *everything*.
Export
What happens when we try to export this model to deploy it? Typically, this is what it will involve:
model.save(EXPORT_PATH)
Then, we will take the saved model and give to a service such as Sagemaker or Vertex AI and it will do:
model = saved_model.load_model(EXPORT_PATH)
model.predict(...)
Unfortunately, because of all the custom layers and code above, this stratightforward approach won’t work.
When we do:
ner_model.save(EXPORT_PATH)
we get an error:
Unknown loss function: CustomNonPaddingTokenLoss. Please ensure this object is passed to the `custom_objects` argument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.
Problem 1: Unknown Loss Function
How do we solve this? I will show you how to register custom objects later in this notebook. But the first thing to realize is that WE DO NOT NEED TO EXPORT THE LOSS. The loss is needed only for training, not for deployment.
So, we have a much simpler thing we can do. Just remove the loss:
# remove the custom loss before saving.
ner_model.compile('adam', loss=None)
ner_model.save(EXPORT_PATH)
Success! Bottom line: remove custom losses before exporting Keras models for deployment. It’s a one-liner.
Problem 2: Wrong input shape
Now, let’s do what TensorFlow Serving (or the managed service that wraps TensorFlow Serving, such as Sagemaker or Keras) does: call the predict() method of the model we just loaded:
sample_input = [
"Justin Trudeau went to New Delhi India",
"Vladimir Putin was chased out of Kyiv Ukraine"
]
model.predict(sample_input)
Unfortunately, we get an error:
Could not find matching concrete function to call loaded from the SavedModel. Got:
* Tensor("inputs:0", shape=(None,), dtype=string)Expected:
* TensorSpec(shape=(None, None), dtype=tf.int64, name='inputs')
Capturing the preprocessing
This is because we are trying to send in a full sentence (a string), but our model was trained on a set of vocabulary ids. That’s why the expected input is a set of int64.
We did this in our tf.data() pipeline when we called tf.strings.split(), tf.strings.lower and vocab_lookup_layer():
def map_record_to_training_data(record):
record = tf.strings.split(record, sep="\t") tokens = tf.strings.split(record[0])
tokens = tf.strings.lower(tokens)
tokens = vocab_lookup_layer(tokens)...
We’ve got to repeat that preprocessing during prediction too.
How?
Well, we could use a preprocessing container on Vertex AI or some similar functionality. But that sort of defeats our purpose of having a simple, all-in deployed Keras model.
Instead, we should reorganize our tf.data input pipeline. What we want is to have a function (here, I call it process_descr) that we can can call from both the tf.data pipeline and from our exported model:
def process_descr(descr):
# split the string on spaces, and make it a rectangular tensor
tokens = tf.strings.split(tf.strings.lower(descr))
tokens = vocab_lookup_layer(tokens)
max_len = MAX_LEN # max([x.shape[0] for x in tokens])
input_words = tokens.to_tensor(default_value=0, shape=[tf.rank(tokens), max_len])
return input_words
This allows us to have our training code work as before. When we are ready to save it, though, we need to create a layer with the preprocessing code included in the model.
Another way to do that is to define a prediction signature that calls the preprocessing function. However, that is problematic because if there are errors in your custom model, you won’t know about the errors (ask me how I know).
New model with preprocessing layer
A simpler approach that tells you the errors and gives you a chance to fix them is to do what we did with the loss function. Define a new standard model that has a lambda layer that does the preprocessing before feeding it to the custom model and write that out.
temp_model = tf.keras.Sequential([
tf.keras.Input(shape=[], dtype=tf.string, name='description'),
tf.keras.layers.Lambda(process_descr),
ner_model
])
temp_model.compile('adam', loss=None)
temp_model.save(EXPORT_PATH)
!ls -l {EXPORT_PATH}
Unfortunately, the aforementioned unreported errors now come up. What error?
Problem 3: Untracked Tensor
The error message we get is:
Tried to export a function which references 'untracked' resource Tensor
What's that about? Here's the issue: When you write a custom Keras layer or Keras loss or Keras model, you are defining code. But when you are exporting the model, you have to make a flat file out of it. What happens to the code? It's lost! How can the prediction work then?
You need to tell Keras how to pass in all the constructor arguments etc. Then Keras will pickle the code, resurrect the objects, and do the right thing.
The way you do that is by defining a getConfig() method that has all the constructor arguments. Basically, a custom layer that looks like this:
class TokenAndPositionEmbedding(layers.Layer):
def __init__(self, maxlen, vocab_size, embed_dim):
super(TokenAndPositionEmbedding, self).__init__()
self.token_emb = keras.layers.Embedding(
input_dim=vocab_size, output_dim=embed_dim
)
self.pos_emb = keras.layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
def call(self, inputs):
maxlen = tf.shape(inputs)[-1]
positions = tf.range(start=0, limit=maxlen, delta=1)
position_embeddings = self.pos_emb(positions)
token_embeddings = self.token_emb(inputs)
return token_embeddings + position_embeddings
will have to look like this:
@tf.keras.utils.register_keras_serializable() # 1
class TokenAndPositionEmbedding(layers.Layer):
def __init__(self, maxlen, vocab_size, embed_dim, **kwargs): # 2
super(TokenAndPositionEmbedding, self).__init__(**kwargs) # 3
self.token_emb = keras.layers.Embedding(
input_dim=vocab_size, output_dim=embed_dim
)
#4 save the constructor parameters for get_config()
self.maxlen = maxlen
self.vocab_size = vocab_size
self.embed_dim = embed_dim
self.pos_emb = keras.layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
def call(self, inputs):
maxlen = tf.shape(inputs)[-1]
positions = tf.range(start=0, limit=maxlen, delta=1)
position_embeddings = self.pos_emb(positions)
token_embeddings = self.token_emb(inputs)
return token_embeddings + position_embeddings
def get_config(self): # 5
config = super().get_config()
# save constructor args
config['maxlen'] = self.maxlen
config['vocab_size'] = self.vocab_size
config['embed_dim'] = self.embed_dim
return config
There are 5 changes:
- Add the annotation to register the custom layer with Keras
- Add a **kwargs to the constructor parameter
- Add a **kwargs to the super constructor
- Save the constructor parameters as instance fields
- Define a get_config method that saves the constructor args
Once we do this to our custom Layer and Model classes (look at the notebook in GitHub for the full code), we are still unable to save. The error message is exactly the same as before. We have an untracked Tensor. But we just went through all our custom classes and did the right things. What’s happening?
Problem 4: Untracked resource in Lambda Layer
We still have a problem. Even though we went through and fixed all the custom layers and models, there is still one piece of user-defined code.
The Lamda layer we are using for the preprocessing! It uses the vocabulary, and that vocab_lookup_layer is a resource that is untracked:
def process_descr(descr):
# split the string on spaces, and make it a rectangular tensor
tokens = tf.strings.split(tf.strings.lower(descr))
tokens = vocab_lookup_layer(tokens)
max_len = MAX_LEN # max([x.shape[0] for x in tokens])
input_words = tokens.to_tensor(default_value=0, shape=[tf.rank(tokens), max_len])
return input_words
Just pickling the function isn’t going to fix this.
Bottom line: Lambda layers are dangerous and it’s difficult to realize what resources we are forgetting.
I recommend that you get rid of any Lambda layers and replace them by custom layers.
Let’s do that, making sure to remember the 5 steps that are necessary for every custom Layer:
@tf.keras.utils.register_keras_serializable(name='descr')
class PreprocLayer(layers.Layer):
def __init__(self, vocab_lookup_layer, **kwargs):
super(PreprocLayer, self).__init__(**kwargs) # save the constructor parameters for get_config() to work properly
self.vocab_lookup_layer = vocab_lookup_layer def call(self, descr, training=False):
# split the string on spaces, and make it a rectangular tensor
tokens = tf.strings.split(tf.strings.lower(descr))
tokens = self.vocab_lookup_layer(tokens)
max_len = MAX_LEN # max([x.shape[0] for x in tokens])
input_words = tokens.to_tensor(default_value=0, shape=[tf.rank(tokens), max_len])
return input_words def get_config(self):
config = super().get_config()
# save constructor args
config['vocab_lookup_layer'] = self.vocab_lookup_layer
return config
Now, our temporary model for saving becomes:
temp_model = tf.keras.Sequential([
tf.keras.Input(shape=[], dtype=tf.string, name='description'),
PreprocLayer(vocab_lookup_layer),
ner_model
])
temp_model.compile('adam', loss=None)
temp_model.save(EXPORT_PATH)
!ls -l {EXPORT_PATH}
Of course, you should back and fix your tf.data input pipeline to use the layer instead of the process_descr function. Fortunately, that’s quite easy. Just replace the process_descr() call by a call to PreprocLayer():
PreprocLayer(vocab_lookup_layer)(['Joe Biden visited Paris'])
does the same thing as:
process_descr(['Joe Biden visited Paris'])
Postprocessing
Now, when we load the model and call predict, we get correct behavior:
model = tf.keras.models.load_model(EXPORT_PATH)
sample_input = [
"Justin Trudeau went to New Delhi India",
"Vladimir Putin was chased out of Kyiv Ukraine"
]
model.predict(sample_input)
This returns a set of probabilities, though:
array([[[7.6006036e-03, 4.3546227e-03, 9.7820580e-01, 1.3501652e-03,
5.0268644e-03, 3.4619651e-03],
[6.8284925e-03, 1.7240658e-02, 9.1373536e-04, 9.6674633e-01,
5.9596724e-03, 2.3111277e-03],
That’s annoying. Can we postprocess the output? Sure. We pretty much have to find the argmax of this array and then lookup the tag corresponding to that index. For example, if the second item is the maximum probability, we’d get mapping[1] which is B-NAME.
Now that we have custom code in Keras down, it’s a simple matter of applying the custom layer approach:
@tf.keras.utils.register_keras_serializable(name='tagname')
class OutputTagLayer(layers.Layer):
def __init__(self, mapping, **kwargs):
super(OutputTagLayer, self).__init__(**kwargs) # save the constructor parameters for get_config() to work properly
self.mapping = mapping # construct
self.mapping_lookup = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(
tf.range(start=0, limit=len(mapping.values()), delta=1, dtype=tf.int64),
tf.constant(list(mapping.values()))),
default_value='[PAD]') def call(self, descr_tags, training=False):
prediction = tf.argmax(descr_tags, axis=-1)
prediction = self.mapping_lookup.lookup(prediction)
return prediction def get_config(self):
config = super().get_config()
# save constructor args
config['mapping'] = self.mapping
return config
But that code? What the @#$@R$@# is that?
Well, where in Python, we can simply do:
mapping[ np.argmax(prob) ]
we have to do mapping.lookup:
prediction = tf.argmax(descr_tags, axis=-1)
prediction = self.mapping_lookup.lookup(prediction)
that mapping_lookup is itself a resource, and it is a TensorFlow equivalent of a dict:
self.mapping_lookup = tf.lookup.StaticHashTable(
tf.lookup.KeyValueTensorInitializer(
tf.range(start=0, limit=len(mapping.values()), delta=1, dtype=tf.int64),
tf.constant(list(mapping.values()))),
default_value='[PAD]')
At the time of writing, TensorFlow has a bug in storing dicts where the keys are integers, so I am hacking in the call to tf.range(). Sorry.
But the benefit of doing this is that you can simply take this model and deploy it. No extra preprocessing containers, postprocessing containers, etc. The code all runs on the GPU (it’s accelerated!) and is super-fast.
It’s also intuitive. We ask for a string and return the identified words. Sending in:
model = tf.keras.models.load_model(EXPORT_PATH)
sample_input = [
"Justin Trudeau went to New Delhi India",
"Vladimir Putin was chased out of Kyiv Ukraine"
]
model.predict(sample_input)
gives back:
array([[b'B-NAME', b'I-NAME', b'OUT', b'OUT', b'B-LOCATION',
b'I-LOCATION', b'I-LOCATION', b'[PAD]', b'[PAD]', b'[PAD]',
b'[PAD]', b'[PAD]', b'[PAD]', b'[PAD]', b'[PAD]', b'[PAD]'],
[b'B-NAME', b'I-NAME', b'OUT', b'OUT', b'OUT', b'OUT',
b'B-LOCATION', b'I-LOCATION', b'[PAD]', b'[PAD]', b'[PAD]',
b'[PAD]', b'[PAD]', b'[PAD]', b'[PAD]', b'[PAD]']], dtype=object)
i.e. the tag for each of the words in the provided sentence.
Enjoy!
Suggested Reading:
- My full code is on GitHub. I omitted some details from this article for readability. Please do refer to the notebook. You can run it in Colab or any Jupyter environment.
- Named Entity Recognition (NER) model from the Keras examples. Great illustrative model. You won’t be able to deploy it though.
- Read all about Keras custom layers here. Don’t miss the “optional” section on serialization. It’s mandatory if you want to deploy your custom layers.
- Once you write your custom layers, you have to do custom object registration. But registering custom objects is very error prone. You are likely to forget a few and the error message won’t tell you which one you missed. It’s better to use the annotation short-cut. Hope you saw it — it’s the third option discussed on that page.
- Read about Lambda layers here. Nothing there about the pitfalls of exporting a model with Lambda layers that have global objects in them. Hope you read the previous sections on serializing models and realized they apply equally to the functions being wrapped by the Lambda layer!
Sorry for the snark.