Simplifying Rough Sketches using Deep Learning

Ashish Sinha
Coinmonks
Published in
6 min readJul 2, 2018

--

Sketching is the fundamental first step for expressing artistic ideas
and beginning an iterative process of design refinement. It allows
artists to quickly render their ideas on paper. The priority is to ex-
press concepts and ideas quickly, rather than exhibit fine details,
which leads to coarse and rough sketches. After an initial sketch,
feedback is used to iteratively refine the design until the final piece
is produced. This iterative refinement forces artists to have to con-
tinuously clean up their rough sketches into simplified drawings
and thus implies an additional workload. The process of manually
tracing the rough sketch to produce a clean drawing, as one would
expect, is fairly tedious and time-consuming.

So, wouldn’t it better, if there was a way to get clean sketches from our rough sketches, instantly, whatsoever pencil strokes it may have! Fascinating, isn’t it?In this article, I shall discuss about a Deep Learning technique that uses, fully convolutional networks to generate clean sketches from rough ones.

left : rough sketches right: clean sketch generated source: Original Paper, link in the footnotes.

Well, it’s not that there weren’t any softwares earlier to do this, there were, but the problem is that they used to work on vector images and not raster images. Lets start with what are vector and raster images!

Vector and Raster Images?

Raster images, also known as bitmaps, are comprised of individual pixels of color. Each color pixel contributes to the overall image.

Raster images might be compared to pointillist paintings, which are composed with a series of individually-colored dots of paint. Each paint dot in a pointillist painting might represent a single pixel in a raster image. When viewed as an individual dot, it’s just a color; but when viewed as a whole, the colored dots make up a vivid and detailed painting. The pixels in a raster image work in the same manner, which provides for rich details and pixel-by-pixel editing.

left: full picture(without zoom) right: on zooming the individual pixels can be seen

Unlike raster graphics, which are comprised of colored pixels arranged to display an image, vector graphics are made up of paths, each with a mathematical formula (vector) that tells the path how it is shaped and what color it is bordered with or filled by.

Since mathematical formulas dictate how the image is rendered, vector images retain their appearance regardless of size. They can be scaled infinitely.

the difference between vector and raster on scaling up

The Model Architecture

Model architecture source: original paper

The best part of this model is that, it works with Raster images, and converts multiple rough sketch lines into a single clean line.

multiple lines to clean single line

Another plus point, of this architecture is that, the image of any dimension can be fed into the network, and it outputs the image of same dimension as the input image.

The architecture is a rather simple one, the first part acts as an encoder and spatially compresses the image, the second part, processes and extracts the essential lines from the image, and the third and last part acts as a decoder which converts the small more simple representation to an grayscale image of the same resolution as the input. This is all done using convolutions.

The down- and up-convolution architecture may seem similar to
a simple filter banks. However, it is important to realize that the
number of channels is much larger where resolution is lower, e.g.,
1024 where the size is 1/8. This ensures that information that leads
to clean lines is carried through the low-resolution part
; the network
is trained to choose which information to carry by the encoder-
decoder architecture.

Padding is used to compensate for the kernel size and ensure the output is the same size as the input when a stride of 1 is used. Pooling layers are replaced by convolutional layers with increased strides to lower the resolution from the previous layer.

# The input dimensions can be replaced with the dimensions of the image.class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.downconv1 = torch.nn.Sequential(
torch.nn.Conv2d(1, 48, 5, 2, 2),
torch.nn.BatchNorm2d(48),
torch.nn.ReLU(),

torch.nn.Conv2d(48, 128, 3, 1, 1),
torch.nn.BatchNorm2d(128),
torch.nn.ReLU(),

