Toilet Finder: Using Transfer Learning to Create a Rapid Image Classification Pipeline
Don’t get me wrong. I have absolutely no problems with toilets. They are one of the best inventions of history and yet they aren’t really the best images to use to market a property!
So here is the problem: We have hundreds of thousands of listings, and our wonderful marketing team reported that from time to time an image of a toilet will appear in our marketing feed for CPC and email campaigns, which isn’t really ideal for us promoting our clients’ listings.
So, in the spirit of cross-functional problem solving, marketing asked us for help. “You’re the data people right? This should be an easy job, just find toilet images and exclude them from our marketing feed!”
From a Data Science standpoint, this isn’t a particularly difficult task. It’s a vanilla image classification problem where all we need to do is to train a binary classifier to tell if a given image containing a toilet/bathroom or not. We could get more fancy (considering the predicted demand from our marketing gurus 🔮) and train a multi-class classifier that would not only be used to pick a non-toilet image, but also will help us test which property image types would make the highest conversion rate, but at this stage finding toilets is the name of the game.
Data, data, and more data:
As with most cases of data-science-best-laid-plans, this turned out to be not as simple as it sounds.
First, in order to train our model we need a labelled dataset. And since we don’t ask our agents to label the property images they upload, we needed to manually label a good enough sample to train the model. Out of millions of property images our agents upload. We needed to decide which ones to include to represent an unbiased sample of our listing images. Selection bias can seriously undermine your model usability, even if your model has great results according to the loss function and the evaluation criteria you select.
Selecting a statistically representative sample in our case was a weighted sampling problem where we assigned higher relative weights to the most advertised locations. That is, locations with the highest number of active listings. Once we had those listings identified, we pass that list to pandas to sample from as follows:
This process is known as stratified sampling, which is simply sampling from a population that can be further grouped into sub-populations. In our case we grouped listing by location and type (“rent”, “buy”, “commercial”, … etc). We could expand this further by including the number of views, price brackets, etc. Here, data quality represents the most important part as we will be using a pre-trained model and only fine-tune the final layer (more on that shortly!)
We ended up selecting around 40,000 images among the millions we have to start training.
Selecting the model:
We use AWS and Google Cloud Platform to power our infrastructure. The first choice was to find out if we could use the machine learning APIs available in any of them. For example, Amazon Rekognition is a great platform and it could be really beneficial for many use cases. However there are two factors if you decide to rely on off-the-shelf machine learning APIs:
With millions of property images to process, the cost of using an API could become very significant. Why pay a few thousands of dollars a month while you could potentially train a model specific to your needs for a fraction of the cost? Especially when the accuracy of the results is not guaranteed!
You also will need to worry about how to integrate the API with your existing pipeline, extract (and possibly map) the predictions’ response according to your needs. And take care of moving your data if it lives in another environment or even in another AWS region. You see, the complexities can really add up as we consider bringing this to life, so it’s not really as simple as just calling an API.
Transfer Learning to the Rescue
I have explored both Keras/Tensorflow and Pytorch as the most popular machine learning frameworks, however I have a personal preference to use Pytorch. Torchvision has an amazing selection of datasets and models to learn from and use on your projects. It’s also a great place to start if you are a beginner and would like to focus on Computer Vision problems.
So first, what is Transfer Learning?
Transfer learning is the process of leveraging a model trained in one domain to be used as a starting point on another (related) domain.
We used Pytorch Resnet18; a Resnet pre-trained on the ILSVRC2015 thousand-classes Imagenet dataset. An ensemble of these models won the 1st place on the ILSVRC 2015 classification task. More information on this Kaggle dataset here.
What we did was just replace the final layer of a the model with a fully connected layer with the number of outputs matching the number of classes we wanted the model to predict.
In simpler words, we used the pre-trained model weights as a starting point, then fine-tuned it by only training the final layer on our own proprietary data.
If you have a basic understanding of Pytorch then you can follow the complete transfer learning tutorial here. You can also find more insight about the potential of building models in-house from this brilliant Airbnb post
Since we don’t have much labelled data as we mentioned earlier, we had to start the manual labeling process before even being able to train the model. A quick-start was to label the data in-house (and a huge shout-out to my colleague Jemma Battie for her help here!). We don’t need bounding boxes at the moment and this had made labeling as intuitive as just dragging image files to a few directories where each represented a class label. We then used Torchvision ImageFolder class to automatically pick those classes based on the directory structure as outlined in this code snippet.
The data transforms are just to make sure each image file in the training and validation datasets is resized and normalized as expected by the model.
We thankfully did not have a class imbalance problem so we used validation accuracy to measure the model performance. With 10 epochs of training we managed to get around 88% accuracy. Our data labelling process is still in progress and accuracy could improve with more data as long as the model still has room before it starts overfitting.
In other words, if your model did not overfit yet, it can still learn patterns from your data (by adding more data, introducing data augmentation, increasing model complexity etc.)
- To train an image classifier, you don’t need to start from scratch as long as you can leverage transfer learning to solve your problem.
- Collecting and processing data is usually more important than selecting your model. We could have obtained close results by using another pre-trained model. Possibly with more data and/or training iterations. Most of the time and effort spent on this problem was on data collection, labelling, and other infrastructure-related tasks.
Thanks for reading. Please feel free to comment, ask questions, or reach out.
That’s it! Thanks for reading!
If you liked this article and want to be part of our brilliant team of engineers that produced it, then have a look at our latest vacancies here.