Convert fast.ai trained image classification model to iOS app via ONNX and Apple Core ML
fast.ai is a great free open source/course for many people who love to learn and train Deep Learning model. One of the reasons why fast.ai is so popular is because fast.ai makes it easier, faster for machine learning practitioners to use and train their models. Even though training with fast.ai is fast (thanks to several techniques embedded, e.g. find learning rate automatically, stochastic gradient descent with restarts (SGDR)…), it usually achieves state-of-the-art results.
For many people who have their own ideas and data, only training model is not enough. Transforming their fast.ai trained model into production (e.g. iOS, Android application, web services) is actually the final goal. This tutorial will show a way to turn a image classifier model from fast.ai into iOS app.
Because fast.ai is built on Pytorch, we’ll actually convert it from Pytorch model to ONNX and finally Apple Core ML model which we’ll then import to iOS project.
Pytorch → ONNX → Apple Core ML
What we need to install
- Pytorch: if fast.ai environment is still using old Pytorch version, it’s better to upgrade Pytorch to latest release (e.g. Pytorch 0.4.1) via https://pytorch.org
- onnx-coreml: https://github.com/onnx/onnx-coreml
> Preparation: Pytorch-onnx currently doesn’t support AdaptivePooling but fast.ai is using that for training on different input image sizes (a way to prevent overfitting). But if we only care about one size, let’s say 299, we have to replace the AdaptivePooling by supported Pooling layer with fixed size. Go to fastai library and edit fastai/layers.py file as follows:
- Adding 2 classes: MyAdaptiveMaxPool2d and MyAdaptiveAvgPool2d
- Replacing 2 layers Max and Average in AdaptiveConcatPool2d with these 2 new classes above.
> Training: Now we can train an image classifier with fixed image size with new AdaptivePooling layers replaced. Assumed that we trained Dog vs Cat model like this notebook of fast.ai course, and we get like a accuracy of 0.9975 on validation set.
Note: I’ve tried both Resnext50 and Resnet50, Pytorch-onnx and onnx-coreml could convert them without error.
> Extracting Pytorch model from fast.ai learn.model:
- Add Image Transformation layer in the front: before forwarding through the model, every image data is normalized into [0, 1] range. This is done by Dataloader and not by the model. Therefore, we need to add a transformation layer at the beginning of the model. Here I just add the scale by dividing image data by 255.0. I hope that Apple Core ML or ONNX library will do this for us in the future.
- Replace LogSoftmax layer with Softmax at the end to get probability instead of loss/cost.
- Firstly, we need to create a dummy input image with shape (3, 299, 299). This dummy input will not affect our model’s weight. For iOS app, it will predict one image at a time so we don’t use batch here.
- Secondly, we need to indicate a name for our input layer which is the image via “input_names” parameter when calling torch.onnx.export. And after that, we need to use the same name for “image_input_names” when calling convert to Apple Core ML. By this way, Xcode will understand our input as an image instead of MLMultiArray.
- Last but not least, for classification problem, we need to create a text file containing label of all output classes. For example, I’ve created “labels.txt” file containing 2 lines:
After this step, we will get a file name “dog_vs_cat_image.onnx.mlmodel” to import into Xcode.
The rest of descriptions below will be just displayed in Xcode automatically when we import Core ML model.
> Importing mlmodel to Xcode:
This is quite straightforward step.
- Download and open the example from Apple site.
- Drag “dog_vs_cat_image.onnx.mlmodel” file into Xcode. Maybe we should rename it to a CamelCase classname convention like DogvsCatModel.mlmodel.
- In ImageClassificationViewController.swift file, replace default model class (MobileNet) by our new model (DogvsCatModel).
- Run Xcode and there we go.
The whole code of training and converting model can be found here: