Geek Culture
Published in

Geek Culture

Can I Create an Image Classifier Using Tensorflow to Identify the Theme of a Painting?

https://upload.wikimedia.org/wikipedia/commons/e/e4/Leonardo_Da_Vinci_-_Vergine_delle_Rocce_%28Louvre%29.jpg

Can I teach an AI model using Tensorflow Object Detection API, how to identify the theme of a painting?

That’s the general question I want to answer with this project.

But, I am going to narrow down my question further, for practical reasons.

I want to identify if a painting is about the Nativity or not. Which is a simpler binary classification problem.

But first, let me explain what a nativity painting is. A nativity painting is a painting where the subject is the birth of Jesus Christ, very revered in Christianity. During the last two thousand years, many famous artists like, Leonardo Da Vinci, Michelangelo, Caravaggio were commissioned to create paintings for Churches, so there should be plenty of paintings to pick about the nativity.

https://upload.wikimedia.org/wikipedia/commons/9/9d/Botticelli_-_Adoration_of_the_Magi_%28Zanobi_Altar%29_-_Uffizi.jpg

Before Starting

I created this notebook with some code from the Tensorflow tutorial on Image Classification. You can find the original tutorial in the link below:

https://www.tensorflow.org/tutorials/images/classification

Import TensorFlow and other libraries

import matplotlib.pyplot as plt 
import PIL
import tensorflow as tf
import os
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
import pandas as pd
import requests # to get image from the web
import shutil # to save it locally
import time
import numpy as np

Download Training, Validation and Test Image Data Sets

In order to train an image classifier, we need to have a training image dataset, validation dataset, and test dataset.

Since we are training a binary image classifier, we will have images for two different classes:

  • Nativity
  • Others

During the training of the model, we will use the training dataset to teach the model how to classify a painting: either a nativity painting or not(other).

At the end of each training cycle(epoch) we will use the validation data-set to score how well the model is doing by calculating the accuracy and the loss. The accuracy measures how many times our model gets the right answer. Higher is better. The loss measures the delta, i.e. the difference between the predicted value and the actual value. Lower is better.

It is important that the validation dataset is separate from the training dataset because the AI model is very good at cutting corners(i.e. cheating). If you don’t separate the two, the model will simply memorize the answers instead of learning the intrinsic characteristics of what we are trying to teach it.

At the very end of the training, we will also use a separate test dataset from the training and validation dataset, to do an independent benchmark of the model performance.

You will notice that we are downloading three files:

  • nativity_dataset.csv — contains all nativity paintings
  • other_dataset.csv — contains many paintings except nativity paintings
  • test_dataset.csv — contains labeled paintings

Wait a moment! Did I not just say that the training data set should be separate from the validation data set, so why keep it in the same files?

Yes, but because we are doing data exploration, it is a good thing to have some flexibility. Typically you are advised to have 80% of the training data and 20% of the validation data. But, this is not a hard and fast rule. We might want to change these percentages and see what gives us better results as part of our experimentation. This is also known as Hyperparameter tuning. On the other hand, the test data set should be fixed, so we can compare different models with different architectures in a consistent way.

We define below some utility functions to help download the images from our image dataset. Notice that getFileNameFromUrl() does some very basic cleanup and extraction of the filename in the url.

