TDS Archive

An archive of data science, data analytics, data engineering, machine learning, and artificial intelligence writing from the former Towards Data Science Medium publication.

Style Transfer for Line Drawings

Generating Images From Line Drawings With ML

Ryan Rudes
TDS Archive
Published in
8 min readSep 6, 2020

--

Here, I’ll walk through a machine learning project I recently did in a tutorial-like manner. It is an approach to generating full images in an artistic style from line drawings.

Dataset

I trained on 10% of the Imagenet dataset. This is a dataset commonly used for benchmarks in computer vision tasks. The Imagenet dataset is not openly available; it is restricted to those undergoing research which requires use of it to compute performance benchmarks for comparing with other approaches. Therefore, it is typically required that you submit a request form. But if you are just using it casually, it is available here. I just wouldn’t use this for beyond anything beyond personal projects. Note that the dataset is very large, which is why I only used 1/10th of it to train my model. It consists of 1000 classes, so I used 100 of these image classes for training.

I used Imagenet for a different personal project a few weeks ago, so I already had a large collection of files in Google Drive. Unfortunately, however, it took approximately 20 hours to upload these 140,000 images or so to Google Drive. It is necessary to train the model on Google Colab’s online GPU, but this requires you to upload the images to Google Drive, as you aren’t hosting your coding environment locally.

Data Input Pipeline

I have a Colab Pro account, but even with the additional RAM, I certainly can’t handle 140,000 line drawing, each of 256x256 pixels in size, along with their 256x256 pixel colored counterparts. Hence, I have to load in the data on-the-go using a TensorFlow data input pipeline.

Before we start to set up the pipeline, let’s import the required libraries (these are all of the import statements in my code):

import matplotlib.pyplot as plt
import numpy as np
import cv2
from tqdm.notebook import tqdm
import glob
import random
import threading, queue
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.regularizers import *
from tensorflow.keras.utils import to_categorical
import tensorflow as tf

Now, let’s load the filepaths which refer to each image in our subset of Imagenet, assuming you have uploaded them to Drive under the appropriate directory structure and connection your Google Colab instance to your Google Drive.

filepaths = glob.glob("drive/My Drive/1-100/**/*.JPEG")# Shuffle the filepaths
random.shuffle(filepaths)

If you don’t want to use the glob module, you can use functions from theos library, which are often more efficient.

Here’s a few helper functions I need:

  • Normalizing data
  • Posterizing image data
def normalize(x):
return (x - x.min()) / (x.max() - x.min())

Posterization

The aforementioned process of posterization takes an image as input and transforms smooth gradients into more clearly-separated color sections by rounding color values to some nearest value. Here’s an example:

Posterization

As you can see, the resulting image has less smooth gradients, which are replaced with separated color sections. The reason I am implementing this is because I can limit the output images to a set of colors, allowing me to format the learning problem as a classification problem across each pixel in an image. For each available color, I assign a label. The model outputs an image of shape (height, width, num_colors) activated by a softmax function over the last channel, num_colors. Given a variable num_values, I allow all combinations of RGB where the color values are limited to np.arange(0, 255, 255 / num_values). This means that num_colors = num_values ** 3. Here’s an example:

Posterization

For an example of how I implemented this, here’s a demonstration:

def get_nearest_color(color, colors):
"""
Args:
- color: A vector of size 3 representing an RGB color
- colors: NumPy array of shape (num_colors, 3)
Returns:
- The index of the color in the provided set of colors that is
closest in value to the provided color
"""
return np.argmin([np.linalg.norm(color - c) for c in colors])def posterize_with_limited_colors(image, colors):
"""
Args:
- colors: NumPy array of shape (num_colors, 3)
Returns:
- Posterized image of shape (height, width, 1), where each value
is an integer label associated with a particular index of the
provided colors array
"""
image = normalize(image)
posterized = np.array([[get_nearest_color(x, colors) for x in y] for y in image])
return posterized

Edge Extraction

In order to create the input data from our colored images, we need a method of extracting edges from an image which are akin to a trace or line drawing.

We’ll be using the Canny edge detection algorithm. Let’s write our helper function, which inputs the path to an image and output the associated example/(X, Y) training pair, comprised of a posterization of the colored input, alongside the black and white edge extraction:

