A Non-Expert’s Guide to Image Segmentation Using Deep Neural Nets

Rohan Relan
Oct 30, 2017 · 8 min read

Given a single picture of a piece of furniture in context, can you automatically separate the furniture from the background?

In this post, I’ll walk through how we can use the current state-of-the-art in deep learning to try and solve this problem. I’m not an expert in machine learning myself, so my hope is that this post will be useful to other non-experts looking to use this powerful new tool.

This problem is called segmentation. Essentially, we want to go from this:

Our source image: a chair in context

to this:

We can apply this mask to the source image to get just the chair without the background

We’re going to use a few tools to make our lives easier. These are:

kerasAn awesome library for building neural networks. Keras is a front-end to lower level libraries like Tensorflow that handles a lot of the messy details of building neural networks for you. We won’t actually need to use Keras directly in this guide, but if you peek under the hood Keras is what you’ll see.

U-Net — A neural network architecture for image segmentation. U-Net was originally designed for biomedical image segmentation (eg. identifying lung nodules in a CT scan), but it also works for segmenting regular 2D images. As you’ll see, U-Net works surprisingly well even when you don’t have a large dataset.

brine (I’m one of the developers) A dataset manager to make it easy to share and manipulate image datasets. I’ve found that one of the most annoying parts of building a model is getting and wrangling the datasets we need to train our models. I created brine to easily share datasets and use them with PyTorch/Keras models. We’ll be using it to download the dataset and interface it with Keras so we won’t have to do any crufty data format wrangling ourselves.

this github repoThe Carvana Image Masking Challenge was a Kaggle competition posing a similar problem: segmenting out cars from their background. People often share their solutions to Kaggle competitions, and in this repo someone has helpfully shared a clean solution that uses Keras and U-Net. Our goal is going to be to repurpose this solution to solve our furniture segmentation problem.

this datasetThis is a dataset provided by the friend who posed this problem to me. Note that it’s a small dataset, containing only 97 images of chairs and the corresponding masks. Normally I wouldn’t expect to be able to do much with such limited data (the Carvana challenge provided thousands of examples) but let’s see how far we can get with it.


There’s a jupyter notebook available here that contains all the code to build the model. I’ll step through the important parts and explain what it’s doing.

Our first step is to install the dataset. Since it’s hosted on Brine, we can do this with a simple brine install rohan/chairs-with-masks

Our next step is to load the dataset. We do this via the Brine load_dataset function: chairs = brine.load_dataset(‘rohan/chairs-with-masks’). This dataset contains 97 samples, with each sample being a pair of the image and its mask. The mask is a two color image where blue represents the background and red represents the foreground.

Now that we have the dataset loaded, let’s load up the U-Net network. Copy over the model directory from the Kaggle-Carvana-Image-Masking-Challenge github repo so we have it available to us. After importing it, we can do model = unet.get_unet_256() . Thanks to petrosgk’s work, that single function calls returns a U-Net network built in Keras. Keras gives us model.summary() method that we can use to see the structure of the network. While there’s a ton of information in there, the most important lines to look at are the first and the last. These tell us the shapes of the inputs and outputs the network expects.

We can see that the input shape is (None, 256, 256, 3) and the output shape is (None, 256, 256, 1). The first element of the tuple is referring to the mini-batch size, so we can ignore that for now. This tells us the networks expects a batch of 256x256 3 channel images as an input, and will output a batch of 256x256 single channel masks. Our masks need to match this shape as well.

Our next step then is to prepare our samples so they can be used with this network. We’ll define a processing function for our training data which will be applied to each sample before the sample is passed to the network during training.

def fix_mask(mask):
mask[mask < 100] = 0.0
mask[mask >= 100] = 255.0
def train_process(sample):
img, mask = sample
img = img[:,:,:3]
mask = mask[:, :, :3]
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY)
img = cv2.resize(img, SIZE)
mask = cv2.resize(mask, SIZE)
img = randomHueSaturationValue(img,
hue_shift_limit=(-50, 50),
sat_shift_limit=(0, 0),
val_shift_limit=(-15, 15))
img, mask = randomShiftScaleRotate(img, mask,
shift_limit=(-0.062, 0.062),
scale_limit=(-0.1, 0.1),
rotate_limit=(-20, 20))
img, mask = randomHorizontalFlip(img, mask)
img = img/255.
mask = mask/255.
mask = np.expand_dims(mask, axis=2)
return (img, mask)

