20-Minute Masterpiece: Training your own style transfer model

Mobile-ready style transfer models with Google Colab and Fritz AI

Jameson Toole

--

Artistic style transfer is one of the most popular creativity tools made possible by machine learning. People love to see their photos and videos transformed into works by famous artists or entirely new abstract masterpieces. We recently released 11 ready-to-use style filters as part of the Fritz SDK, but we often get asked if it’s possible to train custom styles. In this tutorial, I’ll show you how to train your very own style transfer model in just 20 minutes using our open source training template and Google Colab notebooks.

For those unfamiliar, Google Colab is an interactive, notebook-style compute environment available free to anyone. Thanks, Google! You can write and execute code in a virtual machine managed by Google and even get free access to GPUs and TPUs to accelerate your machine learning research. We’ll take advantage of those GPUs in this tutorial to cut model training time down to minutes.

To get started, head over to the Fritz AI Style Transfer Training Notebook.

This post will walk you through each step of the process. If you want to see what the final results will look like, try out the Fritz AI Studio app.

To get the most out of your mobile ML models, you’ll need to monitor and mange them in production. Our new free ebook explores best practices for this and other stages of the project lifecycle.

Getting started with Colab

This link will take you to a View Only version of the notebook. Make a copy of it by pressing the OPEN IN PLAYGROUND button at the top left. If you don’t already have Google Colab installed as a Google App, follow the prompts and instructions

Installing dependencies

Start by installing some dependencies. Note the specific package versions and branches. Machine learning tools evolve quickly and changes are often breaking. If you’re trying to run this and get errors, let us know!

pip install tensorflow keras==2.2.4 numpy matplotlib pillow
pip install git+https://www.github.com/keras-team/keras-contrib.git
pip install git+https://github.com/apple/coremltools.git@master

Next we’ll clone the Fritz AI Style Transfer repository from GitHub. Friendly reminder—if you find this helpful, star the repo!

git clone https://github.com/fritzlabs/fritz-models.git# add it to your python path
cd fritz-models/style_transfer
export PYTHONPATH=$PYTHONPATH:`pwd`
# create a data folder we’ll use later. Git will ignore this
mkdir data/

Choosing a style image

Now that everything is installed, let’s find a style we’d like to apply to photos. Certain style images work better than others. I’ve found that you get the best results when styles have a vibrant, contrasting color palette, large geometric shapes, and sharp edges. A painting like van Gogh’s Starry Night might be popular, but the iconic brush stroke texture is too small for the model to pick up during training. Consider cropping images to retain the color palette and improve the resulting style. Cubist and impressionist painters work very well.

If you’re not set on using the style from a famous artist, images tagged as geometric or abstract on Unsplash and Flickr work really well. Make sure you check the licenses of any images before you use them. Here’s a good image I found and cropped for this tutorial:

Photo by Darius Family.

Training data

Training a style transfer network requires two inputs: the style image we found in the previous section and a much larger set of arbitrary images for the style to be applied to. During training, the model stylizes each image over and over, and we evaluate the network on how well it mixes the style and content of the two inputs. The COCO dataset — specifically the 5000 images in the 2017 validation set — is perfect for the job.

Download the dataset to the data/folder we created earlier:

# From fritz-style-transfer/data/wget http://images.cocodataset.org/zips/val2017.zipunzip val2017.zip

Our training code expects images to be in the TFRecord format. This format has a few advantages over a folder of individual images. First, the TensorFlow Dataset API can read TFRecord files in chunks from cloud storage, making it suitable for distributed training. Second, the format makes it easy to package multiple fields with a single data record, so annotation data can be stored right next to the raw images. We don’t make use of the latter feature in this tutorial, but it’s particularly useful for other machine learning tasks like object detection or image segmentation.

The Fritz Style Transfer repository has a script that converts a folder of images into a TFRecord dataset:

# from /path/to/fritz-models/style_transfer/
python create_training_dataset.py \
--output data/training_images.tfrecord \
--image-dir data/val2017/

Training the network

It’s time to train the network. Running the command below will train a network that’ll work well on live video for iOS devices. The training script will print out progress updates every 10 training iterations.

python style_transfer/train.py \
--training-image-dset data/training_images.tfrecord \
--style-images data/style_image.jpg \
--model-checkpoint data/style_name_025.h5 \
--image-size 256,256 \
--alpha 0.25 \
--num-iterations 500 \
--batch-size 24 \
--fine-tune-checkpoint example/starry_night_256x256_025.h5