def getFileNameFromUrl(url):
firstpos=url.rindex("/")
lastpos=len(url)
filename=url[firstpos+1:lastpos]
print(f"url={url} firstpos={firstpos} lastpos={lastpos} filename={filename}")
return filename
def downloadImage(imageUrl, destinationFolder):
filename = getFileNameFromUrl(imageUrl)
# Open the url image, set stream to True, this will return the stream content.
r = requests.get(imageUrl, stream = True)
# Check if the image was retrieved successfully
if r.status_code == 200:
# Set decode_content value to True, otherwise the downloaded image file's size will be zero.
r.raw.decode_content = True
# Open a local file with wb ( write binary ) permission.
filePath = os.path.join(destinationFolder, filename)
if not os.path.exists(filePath):
with open(filePath,'wb') as f:
shutil.copyfileobj(r.raw, f)
print('Image sucessfully Downloaded: ',filename)
print("Sleeping for 1 seconds before attempting next download")
time.sleep(1)
else:
print(f'Skipping image {filename} as it is already Downloaded: ')
else:
print(f'Image url={imageUrl} and filename={filename} Couldn't be retreived. HTTP Status={r.status_code}')
df = pd.read_csv("nativity_dataset.csv")# create directory to which we download if it doesn't exist
destinationFolder = "/content/dataset/nativity"
os.makedirs(destinationFolder, exist_ok=True)
for i, row in df.iterrows():
print(f"Index: {i}")
print(f"{row['Image URL']}n")
downloadImage(row["Image URL"], destinationFolder)

Out:

Index: 0
https://d3d00swyhr67nd.cloudfront.net/w1200h1200/collection/LSE/CUMU/LSE_CUMU_TN07034-001.jpg
url=https://d3d00swyhr67nd.cloudfront.net/w1200h1200/collection/LSE/CUMU/LSE_CUMU_TN07034-001.jpg firstpos=68 lastpos=93 filename=LSE_CUMU_TN07034-001.jpg
Image sucessfully Downloaded: LSE_CUMU_TN07034-001.jpg
Sleeping for 1 seconds before attempting next download
Index: 1
https://d3d00swyhr67nd.cloudfront.net/w1200h1200/collection/GMIII/MCAG/GMIII_MCAG_1947_188-001.jpg

Resize All images to be no bigger than 90000 pixels(width x height)

Some of the images in our dataset are over 80MB in size. If we try to resize these images directly from Python, it will try to load the image into memory. Not a great idea. So we are going to use Imagemick to do the job super fast.

!apt install imagemagickReading package lists... Done
Building dependency tree
Reading state information... Done
imagemagick is already the newest version (8:6.9.7.4+dfsg-16ubuntu6.9).
0 upgraded, 0 newly installed, 0 to remove and 17 not upgraded.

Now we define the utility function resizeImages to resize images and copy from a sourceFolder to a destinationFolder.

def resizeImages(sourceFolder, destinationFolder, maxPixels=1048576):
os.makedirs(destinationFolder, exist_ok=True)
for path, subdirs, files in os.walk(sourceFolder):
relativeDir=path.replace(sourceFolder, "")
destinationFolderPath = destinationFolder + relativeDir
os.makedirs(destinationFolderPath,exist_ok=True)
for fileName in files:
sourceFilepath=os.path.join(path,fileName)
destinationFilepath=os.path.join(destinationFolderPath, fileName)
print(f"sourceFilepath={sourceFilepath} destinationFilepath={destinationFilepath}")
os.system(f"convert {sourceFilepath} -resize {maxPixels}@> {destinationFilepath}")
# resize training images
sourceFolder="/content/dataset"
destinationFolder = "/content/resized/dataset"
resizeImages(sourceFolder, destinationFolder, maxPixels=90000)
# resize testing images
sourceFolder="/content/test_dataset"
destinationFolder = "/content/resized/test_dataset"
resizeImages(sourceFolder, destinationFolder, maxPixels=90000)
sourceFilepath=/content/dataset/others/Quentin_Massys-The_Adoration_of_the_Magi-1526%2CMetropolitan_Museum_of_Art%2CNew_York.jpg destinationFilepath=/content/resized/dataset/others/Quentin_Massys-The_Adoration_of_the_Magi-1526%2CMetropolitan_Museum_of_Art%2CNew_York.jpg
...

Map image labels to numeric values

We are using Binary cross-entropy for our classification so we need to make sure our labels are either a 0 or a 1. Nativity = 1 and Others = 0

We will rename the folders to a 0 and a 1 since that is what tf.keras.preprocessing.image_dataset_from_directory uses to create the labels for our data set.


!mv /content/resized/dataset/nativity /content/resized/dataset/1
!mv /content/resized/dataset/others /content/resized/dataset/0
!mv /content/resized/test_dataset/nativity /content/resized/test_dataset/1
!mv /content/resized/test_dataset/others /content/resized/test_dataset/0

After downloading, we should now have a copy of the dataset available. There are 429 total images:

import pathlib
data_dir = pathlib.Path("/content/resized/dataset")
test_data_dir = pathlib.Path("/content/resized/test_dataset")
image_count = len(list(data_dir.glob('*/*')))
print(image_count)
454

Here are some paintings of the nativity:

nativity_label="1"
nativity = list(data_dir.glob(f'{nativity_label}/*'))
PIL.Image.open(str(nativity[0]))
PIL.Image.open(str(nativity[1]))

And some non-nativity paintings:

others_label="0"
others = list(data_dir.glob(f'{others_label}/*'))
PIL.Image.open(str(others[1]))

PIL.Image.open(str(others[2]))

Load using keras.preprocessing

Keras provides a bunch of really convenient functions to make our life easier when working with Tensorflow. tf.keras.preprocessing.image_dataset_from_directory is one of them. It loads images from the files into tf.data.DataSet format.

batch_size = 32
img_height = 300
img_width = 300

In general it is advised to split data into training data and validation data using a 80% 20% split. Remember, this is not a hard and fast rule.

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="training",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size, label_mode='binary')
Found 452 files belonging to 2 classes.
Using 362 files for training.
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
data_dir,
validation_split=0.2,
subset="validation",
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size,label_mode='binary')
Found 452 files belonging to 2 classes.
Using 90 files for validation.
#Retrieve a batch of images from the test set
test_data_dir = pathlib.Path("/content/resized/test_dataset")
test_batch_size=37
test_ds = tf.keras.preprocessing.image_dataset_from_directory(
test_data_dir,
seed=200,
image_size=(img_height, img_width),
batch_size=test_batch_size,label_mode='binary')
Found 37 files belonging to 2 classes.

You can find the class names in the class_names attribute on these datasets. These correspond to the directory names in alphabetical order.

class_names = train_ds.class_names
print(class_names)
['0', '1']

Visualize the data

Here are the first 9 images from the training dataset.

import matplotlib.pyplot as pltplt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
for i in range(9):
ax = plt.subplot(3, 3, i + 1)
plt.imshow(images[i].numpy().astype("uint8"))
if labels[i] == 1.0:
title = "Nativity"
else:
title = "Others"
plt.title(title)
plt.axis("off")

We inspect the image_batch and labels_batch variables.