def preprocess(path):
color = cv2.imread(path)
color = cv2.resize(color, input_image_size)
# Assuming your pipelines generator function ignores None
if color.shape < input_image_size:
return None, None
color = (normalize(color) * 255).astype(np.uint8) gray = cv2.cvtColor(color, cv2.COLOR_RGB2GRAY
# Removes noise while preserving edges
filtered = cv2.bilateralFilter(gray, 3, 60, 120)
# Automatically determine threshold for edge detection algorithm
# based upon the median color value of the image
m = np.median(filtered)
preservation_factor = 0.33
low = max(0, int(m - 255 * preservation_factor))
high = int(min(255, m + 255 * preservation_factor))
filtered_edges = cv2.Canny(filtered, low, high)
filtered_edges = normalize(filtered_edges)
filtered_edges = np.expand_dims(filtered_edges, axis = -1)
color = cv2.resize(color, output_image_size)
color /= 255.
color = posterize_with_limited_colors(color, colors)
return filtered_edges, color

The automatic Canny edge detection is just my modification to the small function used in this article.

The Pipeline

As I said, I’m loading in data on-the-spot using an input pipeline. Therefore, I need to define a generator object to load in this data when needed. My generator function is simple because we basically just defined it. All it adds is filtering out the None outputs of the preprocess function (images of lower resolution than input_image_size and filtering out any results containing nan or inf values.

def generate_data(paths):
for path in paths:
edges, color = preprocess(path.decode())
if not edges is None:
if not np.any(np.isnan(edges)) or np.any(np.isnan(color)):
if not np.any(np.isinf(edges)) or np.any(np.isinf(color))):
# Yield the clean data
yield edges, color

I use (128, 128) for both input_image_size and output_image_size. A 128x128 pixel image isn’t that low-resolution, so there’s no significant disadvantage for our purposes. Also, Imagenet images are typically much higher resolution, so we can go higher if desired.

Now let’s build the pipeline. I’m using multithreading for improved speeds. TensorFlow’s.interleave()allows us to do this:

thread_names = np.arange(0, 8).astype(str)
dataset = tf.data.Dataset.from_tensor_slices(thread_names)
dataset = dataset.interleave(lambda x:
tf.data.Dataset.from_generator(
generate_data,
output_types = (tf.float32, tf.float32),
output_shapes = ((*input_image_size, 1),
(*output_image_size, 1)),
args = (train_paths,)),
cycle_length = 8,
block_length = 8,
num_parallel_calls = 8)
dataset = dataset.batch(batch_size).repeat()

Testing The Pipeline

Let’s load in a training example through our pipeline:

One training example with input line drawing/edges (right) and output colorization (left)

It’s exactly as desired. Note that the image depicted on the left is not exactly what was outputted by the pipeline. Recall that the pipeline is returning the index referring to the color of each pixel. I simply referred to each associated color to create the visualization. Here’s an example of one that came out much simpler.

Simpler training example

You’ll see that on the left we have the output, posterized color image, which partially resembles a painting. On the right, you see the input edge extraction, which resembles a sketch.

Of course not all training examples will have as good of an edge extraction than others. When the colors are more difficult to separate, the resulting outline might be a little noisy and/or scattered. However, this was the most accurate method for extracting edges I could think of.

Model Architecture

Let’s move on to the model architecture.

I begin at input_image_size = (128, 128), thus making the input of shape (128, 128, 1) after expanding the last axis. I decrease the layer input shape by a power of 2 until it equals 1. Then, I apply two more convolutional layers with stride = 1, because we can’t decrease the shape of the first two axes any further. Then, I perform the reverse with transposed layers. Each convolutional layer has padding = 'valid' and there is a batch normalization layer between each convolutional layer. All convolution layers have ReLU activation, except the last, which of course has softmax activation over the final one-hot-encoded color-label channel.

_________________________________________________________________ Layer (type)                 Output Shape              Param #    ================================================================= input_35 (InputLayer)        [(None, 128, 128, 1)]     0          _________________________________________________________________ conv2d_464 (Conv2D)          (None, 64, 64, 3)         30         _________________________________________________________________ batch_normalization_388 (Bat (None, 64, 64, 3)         12         _________________________________________________________________ conv2d_465 (Conv2D)          (None, 32, 32, 9)         252        _________________________________________________________________ batch_normalization_389 (Bat (None, 32, 32, 9)         36         _________________________________________________________________ conv2d_466 (Conv2D)          (None, 16, 16, 27)        2214       _________________________________________________________________ batch_normalization_390 (Bat (None, 16, 16, 27)        108        _________________________________________________________________ conv2d_467 (Conv2D)          (None, 8, 8, 81)          19764      _________________________________________________________________ batch_normalization_391 (Bat (None, 8, 8, 81)          324        _________________________________________________________________ conv2d_468 (Conv2D)          (None, 4, 4, 243)         177390     _________________________________________________________________ batch_normalization_392 (Bat (None, 4, 4, 243)         972        _________________________________________________________________ conv2d_469 (Conv2D)          (None, 2, 2, 729)         1595052    _________________________________________________________________ batch_normalization_393 (Bat (None, 2, 2, 729)         2916       _________________________________________________________________ conv2d_470 (Conv2D)          (None, 1, 1, 2187)        14351094   _________________________________________________________________ batch_normalization_394 (Bat (None, 1, 1, 2187)        8748       _________________________________________________________________ conv2d_471 (Conv2D)          (None, 1, 1, 2187)        43048908   _________________________________________________________________ batch_normalization_395 (Bat (None, 1, 1, 2187)        8748       _________________________________________________________________ conv2d_472 (Conv2D)          (None, 1, 1, 2187)        43048908   _________________________________________________________________ batch_normalization_396 (Bat (None, 1, 1, 2187)        8748       _________________________________________________________________ conv2d_transpose_229 (Conv2D (None, 1, 1, 2187)        43048908   _________________________________________________________________ batch_normalization_397 (Bat (None, 1, 1, 2187)        8748       _________________________________________________________________ conv2d_transpose_230 (Conv2D (None, 1, 1, 2187)        43048908   _________________________________________________________________ batch_normalization_398 (Bat (None, 1, 1, 2187)        8748       _________________________________________________________________ conv2d_transpose_231 (Conv2D (None, 2, 2, 2187)        43048908   _________________________________________________________________ batch_normalization_399 (Bat (None, 2, 2, 2187)        8748       _________________________________________________________________ conv2d_transpose_232 (Conv2D (None, 4, 4, 2187)        43048908   _________________________________________________________________ batch_normalization_400 (Bat (None, 4, 4, 2187)        8748       _________________________________________________________________ conv2d_transpose_233 (Conv2D (None, 8, 8, 729)         14349636   _________________________________________________________________ batch_normalization_401 (Bat (None, 8, 8, 729)         2916       _________________________________________________________________ conv2d_transpose_234 (Conv2D (None, 16, 16, 243)       1594566    _________________________________________________________________ batch_normalization_402 (Bat (None, 16, 16, 243)       972        _________________________________________________________________ conv2d_transpose_235 (Conv2D (None, 32, 32, 81)        177228     _________________________________________________________________ batch_normalization_403 (Bat (None, 32, 32, 81)        324        _________________________________________________________________ conv2d_transpose_236 (Conv2D (None, 64, 64, 27)        19710      _________________________________________________________________ up_sampling2d_1 (UpSampling2 (None, 128, 128, 27)      0          _________________________________________________________________ batch_normalization_404 (Bat (None, 128, 128, 27)      108        ================================================================= Total params: 290,650,308 Trainable params: 290,615,346 Non-trainable params: 34,962 _________________________________________________________________

Training

Let’s create some lists to store out metrics throughout training.

train_losses, train_accs = [], []

Also, a variable for the number of training epochs

epochs = 100

And here’s our training script

for epoch in range(epochs):
random.shuffle(filepaths)
history = model.fit(dataset,
steps_per_epoch = steps_per_epoch,
use_multiprocessing = True,
workers = 8,
max_queue_size = 10)
train_loss = np.mean(history.history["loss"])
train_acc = np.mean(history.history["accuracy"])
train_losses = train_losses + history.history["loss"]
train_accs = train_accs + history.history["accuracy"]
print ("Epoch: {}/{}, Train Loss: {:.6f}, Train Accuracy: {:.6f}, Val Loss: {:.6f}, Val Accuracy: {:.6f}".format(epoch + 1, epochs, train_loss, train_acc, val_loss, val_acc)) if epoch > 0:
fig = plt.figure(figsize = (10, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses)
plt.xlim(0, len(train_losses) - 1)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss")
plt.subplot(1, 2, 2)
plt.plot(train_accs)
plt.xlim(0, len(train_accs) - 1)
plt.ylim(0, 1)
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Accuracy")
plt.show()
model.save("model_{}.h5".format(epoch))
np.save("train_losses.npy", train_losses)
np.save("train_accs.npy", train_accs)

--

--

TDS Archive
TDS Archive

Published in TDS Archive

An archive of data science, data analytics, data engineering, machine learning, and artificial intelligence writing from the former Towards Data Science Medium publication.

Ryan Rudes
Ryan Rudes

Written by Ryan Rudes

I am a student at the California Institute of Technology, majoring in Electrical Engineering and Computer Science

Responses (1)