There’s a lot going on here, so I’ll go through it step by step. We’re passed in the sample as a tuple, so first we unpack that. The next two lines use numpy slicing to ensure that we only have 3 channel images — if there’s a fourth alpha channel, we’re ignoring it. Next, we convert the mask to grayscale using cv2 (python’s bindings to OpenCV), so we now have a single channel mask as our network expects. Instead of using two random grayscale numbers for the two colors, we force the mask to 0 and 255 to denote the background and foreground using the fix_mask function. We then resize the both the image and the mask to 256x256 to match the size expected by the network.

Since we don’t have a lot of data, we’re going to use data augmentation. Data augmentation refers to randomly modifying the image at training time in ways that preserve the information in order to artificially generate more data. For example, a chair rotated by 5 degrees is still a chair, so the network should be able to identify that correctly. In our code, we’re using the 3 functions from petrosgk’s Carvana example to randomly alter the hue, saturation, and value of the image, and to randomly rotate and flip the image. If we rotate or flip the image, we have to perform the same operation on the mask so that the mask stays aligned with the original image.

Finally, we normalize the data by dividing all the pixel values by 255, so our values are all between 0 and 1. If you were to print image.shape at this point, you’d see that it’s 256x256x3, exactly what our network needs. However, mask.shape is 256x256, whereas the network requires 256x256x1, so we use np.expand_dims to make the mask match this shape. Finally, we return the new image and mask pair.

Before we start training the network, we also need to set some of the samples aside to be used for validation. The validation fold won’t be used for training, we’ll only use the validation fold to check the performance of the model. We can easily create the folds using Brine’s create_folds:

validation_fold, train_fold = chairs.create_folds((20,))

We’ve asked for 20 samples in the validation fold, and the remaining 77 samples will be assigned to the training fold.

We also define a processing function for the validation samples. This is very similar to the processing function we use for training, except at validation time we won’t make unnecessarily make the network’s job harder by using data augmentation.

Finally, we ask Brine to return a generator we can use with Keras for both the train and validation folds:

train_generator = train_fold.to_keras(
validation_generator = validation_fold.to_keras(

This gives us generators that returns batches of samples processing with the processing functions we defined earlier. These generators can be passed directly in to Keras’ fit_generator method to train our model.

Now we’re ready to train the model. We’ll use the callbacks from petrosgk as well for our training:

callbacks = [EarlyStopping(monitor='val_loss',

These callbacks modify Keras training loop. EarlyStopping will stop training once it stops seeing improvement to the validation loss, ReduceLROnPlateau will drop the learning rate and ModelCheckpoint will save the version of the model that performs best on our validation set.

Finally, we can start training


Keras will train the model, running through the dataset multiple times (though each run will be slightly different because of data augmentation and shuffling) and output our losses and DICE score for our train and validation set. At some point, we’ll stop either because of our EarlyStopping callback or because we hit 100 epochs.

And that’s it! We’ve trained a U-Net network that tries to segment chairs out of an image. Let’s see how well it does in practice.

Generating Predictions

To generate some predictions, we’ll first load the weights of the best model using model.load_weights(‘weights/best_weights.hdf5’) and then use our model’spredict method on some of the images in our validation set:

def predict_one():
image_batch, mask_batch = next(validation_generator)
predicted_mask_batch = model.predict(image_batch)
image = image_batch[0]
predicted_mask = predicted_mask_batch[0].reshape(SIZE)
plt.imshow(predicted_mask, alpha=0.6)

Here’s a sampling of some of the results:

The original image with our network’s predicted mask overlaid. Yellower pixels means that the network has higher confidence in its prediction for that pixel

While this isn’t good enough for a production use-case yet, the network has learned something about foreground vs background and chair vs non-chair. Personally I find this quite amazing given that the network was trained on only 77 images without any pre-training. Given a larger dataset, it’s possible that the network will get to an accuracy that allows this to be used for a production use case, or at least for a first pass that can be further refined by a human.

That’s it! If you have any thoughts, questions or suggestions for how to improve this model please share them in the responses.

Rohan Relan

Written by

Looking for some help bringing up ML or other new technologies within your organization? Shoot me a note at rohan@rohanrelan.com

Welcome to a place where words matter. On Medium, smart voices and original ideas take center stage - with no ads in sight. Watch
Follow all the topics you care about, and we’ll deliver the best stories for you to your homepage and inbox. Explore
Get unlimited access to the best stories on Medium — and support writers while you’re at it. Just $5/month. Upgrade