The image_batch is a tensor of the shape (32, 300, 300, 3). This is a batch of 32 images of shape 300x300x3 (the last dimension refers to color channels RGB). The label_batch is a tensor of the shape (32,), these are corresponding labels to the 32 images.

You can call .numpy() on the image_batch and labels_batch tensors to convert them to a numpy.ndarray.

for image_batch, labels_batch in train_ds:
print(image_batch.shape)
print(labels_batch.shape)
break
(32, 300, 300, 3)
(32, 1)

Configure the dataset for performance

This code is taken directly from the Tensorflow tutorial and it is meant to help with the performance so we don’t always have to fetch images from disk.

AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

We define a utility function to measure the performance against our test dataset

labelMappings={"0":"Others","1":"Nativity",
0.0:"Others",1.0 :"Nativity"}
def predictWithTestDataset(model):
image_batch, label_batch = test_ds.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()
predictions = tf.where(predictions < 0.5, 0, 1) #print('Predictions:n', predictions.numpy())
#print('Labels:n', label_batch)
correctPredictions=0
plt.figure(figsize=(20, 20))
print(f"number predictions={len(predictions)}")
for i in range(len(predictions)):
ax = plt.subplot(8, 5, i +1)
plt.imshow(image_batch[i].astype("uint8"))
prediction = class_names[predictions[i]]
predictionLabel = labelMappings[prediction]
gtLabel = labelMappings[label_batch[i][0]]
if gtLabel == predictionLabel:
correctPredictions += 1
plt.title(f"P={predictionLabel} GT={gtLabel}")
plt.axis("off")
accuracy = correctPredictions/len(predictions)
print(f"Accuracy:{accuracy}")

Standardize the data

RGB is in the range of [0,255]. We normalize the values to between [0,1] which is the preferred way for a neural network

normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)

This normalization layer will be used later in the model definition.

Create the model

We define an initial architecture for our model, based on the architecture seen in the Tensorflow image classification model.

model = Sequential([
layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(1, activation='sigmoid')
])

Compile the model

For this tutorial, we choose the optimizers.Adam optimizer and losses.BinaryCrossentropy loss function. To view training and validation accuracy for each training epoch, we pass the metrics argument.

model.compile(optimizer='adam', loss=keras.losses.BinaryCrossentropy(from_logits=True), metrics=[keras.metrics.BinaryAccuracy()])

Model summary

View all the layers of the network using the model’s summary method:

model.summary()Model: "sequential_31"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
rescaling_27 (Rescaling) (None, 300, 300, 3) 0
_________________________________________________________________
conv2d_108 (Conv2D) (None, 300, 300, 16) 448
_________________________________________________________________
max_pooling2d_60 (MaxPooling (None, 150, 150, 16) 0
_________________________________________________________________
conv2d_109 (Conv2D) (None, 150, 150, 32) 4640
_________________________________________________________________
max_pooling2d_61 (MaxPooling (None, 75, 75, 32) 0
_________________________________________________________________
conv2d_110 (Conv2D) (None, 75, 75, 64) 18496
_________________________________________________________________
max_pooling2d_62 (MaxPooling (None, 37, 37, 64) 0
_________________________________________________________________
flatten_20 (Flatten) (None, 87616) 0
_________________________________________________________________
dense_52 (Dense) (None, 128) 11214976
_________________________________________________________________
dense_53 (Dense) (None, 1) 129
=================================================================
Total params: 11,238,689
Trainable params: 11,238,689
Non-trainable params: 0
_________________________________________________________________

Train the model

epochs=10
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
Epoch 1/10
12/12 [==============================] - 1s 85ms/step - loss: 1.9371 - binary_accuracy: 0.5107 - val_loss: 0.7001 - val_binary_accuracy: 0.4444
Epoch 2/10
12/12 [==============================] - 1s 49ms/step - loss: 0.6491 - binary_accuracy: 0.6737 - val_loss: 0.7258 - val_binary_accuracy: 0.4778
Epoch 3/10
12/12 [==============================] - 1s 49ms/step - loss: 0.5943 - binary_accuracy: 0.6958 - val_loss: 0.7169 - val_binary_accuracy: 0.5333
Epoch 4/10
12/12 [==============================] - 1s 49ms/step - loss: 0.5111 - binary_accuracy: 0.7762 - val_loss: 0.7201 - val_binary_accuracy: 0.5667
Epoch 5/10
12/12 [==============================] - 1s 49ms/step - loss: 0.4013 - binary_accuracy: 0.8427 - val_loss: 0.6920 - val_binary_accuracy: 0.5667
Epoch 6/10
12/12 [==============================] - 1s 49ms/step - loss: 0.3027 - binary_accuracy: 0.8921 - val_loss: 0.8354 - val_binary_accuracy: 0.5889
Epoch 7/10
12/12 [==============================] - 1s 50ms/step - loss: 0.2438 - binary_accuracy: 0.9049 - val_loss: 0.8499 - val_binary_accuracy: 0.5778
Epoch 8/10
12/12 [==============================] - 1s 49ms/step - loss: 0.1725 - binary_accuracy: 0.9292 - val_loss: 0.9742 - val_binary_accuracy: 0.5222
Epoch 9/10
12/12 [==============================] - 1s 50ms/step - loss: 0.2792 - binary_accuracy: 0.8878 - val_loss: 0.9390 - val_binary_accuracy: 0.5222
Epoch 10/10
12/12 [==============================] - 1s 50ms/step - loss: 0.1347 - binary_accuracy: 0.9658 - val_loss: 0.9914 - val_binary_accuracy: 0.5889

Visualize training results

print(history.history)
acc = history.history['binary_accuracy']
val_acc = history.history['val_binary_accuracy']
# acc = history.history['accuracy']
# val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
{'loss': [1.5215493440628052, 0.6543726325035095, 0.5897634625434875, 0.5006453990936279, 0.39839598536491394, 0.2903604209423065, 0.22604547441005707, 0.22543807327747345, 0.2558016777038574, 0.14820142090320587], 'binary_accuracy': [0.5580110549926758, 0.6491712927818298, 0.7099447250366211, 0.7845304012298584, 0.8425414562225342, 0.8867403268814087, 0.9116021990776062, 0.9060773253440857, 0.9033148884773254, 0.9558011293411255], 'val_loss': [0.7000985741615295, 0.7257931232452393, 0.7169376611709595, 0.7200638055801392, 0.6920430660247803, 0.8354127407073975, 0.8498525619506836, 0.9741556644439697, 0.9390344619750977, 0.9914490580558777], 'val_binary_accuracy': [0.4444444477558136, 0.47777777910232544, 0.5333333611488342, 0.5666666626930237, 0.5666666626930237, 0.5888888835906982, 0.5777778029441833, 0.5222222208976746, 0.5222222208976746, 0.5888888835906982]}

Looking at the plots, we are seeing a typical sign of overfitting. Overfitting happens when the model fits a bit too much with the training data but does poorly against the validation data. Notice that the accuracy increases along with the epochs for the training accuracy but with the validation data, the accuracy doesn’t increase, and in this case, the loss increases.

predictWithTestDataset(model)number predictions=37
Accuracy:0.6216216216216216

Data augmentation

Overfitting generally occurs when there are a small number of training examples. It is not surprising as we only have a total of 455 images across two classes. So we need to find a way to generate more training data.

Keras has some really easy-to-use transformation functions for data augmentation on images, that actually produce very decent images.

The idea of image augmentation, a form of regularization, is to make it harder for the model to overfit. Overfitting happens easily when there is very little data. We introduce variability and hopefully, it should help the model to generalise better.

The keras functions for image augmentation are available under tf.keras.layers.experimental.preprocessing.

We create a layer with the image augmentation transformations and we can include it in the creation of the model, just like any other layer.

For the image augmentation we do a random horizontal flip and a simple rotation.

data_augmentation = keras.Sequential(
[
layers.experimental.preprocessing.RandomFlip("horizontal",
input_shape=(img_height,
img_width,
3)),
layers.experimental.preprocessing.RandomRotation(0.1)
])

Let’s visualize what a few augmented examples look like by applying data augmentation to the same image several times:

plt.figure(figsize=(10, 10))
for images, _ in train_ds.take(1):
for i in range(9):
augmented_images = data_augmentation(images)
ax = plt.subplot(3, 3, i + 1)
plt.imshow(augmented_images[0].numpy().astype("uint8"))
plt.axis("off")

We will use data augmentation to train a model in a moment.

Dropout

Another technique to reduce overfitting is to introduce Dropout to the network, a form of regularization. Think of it as a random culling of neurons in the neural network, during training. We want to prevent the model from depending on one feature alone. The culling of each layer is based on a fraction(0.2) that we pass as an argument to the Dropout function.

Let’s create a new neural network using layers.Dropout, then train it using augmented images.

model = Sequential([
data_augmentation,
layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),
layers.Conv2D(16, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(32, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, padding='same', activation='relu'),
layers.MaxPooling2D(),
layers.Dropout(0.2),
layers.Flatten(),
layers.Dense(128, activation='relu'),
layers.Dense(1, activation='sigmoid')
])

Compile and train the model

from tensorflow import optimizers
model.compile(loss=keras.losses.BinaryCrossentropy(from_logits=True),
optimizer=optimizers.RMSprop(lr=1e-4),
metrics=[keras.metrics.BinaryAccuracy()])
model.summary()Model: "sequential_34"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
sequential_32 (Sequential) (None, 300, 300, 3) 0
_________________________________________________________________
rescaling_29 (Rescaling) (None, 300, 300, 3) 0
_________________________________________________________________
conv2d_114 (Conv2D) (None, 300, 300, 16) 448
_________________________________________________________________
max_pooling2d_66 (MaxPooling (None, 150, 150, 16) 0
_________________________________________________________________
conv2d_115 (Conv2D) (None, 150, 150, 32) 4640
_________________________________________________________________
max_pooling2d_67 (MaxPooling (None, 75, 75, 32) 0
_________________________________________________________________
conv2d_116 (Conv2D) (None, 75, 75, 64) 18496
_________________________________________________________________
max_pooling2d_68 (MaxPooling (None, 37, 37, 64) 0
_________________________________________________________________
dropout_26 (Dropout) (None, 37, 37, 64) 0
_________________________________________________________________
flatten_22 (Flatten) (None, 87616) 0
_________________________________________________________________
dense_56 (Dense) (None, 128) 11214976
_________________________________________________________________
dense_57 (Dense) (None, 1) 129
=================================================================
Total params: 11,238,689
Trainable params: 11,238,689
Non-trainable params: 0
_________________________________________________________________
epochs = 25
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
Epoch 1/25
12/12 [==============================] - 2s 76ms/step - loss: 0.9417 - binary_accuracy: 0.5580 - val_loss: 0.7153 - val_binary_accuracy: 0.5222
Epoch 2/25
12/12 [==============================] - 1s 61ms/step - loss: 0.6869 - binary_accuracy: 0.5338 - val_loss: 0.7236 - val_binary_accuracy: 0.5333
Epoch 3/25
12/12 [==============================] - 1s 61ms/step - loss: 0.6557 - binary_accuracy: 0.5985 - val_loss: 0.8124 - val_binary_accuracy: 0.5222
Epoch 4/25
12/12 [==============================] - 1s 61ms/step - loss: 0.6447 - binary_accuracy: 0.6315 - val_loss: 0.6829 - val_binary_accuracy: 0.5556
Epoch 5/25
12/12 [==============================] - 1s 65ms/step - loss: 0.6482 - binary_accuracy: 0.6273 - val_loss: 0.6708 - val_binary_accuracy: 0.5778
Epoch 6/25
12/12 [==============================] - 1s 61ms/step - loss: 0.6482 - binary_accuracy: 0.6348 - val_loss: 0.6733 - val_binary_accuracy: 0.5556
Epoch 7/25
12/12 [==============================] - 1s 61ms/step - loss: 0.6325 - binary_accuracy: 0.6592 - val_loss: 0.6762 - val_binary_accuracy: 0.5333
Epoch 8/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5994 - binary_accuracy: 0.6680 - val_loss: 0.6587 - val_binary_accuracy: 0.6111
Epoch 9/25
12/12 [==============================] - 1s 61ms/step - loss: 0.6204 - binary_accuracy: 0.6904 - val_loss: 0.7240 - val_binary_accuracy: 0.5333
Epoch 10/25
12/12 [==============================] - 1s 62ms/step - loss: 0.6343 - binary_accuracy: 0.6480 - val_loss: 0.6776 - val_binary_accuracy: 0.5667
Epoch 11/25
12/12 [==============================] - 1s 62ms/step - loss: 0.6439 - binary_accuracy: 0.6107 - val_loss: 0.6811 - val_binary_accuracy: 0.5556
Epoch 12/25
12/12 [==============================] - 1s 62ms/step - loss: 0.6361 - binary_accuracy: 0.6301 - val_loss: 0.6612 - val_binary_accuracy: 0.6222
Epoch 13/25
12/12 [==============================] - 1s 62ms/step - loss: 0.6025 - binary_accuracy: 0.6949 - val_loss: 0.6725 - val_binary_accuracy: 0.5778
Epoch 14/25
12/12 [==============================] - 1s 61ms/step - loss: 0.5977 - binary_accuracy: 0.6868 - val_loss: 0.7521 - val_binary_accuracy: 0.5444
Epoch 15/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5713 - binary_accuracy: 0.6833 - val_loss: 0.6427 - val_binary_accuracy: 0.6444
Epoch 16/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5918 - binary_accuracy: 0.6939 - val_loss: 0.6515 - val_binary_accuracy: 0.6333
Epoch 17/25
12/12 [==============================] - 1s 61ms/step - loss: 0.5831 - binary_accuracy: 0.7253 - val_loss: 0.6556 - val_binary_accuracy: 0.5889
Epoch 18/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5626 - binary_accuracy: 0.7121 - val_loss: 0.6877 - val_binary_accuracy: 0.5667
Epoch 19/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5476 - binary_accuracy: 0.7327 - val_loss: 0.6398 - val_binary_accuracy: 0.6556
Epoch 20/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5551 - binary_accuracy: 0.7283 - val_loss: 0.6465 - val_binary_accuracy: 0.6333
Epoch 21/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5436 - binary_accuracy: 0.7312 - val_loss: 0.7083 - val_binary_accuracy: 0.5667
Epoch 22/25
12/12 [==============================] - 1s 65ms/step - loss: 0.5987 - binary_accuracy: 0.6781 - val_loss: 0.8078 - val_binary_accuracy: 0.5222
Epoch 23/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5534 - binary_accuracy: 0.7139 - val_loss: 0.6705 - val_binary_accuracy: 0.6111
Epoch 24/25
12/12 [==============================] - 1s 85ms/step - loss: 0.5617 - binary_accuracy: 0.7406 - val_loss: 0.6471 - val_binary_accuracy: 0.6111
Epoch 25/25
12/12 [==============================] - 1s 62ms/step - loss: 0.5541 - binary_accuracy: 0.7303 - val_loss: 0.6263 - val_binary_accuracy: 0.7000

Visualize training results

After applying data augmentation and Dropout, there is less overfitting than before, and training and validation accuracy are closely aligned.

print(history.history)
acc = history.history['binary_accuracy']
val_acc = history.history['val_binary_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
{'loss': [0.8289327025413513, 0.6810811758041382, 0.6626855731010437, 0.6704213619232178, 0.650795042514801, 0.6398268938064575, 0.6561762690544128, 0.6122907400131226, 0.6228107810020447, 0.6147234439849854, 0.6149318814277649, 0.6190508604049683, 0.607336699962616, 0.5861756801605225, 0.593088686466217, 0.6063793301582336, 0.5983501672744751, 0.5894269347190857, 0.5698645114898682, 0.5585014224052429, 0.5401753783226013, 0.5774908065795898, 0.5512883067131042, 0.5710932016372681, 0.5434897541999817], 'binary_accuracy': [0.5386740565299988, 0.5386740565299988, 0.6077347993850708, 0.6160221099853516, 0.6022099256515503, 0.6325966715812683, 0.6325966715812683, 0.6685082912445068, 0.6961326003074646, 0.6602209806442261, 0.6408839821815491, 0.6574585437774658, 0.6712707281112671, 0.6933701634407043, 0.6574585437774658, 0.6767956018447876, 0.7044199109077454, 0.6933701634407043, 0.7154695987701416, 0.7265193462371826, 0.7265193462371826, 0.6906077265739441, 0.7154695987701416, 0.7071823477745056, 0.7375690340995789], 'val_loss': [0.7153488993644714, 0.7235575318336487, 0.8124216794967651, 0.6829271912574768, 0.6708189249038696, 0.673344612121582, 0.676236629486084, 0.6586815714836121, 0.7239749431610107, 0.677582323551178, 0.6810950636863708, 0.6611502170562744, 0.6725294589996338, 0.7520950436592102, 0.642659068107605, 0.6514749526977539, 0.6556094884872437, 0.687703013420105, 0.639808714389801, 0.6464514136314392, 0.7082778811454773, 0.8077911138534546, 0.670492947101593, 0.6470986008644104, 0.6263118386268616], 'val_binary_accuracy': [0.5222222208976746, 0.5333333611488342, 0.5222222208976746, 0.5555555820465088, 0.5777778029441833, 0.5555555820465088, 0.5333333611488342, 0.6111111044883728, 0.5333333611488342, 0.5666666626930237, 0.5555555820465088, 0.6222222447395325, 0.5777778029441833, 0.5444444417953491, 0.644444465637207, 0.6333333253860474, 0.5888888835906982, 0.5666666626930237, 0.6555555462837219, 0.6333333253860474, 0.5666666626930237, 0.5222222208976746, 0.6111111044883728, 0.6111111044883728, 0.699999988079071]}

Predict on new data

Finally, let’s use our model to classify images on our test dataset that weren’t included in the training or validation sets.

predictWithTestDataset(model)number predictions=37
Accuracy:0.7027027027027027

Transfer Learning

We have already used image augmentation to try and get better results from our model, and I have to say the results were not bad at all. We were able to get a model with almost 70% accuracy on the validation dataset. Surely, we can do better than that, if we are able to collect hundreds more, perhaps thousands of more images for our training and validation data set.

We can certainly do that, but there is another way that doesn’t involve the tedious and expensive process of collecting more training data: Transfer Learning.

With transfer learning, we can borrow a model that is already trained against thousands of images and re-train it for our use case, but with much fewer images than it would have been possible to if we trained a model from scratch.

To do so we can use Keras to download a pre-trained model with the Xception architecture already trained on Imagenet.

To perform transfer learning we need to freeze the weights of the base model and perform the training as we normally would. You will notice that we still do the image augmentation and regularization.

base_model = keras.applications.Xception(
weights="imagenet", # Load weights pre-trained on ImageNet.
input_shape=(img_height, img_width, 3),
include_top=False,
) # Do not include the ImageNet classifier at the top.
# Freeze the base_model
base_model.trainable = False
# Create new model on top
inputs = keras.Input(shape=(img_height, img_width, 3))
x = data_augmentation(inputs) # Apply random data augmentation
# Pre-trained Xception weights requires that input be normalized
# from (0, 255) to a range (-1., +1.), the normalization layer
# does the following, outputs = (inputs - mean) / sqrt(var)
norm_layer = keras.layers.experimental.preprocessing.Normalization()
mean = np.array([127.5] * 3)
var = mean ** 2
# Scale inputs to [-1, +1]
x = norm_layer(x)
norm_layer.set_weights([mean, var])
# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x) # Regularize with dropout
outputs = keras.layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs, outputs)
model.summary()Model: "model_13"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_28 (InputLayer) [(None, 300, 300, 3)] 0
_________________________________________________________________
sequential_32 (Sequential) (None, 300, 300, 3) 0
_________________________________________________________________
normalization_13 (Normalizat (None, 300, 300, 3) 7
_________________________________________________________________
xception (Functional) (None, 10, 10, 2048) 20861480
_________________________________________________________________
global_average_pooling2d_13 (None, 2048) 0
_________________________________________________________________
dropout_28 (Dropout) (None, 2048) 0
_________________________________________________________________
dense_59 (Dense) (None, 1) 2049
=================================================================
Total params: 20,863,536
Trainable params: 2,049
Non-trainable params: 20,861,487
_________________________________________________________________
# model.compile(optimizer=keras.optimizers.Adam(),
# loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
# metrics=['accuracy'])
model.compile(
optimizer=keras.optimizers.Adam(),
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 25
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=epochs
)
Epoch 1/25
12/12 [==============================] - 7s 376ms/step - loss: 0.6989 - binary_accuracy: 0.5116 - val_loss: 0.6324 - val_binary_accuracy: 0.5333
Epoch 2/25
12/12 [==============================] - 4s 322ms/step - loss: 0.6106 - binary_accuracy: 0.5943 - val_loss: 0.5748 - val_binary_accuracy: 0.6222
Epoch 3/25
12/12 [==============================] - 4s 322ms/step - loss: 0.5557 - binary_accuracy: 0.6647 - val_loss: 0.5378 - val_binary_accuracy: 0.6889
Epoch 4/25
12/12 [==============================] - 4s 326ms/step - loss: 0.5280 - binary_accuracy: 0.6333 - val_loss: 0.5127 - val_binary_accuracy: 0.7222
Epoch 5/25
12/12 [==============================] - 4s 329ms/step - loss: 0.4751 - binary_accuracy: 0.7638 - val_loss: 0.4912 - val_binary_accuracy: 0.7889
Epoch 6/25
12/12 [==============================] - 4s 331ms/step - loss: 0.4586 - binary_accuracy: 0.7535 - val_loss: 0.4775 - val_binary_accuracy: 0.7556
Epoch 7/25
12/12 [==============================] - 4s 335ms/step - loss: 0.4328 - binary_accuracy: 0.7778 - val_loss: 0.4625 - val_binary_accuracy: 0.8111
Epoch 8/25
12/12 [==============================] - 4s 339ms/step - loss: 0.3951 - binary_accuracy: 0.8387 - val_loss: 0.4519 - val_binary_accuracy: 0.8111
Epoch 9/25
12/12 [==============================] - 4s 344ms/step - loss: 0.3745 - binary_accuracy: 0.8427 - val_loss: 0.4435 - val_binary_accuracy: 0.8111
Epoch 10/25
12/12 [==============================] - 4s 348ms/step - loss: 0.3631 - binary_accuracy: 0.8373 - val_loss: 0.4395 - val_binary_accuracy: 0.7889
Epoch 11/25
12/12 [==============================] - 4s 350ms/step - loss: 0.3449 - binary_accuracy: 0.8705 - val_loss: 0.4302 - val_binary_accuracy: 0.8111
Epoch 12/25
12/12 [==============================] - 4s 355ms/step - loss: 0.3409 - binary_accuracy: 0.8623 - val_loss: 0.4249 - val_binary_accuracy: 0.8222
Epoch 13/25
12/12 [==============================] - 4s 356ms/step - loss: 0.3491 - binary_accuracy: 0.8848 - val_loss: 0.4214 - val_binary_accuracy: 0.8333
Epoch 14/25
12/12 [==============================] - 4s 356ms/step - loss: 0.3522 - binary_accuracy: 0.8569 - val_loss: 0.4173 - val_binary_accuracy: 0.8333
Epoch 15/25
12/12 [==============================] - 4s 354ms/step - loss: 0.3106 - binary_accuracy: 0.8641 - val_loss: 0.4120 - val_binary_accuracy: 0.8333
Epoch 16/25
12/12 [==============================] - 4s 348ms/step - loss: 0.3108 - binary_accuracy: 0.8973 - val_loss: 0.4059 - val_binary_accuracy: 0.8333
Epoch 17/25
12/12 [==============================] - 4s 348ms/step - loss: 0.3041 - binary_accuracy: 0.8840 - val_loss: 0.4043 - val_binary_accuracy: 0.8333
Epoch 18/25
12/12 [==============================] - 4s 364ms/step - loss: 0.3106 - binary_accuracy: 0.8548 - val_loss: 0.3994 - val_binary_accuracy: 0.8444
Epoch 19/25
12/12 [==============================] - 4s 343ms/step - loss: 0.3072 - binary_accuracy: 0.8774 - val_loss: 0.4031 - val_binary_accuracy: 0.8333
Epoch 20/25
12/12 [==============================] - 4s 341ms/step - loss: 0.3008 - binary_accuracy: 0.8870 - val_loss: 0.3960 - val_binary_accuracy: 0.8444
Epoch 21/25
12/12 [==============================] - 4s 342ms/step - loss: 0.2959 - binary_accuracy: 0.8738 - val_loss: 0.3969 - val_binary_accuracy: 0.8444
Epoch 22/25
12/12 [==============================] - 4s 340ms/step - loss: 0.2655 - binary_accuracy: 0.8874 - val_loss: 0.3959 - val_binary_accuracy: 0.8444
Epoch 23/25
12/12 [==============================] - 4s 340ms/step - loss: 0.2452 - binary_accuracy: 0.9098 - val_loss: 0.3957 - val_binary_accuracy: 0.8444
Epoch 24/25
12/12 [==============================] - 4s 359ms/step - loss: 0.2532 - binary_accuracy: 0.9214 - val_loss: 0.3906 - val_binary_accuracy: 0.8444
Epoch 25/25
12/12 [==============================] - 4s 340ms/step - loss: 0.2512 - binary_accuracy: 0.9202 - val_loss: 0.3963 - val_binary_accuracy: 0.8556
print(history.history)
acc = history.history['binary_accuracy']
val_acc = history.history['val_binary_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
{'loss': [0.6730840802192688, 0.6024077534675598, 0.5369390249252319, 0.5121317505836487, 0.47049736976623535, 0.43795397877693176, 0.4308575689792633, 0.409035861492157, 0.39621278643608093, 0.38378414511680603, 0.36360302567481995, 0.3390505313873291, 0.33790066838264465, 0.34375235438346863, 0.3241599202156067, 0.3240811824798584, 0.30652204155921936, 0.3121297359466553, 0.29941326379776, 0.31102144718170166, 0.29046544432640076, 0.2721157371997833, 0.2742222845554352, 0.273193895816803, 0.2644941210746765], 'binary_accuracy': [0.541436493396759, 0.580110490322113, 0.7099447250366211, 0.6546961069107056, 0.7762430906295776, 0.7817679643630981, 0.7734806537628174, 0.8121547102928162, 0.8093922734260559, 0.8066298365592957, 0.8563535809516907, 0.8535911440849304, 0.8701657652854919, 0.8618784546852112, 0.8535911440849304, 0.889502763748169, 0.8812154531478882, 0.8729282021522522, 0.8784530162811279, 0.8674033284187317, 0.8839778900146484, 0.8784530162811279, 0.8839778900146484, 0.9171270728111267, 0.8950276374816895], 'val_loss': [0.6324149966239929, 0.5748280882835388, 0.537817120552063, 0.5126944184303284, 0.4911538362503052, 0.47753846645355225, 0.4625410735607147, 0.45192599296569824, 0.44351860880851746, 0.4394892752170563, 0.430193156003952, 0.42491695284843445, 0.42138800024986267, 0.4172518849372864, 0.4119878113269806, 0.40589141845703125, 0.40429064631462097, 0.399353951215744, 0.40307578444480896, 0.39602288603782654, 0.39687782526016235, 0.39587682485580444, 0.39574024081230164, 0.39059093594551086, 0.3963099718093872], 'val_binary_accuracy': [0.5333333611488342, 0.6222222447395325, 0.6888889074325562, 0.7222222089767456, 0.7888888716697693, 0.7555555701255798, 0.8111110925674438, 0.8111110925674438, 0.8111110925674438, 0.7888888716697693, 0.8111110925674438, 0.8222222328186035, 0.8333333134651184, 0.8333333134651184, 0.8333333134651184, 0.8333333134651184, 0.8333333134651184, 0.8444444537162781, 0.8333333134651184, 0.8444444537162781, 0.8444444537162781, 0.8444444537162781, 0.8444444537162781, 0.8444444537162781, 0.855555534362793]}

Test model against test dataset

predictWithTestDataset(model)WARNING:tensorflow:5 out of the last 12 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f58803470d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.number predictions=37
Accuracy:0.7837837837837838
# Unfreeze the base_model. Note that it keeps running in inference mode
# since we passed `training=False` when calling it. This means that
# the batchnorm layers will not update their batch statistics.
# This prevents the batchnorm layers from undoing all the training
# we've done so far.
base_model.trainable = True
model.summary()
model.compile(
optimizer=keras.optimizers.Adam(1e-5), # Low learning rate
loss=keras.losses.BinaryCrossentropy(from_logits=True),
metrics=[keras.metrics.BinaryAccuracy()],
)
epochs = 2
history = model.fit(train_ds, epochs=epochs, validation_data=val_ds)
Model: "model_13"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_28 (InputLayer) [(None, 300, 300, 3)] 0
_________________________________________________________________
sequential_32 (Sequential) (None, 300, 300, 3) 0
_________________________________________________________________
normalization_13 (Normalizat (None, 300, 300, 3) 7
_________________________________________________________________
xception (Functional) (None, 10, 10, 2048) 20861480
_________________________________________________________________
global_average_pooling2d_13 (None, 2048) 0
_________________________________________________________________
dropout_28 (Dropout) (None, 2048) 0
_________________________________________________________________
dense_59 (Dense) (None, 1) 2049
=================================================================
Total params: 20,863,536
Trainable params: 20,809,001
Non-trainable params: 54,535
_________________________________________________________________
Epoch 1/2
12/12 [==============================] - 19s 1s/step - loss: 0.1918 - binary_accuracy: 0.9218 - val_loss: 0.3842 - val_binary_accuracy: 0.8556
Epoch 2/2
12/12 [==============================] - 16s 1s/step - loss: 0.1469 - binary_accuracy: 0.9509 - val_loss: 0.3520 - val_binary_accuracy: 0.8556
print(history.history)
acc = history.history['binary_accuracy']
val_acc = history.history['val_binary_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
{'loss': [0.1816166639328003, 0.14842070639133453], 'binary_accuracy': [0.9226519465446472, 0.950276255607605], 'val_loss': [0.38415849208831787, 0.3520398437976837], 'val_binary_accuracy': [0.855555534362793, 0.855555534362793]}
predictWithTestDataset(model)WARNING:tensorflow:6 out of the last 13 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f58f13847b8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
number predictions=37
Accuracy:0.8378378378378378

Conclusion

We have been able to get some pretty good results with a limited dataset. This is indeed very promising! It seems that neural networks are more than capable of categorizing art collections according to themes.

There are many ways to further improve these results, from gathering more images, experimenting with different image sizes, and even trying new model architectures.

Hope you found this article useful. And hope to see you for the next one. Happy coding!

RESOURCES

--

--

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store