Here’s a breakdown of each argument we used.

  • --training-image-dset: the tfrecord dataset we created earlier.
  • --style-images: a path to the style image we downloaded
  • --model-checkpoint : a path to save model checkpoints to during training
  • --image-size : resize training images to these dimensions for training. Don’t worry, we’ll change this size to accept higher resolution images when we put the trained model into an app. 256px is good size for training, as it preserves image content while being small and fast to train on.
  • --alpha : a parameter controlling the number of weights in our network. In theory, more weights lead to better styles, but at the cost of model size and speed. 0.25 is a good value for models intended to run in mobile apps.
  • --num-iterations : the number of training steps to take. A good rule of thumb is that `num-iterations * batch-size` should be roughly 80,000 for models trained from scratch and 12,000 for models starting from a pre-trained checkpoint.
  • --batch-size : the number of images in each training batch. If this is set too high, training will crash because of memory errors. 24 is a good choice for GPUs in Google Colab.
  • --fine-tune-checkpoint : a path to a pre-trained model to start from. Used for transfer learning or to re-start training on a model. I’ve included a model trained on Starry Night that you can use for this tutorial.

If you start training from the starry_night_256x256_025.h5 checkpoint included in the example/ folder of the repository, a new style can be trained in about 20 minutes. If you’re starting training from scratch, the entire process should take about 3 hours. Make sure you’re using a GPU runtime in Google Colab.

There are a few other training arguments I didn’t use above that you might find useful as you experiment:

  • --style-weight : change this to make stylized images look more or less like the style image. The default value is 0.0001. For style images with subtle texture or color palettes, a value over 0.01 or 0.001 may work better. For style images with strong, large patterns, a value closer to 0.00001 is best.
  • --use-small-network : if present, an even smaller model architecture will be trained. Styles won’t look quite as crisp, but runtime performance is higher. Use this if you’re working with video on Android
  • --checkpoint-interval: the epoch interval at which to save models. Default is 10.

Stylize an image

This snippet loads our trained model and applies the style to an image you might find on Instagram. Comments explain what’s happening at each block.

import style_transfer.models
import keras
import PIL.Image
import requests
import numpy
import matplotlib.pyplot as pyplot
from io
# Clear the keras session of the training we just did
keras.backend.clear_session()
# Load the model
image_size = (640, 640)
model = style_transfer.models.StyleTransferNetwork.build(
image_size, alpha=0.25, checkpoint_file='/content/fritz-style-transfer/data/my_style_025.h5')
# Download an image to stylize
original_url = 'https://farm3.staticflickr.com/2907/14746369554_b783ba8d13_o_d.png'
response = requests.get(original_url)
original_image = PIL.Image.open(BytesIO(response.content))
original_image = original_image.resize(image_size)
# Pre-process input
input_data = numpy.array(original_image)[None, :, :, :3] — 120.0
# Stylize the photo
output_data = model.predict(input_data)
output_image = PIL.Image.fromarray(output_data[0].astype('uint8'))
Our stylized image!

Converting trained models to mobile formats

Finally, it’s time to convert models to a mobile friendly format. For the best in-app performance, use the native runtimes for each platform. That means Core ML for iOS and TensorFlow Mobile / Lite for Android. This model isn’t currently supported by TensorFlow Lite so we’ll stick with mobile.

There are two scripts to convert trained Keras models to mobile friendly formats:

python convert_to_coreml.py \
--keras-checkpoint data/style_name_025.h5 \
--alpha 0.25 \
--image-size 640,640 \
--coreml-model data/style_name_025.mlmodel

Note that we’re changing the input image size expected by the model to be 640x640 pixels. Core ML requires that model input sizes be fixed. This will change in Core ML 2, which can accept input images of arbitrary size.

For Android, you can convert your models to the TensorFlow Lite format with the following script.

convert_to_tflite.py \
--keras-checkpoint fritz-models/style_transfer/data/my_style_025.h5 \
--alpha 0.25 \
--image-size 640,640 \
--tflite-file fritz-models/style_transfer/data/my_style_025.tflite

You can download you models directly from Google Colab notebooks with a single line of Python.

# Download the mlmodel
from google.colab import files
files.download('/content/fritz-style-transfer/data/my_style_025.mlmodel')

For information on integrating these models into your mobile app, take a look at these tutorials for iOS and Android.

Conclusion

I hope this tutorial has been helpful and that you enjoy your custom style transfer model! If you’re interested in templates for other common ML models for tasks like object detection, image segmentation, and more, send us an email at info@fritz.ai.

Here’s a Gist with all of the commands in the Google Colab notebook in case you’d like to run things on your own machine.

Discuss this post on Hacker News.

Editor’s Note: Heartbeat is a contributor-driven online publication and community dedicated to exploring the emerging intersection of mobile app development and machine learning. We’re committed to supporting and inspiring developers and engineers from all walks of life.

Editorially independent, Heartbeat is sponsored and published by Fritz AI, the machine learning platform that helps developers teach devices to see, hear, sense, and think. We pay our contributors, and we don’t sell ads.

If you’d like to contribute, head on over to our call for contributors. You can also sign up to receive our weekly newsletters (Deep Learning Weekly and the Fritz AI Newsletter), join us on Slack, and follow Fritz AI on Twitter for all the latest in mobile machine learning.

--

--