Key-point detection in flower images using deep learning

Wouter de Winter
HackerNoon.com
7 min readSep 4, 2018

--

In this article we describe how we used a Convolutional Neural Network (CNN) to estimate the location of key-points in flower images. Key-points such as stem position and flower position are needed to render these images on a 3D model.

BloomyPro

First, let’s introduce our client: Bloomy. Their software platform BloomyPro allows users to design their bouquets in the browser using a 3D model. It is used by breeders, retailers, wholesale and suppliers in the flower industry.

Instead of creating a real physical bouquet, taking a photo of it and sending it off to the client they can execute this process completely online. This saves them a lot of time and money.

The BloomyPro User Interface

To be able to compete with photos of real bouquets, the images created have to be as photo-realistic as possible. This is achieved by using real photographs of flowers from many angles and rendering them on a 3D model.

For every new flower they take photos from 7 different angles. In the photo booth, the flowers are automatically rotated by a motor.

The flower photo booth

In contrast, the post-processing of the pictures is not completely automated yet. There are currently thousands of flowers in the database and new flowers are added every day. Multiply this by the number of angles and you get a lot of pictures to process manually!

One of the post-processing steps is to locate a few key-points on the images needed for the 3D model to attach to. The most important ones are stem position and flower top position. This is now done manually. Our solution is aimed at automating this step.

Dataset

Fortunately thousands of images are already manually annotated with key-points. So we’ve got plenty of training data to work with!

Annotated images at different angles

Above are a few annotated flowers from the training set. It shows the same flower at a few different angles. The stem position is in blue and the flower top position in green.

In some pictures the stem origin is hidden by the flower itself. In this case we need an ‘educated guess’ where the stem is most likely to be.

Example with hidden stem

Network architecture

Because the model has to output a number instead of a class we are essentially doing regression. CNN’s are best known for classification tasks but can also perform well on regression. For example DensePose does human pose estimation with a CNN based approach. Another example is this article about facial key-point detection.

I’m not going to explain the workings of convolutional networks in general, if you’re interested, you can read about CNN basics in this article:

The network begins with a few standard convolutional blocks. The blocks consist of 3 convolutional layers followed by a max-pooling, batch normalization and dropout layer.

  • The convolutional layers contain a number of filters. Each filter works like a pattern recognizer. Next convolutional blocks have more filters so it can find patterns inside patterns.
  • Max-pooling reduces the resolution of the image. This limits the amount of parameters in the model. Usually, with image classification, we’re not interested in where a certain object is located in the image, as long as it is there. In our case, we ARE interested in the location. Still, having a few max-pooling layers did not hurt performance.
  • Batch normalization layers help the model to train (converge) faster. In some deep networks, training fails completely without them.
  • Dropout randomly disables nodes and this prevents overfitting the model.

After the convolutional blocks we flatten the tensor so it becomes compatible with the dense layers. A global max-pooling or average max-pooling would also achieve a flat tensor but will lose all spatial information. Flattening worked better in our experiments, although it came at a (computational) cost of having more model parameters resulting in a longer training time.

After two dense hidden layers with Relu activation comes the output layer. We want to predict the x and y coordinates of the 2 key-points so we need to have 4 nodes in the output layer. The images can have different resolutions so we scale the coordinates to be between 0 and 1 and scale them back up before use.

The output layer has no activation function. Even though the target variables are between 0 and 1 this worked better for us than using a sigmoid.

For reference, here is the complete model summary from Keras, the Python deep learning library we used:

You might ask: why 3 convolutional layers? Or why 2 convolutional blocks?We included these numbers as hyperparameters in a hyperparameter search. Together with parameters such as: number of dense layers, dropout level, batch normalization and the number of convolutional filters we did a randomized search to find the optimal combination of hyperparameters.

And why randomized search instead of grid search? It’s a little counter intuitive but in practice this gives you better results for your money. See also this article about hyperparameter tuning.

For training we use the Adam optimizer with a learning rate of 0.005. The learning rate is automatically reduced when the validation loss is not improving for a few epochs.

As loss function we use Mean Square Error (MSE). Thus, large errors are punished relatively more than small errors.

Training and performance

These are the loss (error) plots after training for 50 epochs:

Loss plots

After about 8 epochs, the validation loss becomes higher than the training loss. The validation loss still decreases up to the end of training so we see no signs of the model strongly overfitting.

The final loss (MSE) on the test set was 0.0064. MSE can be quite unintuitive to interpret. Mean Average Error (MAE) is a bit easier to explain to humans.

The MAE is 0.0017This means that the predictions are on average 1.7% off

See below for a few examples of the test set. The white circles contain the target key-points and the filled circles our prediction. They are pretty close (overlapping) in most cases.

Some images from the test set

Deployment

The performance of the model is good enough to add value to the product. The key-points are now used to set default coordinates when uploading new flower images. In most cases no manual adjustment is needed!

The model itself is exposed via an API and packaged in a docker container. This container is built on push via bitbucket pipelines. The trained weights are also contained in the docker image. As you don’t want large files in Git we use Git LFS to store them.

Further improvements

We got some idea’s for improvement that we didn’t have time for yet to implement:

  1. Currently a single model is estimating both key-points. It might work better to train a specific model per key-point. This has the additional benefit that you can add new key-points later without having to retrain the complete model.
  2. Another idea is to take the angle of the photograph into account. For example by adding it as input of the dense layers. You could argue that the angle changes the nature of the task so providing this information might help the network. In this line of thought training a separate network for each angle could also be beneficial.

Next steps

The post processing process contains more steps besides setting key-points. For example setting the stem color. The 3D engine draws artificial stems matching the stem color of the photo. We expect that the same technique will work for this case too.

Conclusion

With this research, we proved the feasibility of using a CNN for detecting key-points in flower images. The methods used might also be applicable to post-processing tasks in other domains such a product photography.

Any questions? Let us know in the comments. If you liked the article, please hit the clap button so more people can read this story!

About Artificial Industry: We help entrepreneurs to change the world by transforming their ideas fast and efficient into successful online businesses. We do this by creating (data) prototypes and MVP’s for our clients.

--

--

Wouter de Winter
HackerNoon.com

Data Scientist - writes about Data and Machine learning