torch.nn.Conv2d(128, 128, 3, 1, 1),
torch.nn.BatchNorm2d(128),
torch.nn.ReLU(),
)
self.downconv2 = torch.nn.Sequential(
torch.nn.Conv2d(128,256, 3, 2, 1),
torch.nn.BatchNorm2d(256),
torch.nn.ReLU(),

torch.nn.Conv2d(256, 256, 3, 1, 1),
torch.nn.BatchNorm2d(256),
torch.nn.ReLU(),

torch.nn.Conv2d(256, 256, 3, 1, 1),
torch.nn.BatchNorm2d(256),
torch.nn.ReLU(),
)
self.downconv3 = torch.nn.Sequential(
torch.nn.Conv2d(256, 256, 3, 2, 1),
torch.nn.BatchNorm2d(256),
torch.nn.ReLU(),

torch.nn.Conv2d(256, 512, 3, 1, 1),
torch.nn.BatchNorm2d(512),
torch.nn.ReLU(),

torch.nn.Conv2d(512, 1024, 3, 1, 1),
torch.nn.BatchNorm2d(1024),
torch.nn.ReLU(),
)
self.flat = torch.nn.Sequential(
torch.nn.Conv2d(1024, 1024, 3, 1, 1),
torch.nn.BatchNorm2d(1024),
torch.nn.ReLU(),

torch.nn.Conv2d(1024, 1024, 3, 1, 1),
torch.nn.BatchNorm2d(1024),
torch.nn.ReLU(),

torch.nn.Conv2d(1024, 1024, 3, 1, 1),
torch.nn.BatchNorm2d(1024),
torch.nn.ReLU(),

torch.nn.Conv2d(1024, 512, 3, 1, 1),
torch.nn.BatchNorm2d(512),
torch.nn.ReLU(),

torch.nn.Conv2d(512, 256, 3, 1, 1),
torch.nn.BatchNorm2d(256),
torch.nn.ReLU(),
)

self.upconv1 = torch.nn.Sequential(
# torch.nn.Conv2d(256, 256, 4, 0.5, 1),
torch.nn.ConvTranspose2d(256, 256, 4, 2, 1),
torch.nn.BatchNorm2d(256),
torch.nn.ReLU(),

torch.nn.Conv2d(256, 256, 3, 1, 1),
torch.nn.BatchNorm2d(256),
torch.nn.ReLU(),

torch.nn.Conv2d(256, 128, 3, 1, 1),
torch.nn.BatchNorm2d(128),
torch.nn.ReLU(),
)
self.upconv2 = torch.nn.Sequential(
# torch.nn.Conv2d(128, 128, 4, 0.5, 1),
torch.nn.ConvTranspose2d(128, 128, 4, 2, 1),
torch.nn.BatchNorm2d(128),
torch.nn.ReLU(),

torch.nn.Conv2d(128, 128, 3, 1, 1),
torch.nn.BatchNorm2d(128),
torch.nn.ReLU(),

torch.nn.Conv2d(128, 48, 3, 1, 1),
torch.nn.BatchNorm2d(48),
torch.nn.ReLU(),
)
self.upconv3 = torch.nn.Sequential(
# torch.nn.Conv2d(48, 48, 4, 0.5, 1),
torch.nn.ConvTranspose2d(48, 48,4, 2, 1),
torch.nn.BatchNorm2d(48),
torch.nn.ReLU(),

torch.nn.Conv2d(48, 24, 3, 1, 1),
torch.nn.BatchNorm2d(24),
torch.nn.ReLU(),

torch.nn.Conv2d(24, 1, 3, 1, 1),
torch.nn.Sigmoid (),
)


def forward(self, x):
conv1_out = self.downconv1(x)
conv2_out = self.downconv2(conv1_out)
conv3_out = self.downconv3(conv2_out)
flat_out = self.flat(conv3_out)
upconv1_out = self.upconv1(flat_out)
upconv2_out = self.upconv2(upconv1_out)
upconv3_out = self.upconv3(upconv2_out)
return upconv3_out
model summary where H and W are height and width of input image

The Loss Function

The model loss is calculated using the weighted mean square criterion,

loss function

where Y is the model output, Y* is the target output, M is the loss map, and element-wise matrix multiplication is performed on them, to calculate the loss. Now, as per testing out various loss maps by the authors of the paper, they found the one given below to perform better. The loss map reduces the loss on thicker lines in order to avoid having the model focus on the thicker lines and forego the thinner lines.

loss map

We construct our loss maps by looking at histograms around each pixel in the ground truth (target) label. H(I,u,v) is the value of the bin of the local normalized histogram in which the pixel I(u,v) falls.The histogram is constructed using all pixels within d_h pixels from the center using b_h bins.

visualization of the training of the model

Various data augmentation techniques are used to create the dataset since the number of images for training were quite less. Traditional transformations like, rotation,etc were employed but along with that Adobe Photoshop was used to change the tone, slur the image, and add noise, to create more samples.

This is a state of the art model which performs even better than Portrace and Adobe Live Trace.

References:

Torch Code: https://github.com/bobbens/sketch_simplification

Paper:http://hi.cs.waseda.ac.jp/~esimo/publications/SimoSerraSIGGRAPH2016.pdf

My Implementation: https://github.com/sinAshish/Rough-Sketch-Simplification-Using-FCNN

P.S: I’ll implement a pytorch version of the code, it’s just that it’s hard to get the dataset for the paper. While, the authors have provided the pre-trained weights, for their code.

--

--