Importing subclass model weights to keras functional API models

Anuj Arora
Dive into ML/AI
Published in
2 min readMar 25, 2021

--

Keras has three methods for defining neural network architectures, namely, Sequential API, Functional API and model subclassing. More about this can be read here. This article introduces a method to import subclass model weights to Functional API model.

Now, I have written about importing pre-trained tensorflow-1 model weights to keras. Back then the google-research still had not provided a tensorflow-2 implementation of SimCLR. However, tf2 version employs model subclassing method for training and saving of the weights. This implies that the loaded model from the saved model files will not be a keras object as discussed here,

The object returned by tf.saved_model.load is not a Keras object

Similar sentiment is reflected by Saurabh Saxena (contributor to SimCLR repository) on an issue pertaining to the pretrained weights. He mentions,

SavedModel saves the object hierarchy to track variables but other information is ignored so the restored model isn't actually a tf.keras.Model so won't have the summary function.

However, I enjoy the flexibility afforded by the keras model objects defined through Functional API. Hence, I wanted to figure out a way to transfer the saved resnet-50 with selective kernel model weights to a Functional API format. After profusely struggling, below is the code snippet that did the job for me.

--

--