Neural Style Transfer with Swift for TensorFlow

James Thompson
May 2, 2019 · 10 min read
Image for post
Image for post
Left: The Painted Ladies by King of Hearts / Wikimedia Commons / CC-BY-SA-3.0 | Middle: A Starry Night by Vincent van Gogh, c. June 1889 | Right: Image generated by employing the Neural Style Transfer algorithm, Gatys et al.

What is Neural Style Transfer?

Simply put, Neural Style Transfer [Gatys et al.] is a process by which we take the style of one image, the content of another image and generate a new image that exhibits the same stylistic features as the style image while preserving the high-level structure of the content image.

  • A content image.
  • A styled image. This is the output. We can initialize it a number of ways (which we’ll explore below).
  • A way of extracting style information from layer activations.
  • A pre-trained conv net (we’ll use VGG-19 [Karen Simonyan, Andrew Zisserman]).

Install dependencies via SPM

Installing packages: .package(url: "", from: "0.16.1") Path With SwiftPM flags: [] Working in: /tmp/tmp7lgo13l0/swift-install Compile Swift Module 'jupyterInstalledPackages' (1 sources) Initializing Swift... Installation complete!

Import dependencies

Grab the VGG-19 checkpoint

Make sure you have wget installed on your system. If you’d rather not use it, you can try using curl instead.

Add a gram matrix method to Tensor

We use the gram matrix to extract the correlation structure between the filters of a given layer. This effectively allows us to decouple texture information from the global arrangement of the scene. We’ll use this later when we calculate the perceptual loss. See the paper for more details.

Define a layer for the input image

Disclaimer: This is probably not the best way to handle this, but after trying a few alternatives, I landed on this. The idea is that we want to freeze the parameters of the Conv Net and only update the input image during back prop. At the time of writing this, I couldn’t find a straightforward way to A) freeze a layer while still computing all of the gradients and B) have the input image be a tune-able parameter. I have some ideas on how I might be able to better handle this, but this works well enough for now.

Define a struct to store the output activations

Swift for TensorFlow also doesn’t currently support differentiable control flow which leads to more repetitious code, though I’ve been told that this feature is coming very soon! It also doesn’t support pulling activations out of a net after a forward pass. There are definitely work arounds for both of these issues but they involve writing writing some less-than-straightforward workarounds. At some point this functionality will be built into the Swift for TensorFlow deep learning framework.

Define a concrete pooling type

We’ll use this to easily swap which type of pooling we’re using. Again, once differentiable control flow lands, this sort of thing shouldn’t be necessary. I tried getting this to work nicely with generics but to no avail.

Implement the VGG19 architecture without the classification head

This code should look fairly straightforward. Again, I could probably reduce some of the repition but this is just a first pass on this to get something up and running. I may do an update where I go in and refactor some of this (especially once differentiable control flow is supported).

Tie the two layers together

I chose to compose the two layers into one here. I had written this prior to S4TF 0.3 which I believe introduced a better way to sequence layers together. I’ll probably revisit this also.

Define an optimizer

This is probably the hackiest part of this. I’m sure there’s a way to use the existing Adam optimizer and freeze all the parameters in the network, updating only the image, however I couldn’t figure it out and this was the best I could come up with. It works, but I’ll definitely try to revisit this on my next pass.

Total variation loss

Total variation loss is a regularization technique that is commonly used to denoise images. Unfortunately this currently causes the GPU memory to grow unbounded and triggers an OOM error. There’s probably a clever way to avoid this, but the results look pretty good without it. I’ve kept it here for reference. Perhaps someone can spot what I’m doing wrong and let me know ;)

Perceptual Loss

This is the loss function we’ll use. We compute the mean squared error (MSE) between the target content activations and the styled image. We then do the same for the style activation layers but instead of computing the MSE of the raw activations, we compute the MSE of the gram matrix of the activations. Each style and content layer loss is then scaled by its own weight. Lastly, every thing is summed up and returned as our final loss.

A clamping utility

This utility is used to keep the color values in the range that VGG19 expects to see, i.e. +/- the imagenet mean pixel values (in BGR). Without this, regions of the image would over excite the network causing clipped regions and aberrant noise. Note: It might be an interesting experiment to clamp these values with a smooth function instead of min/max. It might reduce the noise a bit.

Use Python interoperability to show images via matplotlib

Image processing utilities

These functions are used to load up the images into Float tensors and pre/post process them. This pre-trained VGG-19 model was trained on images in BGR channel order that were normalized by the imagenet mean. Note that they did not divide by the standard deviation so values be in the range of [-mean, mean].

Set up matplotlib for inline display

('inline', 'module://ipykernel.pylab.backend_inline')

A utility to display an image tensor

Define a struct to hold the training results

This is where the results of training get stored so we can see how the training progressed over time. It also has a function that will plot the output images in a grid. This is nifty when trying out and comparing different hyper-parameters.

Peek at the style and content images

We’ll use the images from the graphic at the beginning of the post. Note: If you run into memory issues, you can drop the size down to 256.

TensorShape(dimensions: [1, 512, 512, 3]) 
TensorShape(dimensions: [1, 512, 512, 3])
Image for post
Image for post
Image for post
Image for post

