How to Create a Pinterest Clone Part II: Image Classification

A basic image classification workflow with TensorFlow and the Python Imaging Library

Jim Chen
Geek Culture
8 min readJul 6, 2022

--

Overview and Objective

In the last blog, we built a web page that can upload images and created an API endpoint that can receive these images and store the files in the MongoDB and the metadata in TigerGraph. In this blog, we will integrate a basic image classification workflow built with TensorFlow and the Python Imaging Library into the API endpoint to tag the images and store the tags into TigerGraph. These tags can then be used for searching.

Section I: The Basic Image Classification Model with TensorFlow

The actual image classification algorithm used by Pinterest is complicated and beyond the scope of this blog. We will follow the instructions in this doc to train a simple neural network model that can classify images of clothing. It uses the Fashion MNIST dataset which contains 70,000 grayscale images (28 by 28 pixels) in 10 categories. Here are some examples:

Fashion-MNIST samples (by Zalando, MIT License).

We need to install TensorFlow first. Remember to activate the virtual environment so that it doesn’t mess up our local environment.

Then we can run this command to install TensorFlow (MacOS only, please refer to this doc for other systems).

Let’s start create a Python program at to train and save the model!

After importing the modules, we will load the Fashion-MNIST dataset.

There are 60,000 training images and 10,000 test images. Each image has a corresponding label that is a number from 0 to 9, representing a class name (e.g. 0 for T-shirt/top and 9 for ankle boot). To be more precise, the variable train_images is a numpy array of size 60000*28*28, which represents 60,000 images each with 28*28 pixels of value 0 to 255, and the variable train_labels is a numpy array of size 60000, where each element takes a value from 0 to 9, representing the class name of each image.

Since the neural network model works best with numbers from 0 to 1, we need to normalize the pixel values from 0–255 to 0–1.

Then, we can train and save the model.

Basically, we chain several layers to create the model, compile the model with some recommended parameters, fit (or train) the model, and save the model to the given path. For a detailed description of each component, please check out this doc. For a high-level conceptual understanding, I would highly recommend this video series by 3Blue1Brown.

We can now run the program.

It will take less than 30 seconds for a regular Mac. No GPU needed! We can evaluate the model with the test dataset by adding the following two lines to the end of the program.

The model will be used to predict the class (a number from 0 to 9) of each test image, and the predictions will be compared with the actual labels to get an accuracy, which will be around 90%. Although we use a small dataset and a simple model, the result is quite good!

The terminal output after we ran the program

We now have a basic image classification model! As a recap for this section, we put the following code at photo_library/backend/ML.py and ran it, which trained an image classification model and saved it at photo_library/backend/my_model.

Section II: Integrate the Image Classification Model into the server

Currently, we have an API endpoint that stores the user uploaded photo file into MongoDB, generates an object ID, and puts the object ID into TigerGraph.

Now we want to tag the photo with the model and add the metadata to the TigerGraph database. First, we need to turn the photo file into an Image object and process it with Python Imaging Library (PIL).

After we add the dependencies and code, we can run our project and try to upload an image.

Shirt Image from Pixabey
Shirt Image in a Pop-up Window

The photo is loaded as an Image object, and we need to transform it into a grayscale images of size 28 by 28 pixels.

Resized Shirt
Grayscale, 28 by 28 pixels image with pixel value from 0 to 255

Then, we convert the image pixels into a normalized Numpy array.

We load the trained model and use it to predict the class of the image.

After removing the code that shows image, let’s run it and see the predictions!

There are 10 numbers from 0 to 1, which represents the probability that the image belongs to a certain class. Recall that the 10 classes are these:

We can see that ‘Shirt’ has the highest probability of 0.74, and ‘T-shirt/top’ has the second highest probability of 0.26, which is reasonable!

As a recap, you can put the following code in photo_library/backend/main.py. The API endpoint will now print a prediction array for every uploaded photo.

Section III: Store the metadata in TigerGraph

In the last blog, we created the ‘Photo’ vertex in TigerGraph. Now we will add a ‘Type’ vertex and an edge ‘PHOTO_HAS_TYPE’ to store the metadata.

Let’s navigate to https://tgcloud.io/, log in with the credentials, and go to the ‘My Solutions’ tab.

Start the solution that we created, which will take around a minute, and open the GraphStudio.

Switch to ‘photos’ graph, navigate to ‘Design Schema’, and click on the ‘+’ button to add the ‘Type’ vertex with the primary id called ‘name’, which has type string and is stored as an attribute.

Then, we can click on the right arrow button to add the ‘PHOTO_HAS_TYPE’ edge. The configuration includes the source vertex ‘Photo’, the target vertex ‘Type’, and an attribute ‘probability’ of type float.

Finally, we click on the top left button to ‘Publish schema’, and our TigerGraph solution is good to go!

In main.py, we can update the function that takes a photo ID and inserts it into TigerGraph to add the photo’s metadata.

We insert the ‘Photo’ vertex with an object ID that refers to the photo file stored in MongoDB, add all the ‘Type’ vertexes, and add edges with a probability that indicates the possibility that each photo belongs to a class. We only need to insert the ‘Type’ vertexes once, so feel free to remove or comment out the two lines of code once ‘Type’ vertexes are inserted. We also include a threshold for possibility to reduce redundant edges.

Let’s use this function in the API endpoint, and the code would look like this:

If we run the project and upload the pink shirt photo again, we will see it in the MongoDB.

We can copy its object ID and search it in TigerGraph.

It is connected to ‘Shirt’ and ‘T-shirt’ with probability 0.74 and 0.26, which is exactly what we expected!

Section IV: Next Steps and Resources

That’s all for this part! All the code is uploaded here with instructions to run it. If you are especially interested in some technologies used in this blog, here are the links to their documentations: TigerGraph Cloud, MongoDB Setup, FastAPI, and Quasar.

We will wrap it up in the next part by completing the frontend with searching. Feel free to join the TigerGraph Discord and the Developer Forum if you have any questions!

--

--