Selfie2Anime with TFLite — Part 2: TFLite Model
This is part 2 of an end-to-end tutorial on how to convert a TF 1.x model to TensorFlow Lite (TFLite), and then deploy it to an Android for transforming an selfie image to a plausible anime. (Part 1 | Part 2 |Part 3) The tutorial is the first of a series of E2E TFLite tutorials of awesome-tflite.
Here is a step-by-step summary:
- Generate a
SavedModelout of the pre-trained U-GAT-IT model checkpoints.
- Convert SavedModel using the latest
- Run inference in Python with the converted model
- Add metadata to enable easy integration with mobile app
- Run model benchmark to make sure the model runs well on mobile
Model saving with TF1 — create a SavedModel from pre-trained checkpoints
Please note that this part needs to run in a TensorFlow 1.x runtime. We used TensorFlow 1.14 because that was the version the model code was written with.
The U-GAT-IT authors provided the two checkpoints: one extracted after 50 epochs (~4.6GB) and the other extracted after 100 epochs (4.7GB). We will be using a much lighter version from Kaggle, that is suitable for mobile-based deployments.
Download and extract the model checkpoints from Kaggle
So, first things first! Let’s download the checkpoints from Kaggle with the Kaggle API. On kaggle.com, go to My Account/API, click on “Create new API token” which triggers the download of kaggle. json, containing your API credentials. Then in Colab, you can specify the following and set the environment variables -
Let’s download the checkpoints and extract them-
Load model checkpoints and connect the tensors
This step usually varies from model to model. A general workflow that is followed in this step is as follows:
- Defining the input and output tensors of the model.
- Instantiating the model and connecting the input and the output tensors so that a computation graph can be built.
- Loading the pre-trained checkpoints in the model’s graph.
- Generate the SavedModel.
It is worth noting that step 2 in this workflow can vary from model to model so it’s really hard to know that beforehand. For this section, we are going to only focus on the part of the code that is important to understand, for the full implementation, please check out the Colab Notebook that accompanies this tutorial.
In our case, the input and the output tensors and their details can be accessed from an instance of the main model class. So, we will start by instantiating an instance of the UGATIT model class -
data refers to the model configurations as can be seen here. The
UGATIT class comes from here. At this point, our model should have been instantiated. Now we need to load the checkpoints into the model via the session into which it is loaded which is what
load_checkpoint() method does -
At this point, creating the
SavedModel needs only a matter of a few keystrokes. Remember that we are still under the
As we can see in the above code, the input and the output tensors can be accessed from the model graph itself. After this code is executed, we should have the
SavedModel files ready. We can proceed with converting this
SavedModel to a TFLite model.
Prepare the TFLite model
Time to shift gears to TensorFlow 2.x (2.2.0 or any higher nightly versions). In this section, we will be using the
SavedModel we generated previously and convert it to a TFLite flat buffer, which is about 10 MB in size and perfectly usable in a Mobile Application. Then we will use a few of the latest TensorFlow Lite tools to prepare the model for deployment:
- Run inference in Python with the TFLite model to make sure it’s good after the conversion.
- Add metadata to the TFLite model to make integrating it to an Android app easier with the Android Studio’s ML Model Binding plugin.
- Use the Benchmark tool to see how the model would perform on mobile devices.
Convert SavedModel to TFLite with TF2
First, we load the SavedModel files and create a concrete function from them -
The advantage of doing the conversion in this way is it gives us the flexibility to set the shapes of the input and output tensors of the resulting TFLite model. You can see this in the following code snippet -
It is recommended to use the original shapes of the input and output tensors that were used during training the model accordingly. In this case, this shape is (1, 256, 256, 3) and 1 denotes the batch dimension. This is required because the model expects the data to be in the shape of: BATCH_SIZE, IMAGE_SHAPE, IMAGE_SHAPE, NB_CHANNELS. To do the actual conversion we run the following -
Unless we specify any optimization option explicitly to the
converter, the model would still be a float model. You can explore the different optimization options available in TFLite from here.
Run Inference with TFLite model
After the conversion and before deploying the
.tflite model, it’s always a good practice to run inference in Python to confirm that it’s working as intended.
We have tried the model on a few faces and it turns out that it produces much better results on female faces than male ones. A closer look at the training dataset reveals that all faces are female faces, and the model bias because the model was trained on only female faces.
Here is a screen of the test result:
Add Metadata to TFLite Model
Let’s add metadata to the TensorFlow Model so that we can auto generate model inference code on Android.
Option A: via command line
If you are adding metadata with the Python script via command line, make sure to first
pip install tflite-support in your conda or virtualenv environment. And set up the folder structure as follows:
Then use the metadata_writer_for_selfie2anime.py script to add metadata to the selfie2anime.tflite model:
Optiona B: via Colab
Alternatively, you could use this Colab notebook instead. Remember to also first
$pip install tflite-support. This option maybe easier for you if you are not familiar with running Python scripts in command line. All you need is to launch the notebook in a browser, upload the selfie2anime.tflite file and execute all cells.
Two new file
selfie2anime.json are created under the
model_with_metadatafolder. This new
selfie2anime.tflite contains the model metadata which we can use as input to the ML model Binding in Android Studio when deploying the model to Android. And the
selfie2anime.json is for you to verify if the metadata added to the model is correct.
To learn more about how the TFLite metadata works, refer to the documentation here.
Benchmark model perf on Android (Optional)
As an optional step, we used the TFLite Android Model Benchmark tool to get the runtime performance on Android before deploying it. Please refer to the instructions on the benchmark tool for details.
Here are the high-level summary steps:
- Configure Android NDK/SDK — there are some Android SDK/NDK prerequisites then you build the tool with bazel.
- Build the benchmark apk
- Use adb (Android Debug Bridge) to install the benchmarking tool and push the selfie2anime.tflite model to Android device:
- Run the benchmark tool
We see the benchmark result as follows — and it’s a bit slow: Inference timings in us: Init: 7135, First inference: 7428506, Warmup (avg): 7.42851e+06, Inference (avg): 7.26313e+06
Now that you have a TensorFlow Lite model, let’s see how we would implement the model on Android (Part 3).