Apples & Oranges: A Machine Learning Classifier

Mike Shi
ModelDepot
Published in
5 min readJun 12, 2018

--

Ever wanted to use image recognition, but didn’t want to be limited to just the fixed classes that come out of the box with pretrained models or cloud APIs? Today we’ll walk through how to train a ML model to recognize some fruits it’s never seen before with 95% accuracy!

A Brief Intro on Image Classification and Retraining

The ability for a computer to be able to analyze an image and tell you what’s in it (image classification), whether it’s a banana, apple, or hotdog is one of the most visible achievements in deep learning today.

Today we’ll be using ModelDepot Percept: a Docker image that allows you to quickly predict, retrain and deploy ML state of the art image classification into production for free, with a REST API. You can find the complete code here.

Step 1: Acquiring Data

In this guide, I’ll be using Fruits 360, a dataset of 32,000+ images of 65 different types of fruit, though we’ll just be using a small fraction of it to show how accurate we can be with a few examples!

Once we download the dataset, we’ll see that the Training data is already neatly packaged into a folder per class. We’ll simply use a short snippet of Python to organize our training data per class, and convert the data to Base64 to send it to our Percept instance.

Step 2: Exploring the /train and /predict Endpoint

Now that we have our data all set up, let’s take a look at the API docs around training and predicting. We’ll first want to queue up the data for training, using the /batch_queue endpoint. After that we can train the model by calling /train and then start making predictions with /predict!

/v1/batch_queue/{label_name}

The batch_train endpoint is fairly straightforward, it takes in the desired label name (ex. ‘banana’) in the URL and it takes in an array of images in the body of the post request.

The value defined in the image property of each object can be either a image URL accessible to the Docker container (gif, jpg or png), or a Base64 encoded string.

You can use the API explorer to try queueing a sample as follows:

However, we’ll be using a Python script to programmatically train the model for us, instead of through the API explorer.

/v1/train

Once we’ve queued up all of our data using batch_queue, we can tell the classifier to start training! Hitting this endpoint will kick off a training session and it’ll return with training information after it’s finished.

/v1/predict

The predict method only takes in an object with the property ‘image’. It’s exactly the same as a single image object being passed into /v1/batch_queue so we won’t repeat that here. We’ll see this endpoint in action as well in the following code section.

Step 3: Writing The Code

Now that we have the data and an understanding of the API, we’ll put together the code we’ll need to train and test our custom fruits classifier. In this example, we’ll be using Python, but the basic idea still holds across any language. See the final code here. We’ll walk through the major parts of the code here.

Loading Train/Validation Data

With the data loader helper function already defined from the first part of this guide, we’ll want to load training and validation data. The validation data is held out data not used to train the model so we can accurately assess its performance.

TRAIN_DIR = 'Fruit-Images-Dataset-master/Training'
VALIDATION_DIR = 'Fruit-Images-Dataset-master/Validation'
NUM_CLASSES = 15
TRAIN_SAMPLES = 20
VAL_SAMPLES = 20
_, all_label_names, _ = next(os.walk(TRAIN_DIR))label_names = all_label_names[:NUM_CLASSES]train = get_data(TRAIN_DIR, label_names, TRAIN_SAMPLES)
val = get_data(VALIDATION_DIR, label_names, VAL_SAMPLES)

Here we’ll be loading 20 training samples for each fruit, and 20 validation samples for each fruit, for 15 different types of fruit. Normally the train/val split would be closer to 80%/20% of the data, but since we’re demoing how few samples you’d need to train a model, we’re just doing a 50/50 split.

Training Our Model

Now with our training data split up, we can go ahead and train our model. We’ll write a helper method here to send the request.

def queue(label, samples):
return requests.post(
'http://localhost:8000/v1/batch_queue/%s' % label,
json={
'batch_images': [{'image': image} for image in samples]
})

This method simply makes a post request to the batch queue endpoint with a given label name, and then passes in the JSON encoded parameters for the image URLs.

With the helper method done, we just need 3 more lines of code to actually run the training:

for label in train:
print(queue(label, train[label]).json())
print(requests.post(‘http://localhost:8000/v1/train’).json())

The first part queues up all of our data samples, and the second part kicks off the training on the server. This can take a while as it learns and optimizes the model for 300 samples.

We can further just double check that our model has everything learned:

requests.get(‘http://localhost:8000/v1/train/status').json()

This should tell us we have 15 classes with trained on 300 images.

Testing the Model

Now that we’ve trained a model, we want to understand how well the model will do on totally new data that it’s never seen before. To do this we’re going to ask it to predict the images in our held out validation set and see what it’ll tell us.

Using scikit-learn’s accuracy metric, we can see our classifier does really well!

95% Validation Accuracy!

We can even examine deeper and look at the confusion matrix to see how often it misclassified certain types of fruits for other fruits. The strong diagonal line and sparse output elsewhere shows how well our classifier does.

Done!

We’ve successfully gone from nothing, to having our very own (and quite accurate) fruit classifier! If you’re interested in how all these internals work, we’ve talked about the underlying ML technology and some of our roadmap in this article here. We hope you’ve seen how easy it can be to leverage machine learning for a custom class, and hope that ModelDepot can help you use ML in your product.

Interested in trying ModelDepot Percept? Get started for free here! If you have any questions, don’t hesitate to reach us via the in-site chat on the bottom right or via email at hi@modeldepot.io.

--

--