Define the training method

Right now I’ve only really needed to change the style weights, content weight, iteration count and learning rate, so that’s what’s exposed.

Tie things up into a function we can experiment with

This will allow us to play around with the weights, learning rate, etc.

Let’s finally perform style transfer to make an image

[Iteration 0 - Perceptual Loss: 2.778563e+08] 
[Iteration 50 - Perceptual Loss: 4116344.5]
[Iteration 100 - Perceptual Loss: 2275590.2]
[Iteration 150 - Perceptual Loss: 1557501.5]
[Iteration 200 - Perceptual Loss: 1184193.0]
[Iteration 250 - Perceptual Loss: 1057808.2]
[Iteration 300 - Perceptual Loss: 863599.44]
[Iteration 350 - Perceptual Loss: 778174.4]
[Iteration 400 - Perceptual Loss: 737367.2]
[Iteration 450 - Perceptual Loss: 676374.0]
[Iteration 500 - Perceptual Loss: 946441.3]
Image for post
Image for post

Let’s look at how the image progressed over time

Notice how it actually converges to something pretty reasonable after about 100 iterations

Image for post
Image for post

Now let’s try initializing with the style image instead

We’ll use higher content weights this time around. As one might imagine starting with the style image will bias towards the global structure of the style image. We could alternatively tune the style weights down a bit, but let’s see where this gets us.

[Iteration 0 - Perceptual Loss: 16007397.0] 
[Iteration 50 - Perceptual Loss: 10292353.0]
[Iteration 100 - Perceptual Loss: 9144543.0]
[Iteration 150 - Perceptual Loss: 7475930.0]
[Iteration 200 - Perceptual Loss: 7086185.5]
[Iteration 250 - Perceptual Loss: 6741186.5]
[Iteration 300 - Perceptual Loss: 6937329.0]
[Iteration 350 - Perceptual Loss: 6654474.0]
[Iteration 400 - Perceptual Loss: 6679740.5]
[Iteration 450 - Perceptual Loss: 11030902.0]
[Iteration 500 - Perceptual Loss: 6526403.5]
Image for post
Image for post

Let’s try to tune the style weights

[Iteration 0 - Perceptual Loss: 16007397.0] 
[Iteration 50 - Perceptual Loss: 6788499.0]
[Iteration 100 - Perceptual Loss: 5649164.0]
[Iteration 150 - Perceptual Loss: 5271180.0]
[Iteration 200 - Perceptual Loss: 5076626.5]
[Iteration 250 - Perceptual Loss: 4990565.5]
[Iteration 300 - Perceptual Loss: 4901208.0]
[Iteration 350 - Perceptual Loss: 4857651.0]
[Iteration 400 - Perceptual Loss: 4806595.0]
[Iteration 450 - Perceptual Loss: 4777933.0]
[Iteration 500 - Perceptual Loss: 4765858.5]
Image for post
Image for post
Image for post
Image for post

Let’s try using average instead of max pooling

[Iteration 0 - Perceptual Loss: 279485.72] 
[Iteration 50 - Perceptual Loss: 202219.42]
[Iteration 100 - Perceptual Loss: 169860.62]
[Iteration 150 - Perceptual Loss: 156222.81]
[Iteration 200 - Perceptual Loss: 148178.36]
[Iteration 250 - Perceptual Loss: 143206.12]
[Iteration 300 - Perceptual Loss: 140567.17]
[Iteration 350 - Perceptual Loss: 137404.88]
[Iteration 400 - Perceptual Loss: 135017.0]
[Iteration 450 - Perceptual Loss: 133670.3]
[Iteration 500 - Perceptual Loss: 132263.81]
Image for post
Image for post
Image for post
Image for post
Left Max Pooling — Right Average Pooling

To close out, let’s test on some new images

Here’s our style image.

Image for post
Image for post
Starry Night Over the Rhône by Vincent van Gogh
Image for post
Image for post
Dllu [CC BY-SA 4.0 (]
[Iteration 0 - Perceptual Loss: 2.8845638e+07]
[Iteration 50 - Perceptual Loss: 15046891.0]
[Iteration 100 - Perceptual Loss: 12473184.0]
[Iteration 150 - Perceptual Loss: 12153522.0]
[Iteration 200 - Perceptual Loss: 11189293.0]
[Iteration 250 - Perceptual Loss: 10998533.0]
[Iteration 300 - Perceptual Loss: 11554881.0]
[Iteration 350 - Perceptual Loss: 10926209.0]
[Iteration 400 - Perceptual Loss: 12236732.0]
[Iteration 450 - Perceptual Loss: 11646822.0]
[Iteration 500 - Perceptual Loss: 10367661.0]
Image for post
Image for post

Closing thoughts

I’d like to take another pass at this and refactor it to be more “Swifty”. I’d also like to take a stab at implementing L-BFGS in Swift for TensorFlow. Hopefully I can contribute this stuff back to the Swift for TensorFlow community.

Welcome to a place where words matter. On Medium, smart voices and original ideas take center stage - with no ads in sight. Watch

Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore

Get unlimited access to the best stories on Medium — and support writers while you’re at it. Just $5/month. Upgrade

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store