Published in

Geek Culture

Can I Create an Image Classifier Using Tensorflow to Identify the Theme of a Painting?

Can I teach an AI model using Tensorflow Object Detection API, how to identify the theme of a painting?

That’s the general question I want to answer with this project.

But, I am going to narrow down my question further, for practical reasons.

I want to identify if a painting is about the Nativity or not. Which is a simpler binary classification problem.

But first, let me explain what a nativity painting is. A nativity painting is a painting where the subject is the birth of Jesus Christ, very revered in Christianity. During the last two thousand years, many famous artists like, Leonardo Da Vinci, Michelangelo, Caravaggio were commissioned to create paintings for Churches, so there should be plenty of paintings to pick about the nativity.

Before Starting

I created this notebook with some code from the Tensorflow tutorial on Image Classification. You can find the original tutorial in the link below:

https://www.tensorflow.org/tutorials/images/classification

Import TensorFlow and other libraries

`import matplotlib.pyplot as plt import PILimport tensorflow as tfimport osfrom tensorflow import kerasfrom tensorflow.keras import layersfrom tensorflow.keras.models import Sequentialimport pandas as pdimport requests # to get image from the webimport shutil # to save it locallyimport timeimport numpy as np`

In order to train an image classifier, we need to have a training image dataset, validation dataset, and test dataset.

Since we are training a binary image classifier, we will have images for two different classes:

• Nativity
• Others

During the training of the model, we will use the training dataset to teach the model how to classify a painting: either a nativity painting or not(other).

At the end of each training cycle(epoch) we will use the validation data-set to score how well the model is doing by calculating the accuracy and the loss. The accuracy measures how many times our model gets the right answer. Higher is better. The loss measures the delta, i.e. the difference between the predicted value and the actual value. Lower is better.

It is important that the validation dataset is separate from the training dataset because the AI model is very good at cutting corners(i.e. cheating). If you don’t separate the two, the model will simply memorize the answers instead of learning the intrinsic characteristics of what we are trying to teach it.

At the very end of the training, we will also use a separate test dataset from the training and validation dataset, to do an independent benchmark of the model performance.

• nativity_dataset.csv — contains all nativity paintings
• other_dataset.csv — contains many paintings except nativity paintings
• test_dataset.csv — contains labeled paintings

Wait a moment! Did I not just say that the training data set should be separate from the validation data set, so why keep it in the same files?

Yes, but because we are doing data exploration, it is a good thing to have some flexibility. Typically you are advised to have 80% of the training data and 20% of the validation data. But, this is not a hard and fast rule. We might want to change these percentages and see what gives us better results as part of our experimentation. This is also known as Hyperparameter tuning. On the other hand, the test data set should be fixed, so we can compare different models with different architectures in a consistent way.

We define below some utility functions to help download the images from our image dataset. Notice that getFileNameFromUrl() does some very basic cleanup and extraction of the filename in the url.

`def getFileNameFromUrl(url):  firstpos=url.rindex("/")  lastpos=len(url)  filename=url[firstpos+1:lastpos]  print(f"url={url} firstpos={firstpos} lastpos={lastpos} filename={filename}")  return filenamedef downloadImage(imageUrl, destinationFolder):  filename = getFileNameFromUrl(imageUrl)  # Open the url image, set stream to True, this will return the stream content.  r = requests.get(imageUrl, stream = True)  # Check if the image was retrieved successfully  if r.status_code == 200:      # Set decode_content value to True, otherwise the downloaded image file's size will be zero.      r.raw.decode_content = True      # Open a local file with wb ( write binary ) permission.      filePath = os.path.join(destinationFolder, filename)      if not os.path.exists(filePath):        with open(filePath,'wb') as f:            shutil.copyfileobj(r.raw, f)        print('Image sucessfully Downloaded: ',filename)        print("Sleeping for 1 seconds before attempting next download")        time.sleep(1)      else:        print(f'Skipping image {filename} as it is already Downloaded: ')  else:      print(f'Image url={imageUrl} and filename={filename} Couldn't be retreived. HTTP Status={r.status_code}')df = pd.read_csv("nativity_dataset.csv")# create directory to which we download if it doesn't existdestinationFolder = "/content/dataset/nativity"os.makedirs(destinationFolder, exist_ok=True)for i, row in df.iterrows():  print(f"Index: {i}")  print(f"{row['Image URL']}n")  downloadImage(row["Image URL"], destinationFolder)`

Out:

`Index: 0https://d3d00swyhr67nd.cloudfront.net/w1200h1200/collection/LSE/CUMU/LSE_CUMU_TN07034-001.jpgurl=https://d3d00swyhr67nd.cloudfront.net/w1200h1200/collection/LSE/CUMU/LSE_CUMU_TN07034-001.jpg firstpos=68 lastpos=93 filename=LSE_CUMU_TN07034-001.jpgImage sucessfully Downloaded:  LSE_CUMU_TN07034-001.jpgSleeping for 1 seconds before attempting next downloadIndex: 1https://d3d00swyhr67nd.cloudfront.net/w1200h1200/collection/GMIII/MCAG/GMIII_MCAG_1947_188-001.jpg`

Resize All images to be no bigger than 90000 pixels(width x height)

Some of the images in our dataset are over 80MB in size. If we try to resize these images directly from Python, it will try to load the image into memory. Not a great idea. So we are going to use Imagemick to do the job super fast.

`!apt install imagemagickReading package lists... DoneBuilding dependency tree       Reading state information... Doneimagemagick is already the newest version (8:6.9.7.4+dfsg-16ubuntu6.9).0 upgraded, 0 newly installed, 0 to remove and 17 not upgraded.`

Now we define the utility function resizeImages to resize images and copy from a sourceFolder to a destinationFolder.

`def resizeImages(sourceFolder, destinationFolder, maxPixels=1048576):  os.makedirs(destinationFolder, exist_ok=True)  for path, subdirs, files in os.walk(sourceFolder):      relativeDir=path.replace(sourceFolder, "")      destinationFolderPath = destinationFolder + relativeDir      os.makedirs(destinationFolderPath,exist_ok=True)      for fileName in files:          sourceFilepath=os.path.join(path,fileName)          destinationFilepath=os.path.join(destinationFolderPath, fileName)          print(f"sourceFilepath={sourceFilepath} destinationFilepath={destinationFilepath}")          os.system(f"convert {sourceFilepath} -resize {maxPixels}@> {destinationFilepath}")# resize training imagessourceFolder="/content/dataset"destinationFolder = "/content/resized/dataset"resizeImages(sourceFolder, destinationFolder, maxPixels=90000)# resize testing imagessourceFolder="/content/test_dataset"destinationFolder = "/content/resized/test_dataset"resizeImages(sourceFolder, destinationFolder, maxPixels=90000)sourceFilepath=/content/dataset/others/Quentin_Massys-The_Adoration_of_the_Magi-1526%2CMetropolitan_Museum_of_Art%2CNew_York.jpg destinationFilepath=/content/resized/dataset/others/Quentin_Massys-The_Adoration_of_the_Magi-1526%2CMetropolitan_Museum_of_Art%2CNew_York.jpg...`

Map image labels to numeric values

We are using Binary cross-entropy for our classification so we need to make sure our labels are either a 0 or a 1. Nativity = 1 and Others = 0

We will rename the folders to a 0 and a 1 since that is what tf.keras.preprocessing.image_dataset_from_directory uses to create the labels for our data set.

`!mv /content/resized/dataset/nativity /content/resized/dataset/1!mv /content/resized/dataset/others /content/resized/dataset/0!mv /content/resized/test_dataset/nativity /content/resized/test_dataset/1!mv /content/resized/test_dataset/others /content/resized/test_dataset/0`

After downloading, we should now have a copy of the dataset available. There are 429 total images:

`import pathlibdata_dir = pathlib.Path("/content/resized/dataset")test_data_dir = pathlib.Path("/content/resized/test_dataset")image_count = len(list(data_dir.glob('*/*')))print(image_count)454`

Here are some paintings of the nativity:

`nativity_label="1"nativity = list(data_dir.glob(f'{nativity_label}/*'))PIL.Image.open(str(nativity[0]))`
`PIL.Image.open(str(nativity[1]))`

And some non-nativity paintings:

`others_label="0"others = list(data_dir.glob(f'{others_label}/*'))PIL.Image.open(str(others[1]))`

`PIL.Image.open(str(others[2]))`

Keras provides a bunch of really convenient functions to make our life easier when working with Tensorflow. tf.keras.preprocessing.image_dataset_from_directory is one of them. It loads images from the files into tf.data.DataSet format.

`batch_size = 32img_height = 300img_width = 300`

In general it is advised to split data into training data and validation data using a 80% 20% split. Remember, this is not a hard and fast rule.

`train_ds = tf.keras.preprocessing.image_dataset_from_directory(  data_dir,  validation_split=0.2,  subset="training",  seed=123,  image_size=(img_height, img_width),  batch_size=batch_size, label_mode='binary')Found 452 files belonging to 2 classes.Using 362 files for training.val_ds = tf.keras.preprocessing.image_dataset_from_directory(  data_dir,  validation_split=0.2,  subset="validation",  seed=123,  image_size=(img_height, img_width),  batch_size=batch_size,label_mode='binary')Found 452 files belonging to 2 classes.Using 90 files for validation.#Retrieve a batch of images from the test settest_data_dir = pathlib.Path("/content/resized/test_dataset")test_batch_size=37test_ds = tf.keras.preprocessing.image_dataset_from_directory(  test_data_dir,  seed=200,  image_size=(img_height, img_width),  batch_size=test_batch_size,label_mode='binary')Found 37 files belonging to 2 classes.`

You can find the class names in the class_names attribute on these datasets. These correspond to the directory names in alphabetical order.

`class_names = train_ds.class_namesprint(class_names)['0', '1']`

Visualize the data

Here are the first 9 images from the training dataset.

`import matplotlib.pyplot as pltplt.figure(figsize=(10, 10))for images, labels in train_ds.take(1):  for i in range(9):    ax = plt.subplot(3, 3, i + 1)    plt.imshow(images[i].numpy().astype("uint8"))    if labels[i] == 1.0:      title = "Nativity"    else:      title = "Others"    plt.title(title)    plt.axis("off")`

We inspect the image_batch and labels_batch variables.

The image_batch is a tensor of the shape (32, 300, 300, 3). This is a batch of 32 images of shape 300x300x3 (the last dimension refers to color channels RGB). The label_batch is a tensor of the shape (32,), these are corresponding labels to the 32 images.

You can call .numpy() on the image_batch and labels_batch tensors to convert them to a numpy.ndarray.

`for image_batch, labels_batch in train_ds:  print(image_batch.shape)  print(labels_batch.shape)  break(32, 300, 300, 3)(32, 1)`

Configure the dataset for performance

This code is taken directly from the Tensorflow tutorial and it is meant to help with the performance so we don’t always have to fetch images from disk.

`AUTOTUNE = tf.data.AUTOTUNEtrain_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)`

We define a utility function to measure the performance against our test dataset

`labelMappings={"0":"Others","1":"Nativity",               0.0:"Others",1.0 :"Nativity"}def predictWithTestDataset(model):  image_batch, label_batch = test_ds.as_numpy_iterator().next()  predictions = model.predict_on_batch(image_batch).flatten()  predictions = tf.where(predictions < 0.5, 0, 1)  #print('Predictions:n', predictions.numpy())  #print('Labels:n', label_batch)  correctPredictions=0  plt.figure(figsize=(20, 20))  print(f"number predictions={len(predictions)}")  for i in range(len(predictions)):    ax = plt.subplot(8, 5, i +1)    plt.imshow(image_batch[i].astype("uint8"))    prediction = class_names[predictions[i]]    predictionLabel = labelMappings[prediction]    gtLabel = labelMappings[label_batch[i][0]]    if gtLabel == predictionLabel:      correctPredictions += 1    plt.title(f"P={predictionLabel} GT={gtLabel}")    plt.axis("off")  accuracy = correctPredictions/len(predictions)  print(f"Accuracy:{accuracy}")`

Standardize the data

RGB is in the range of [0,255]. We normalize the values to between [0,1] which is the preferred way for a neural network

`normalization_layer = layers.experimental.preprocessing.Rescaling(1./255)`

This normalization layer will be used later in the model definition.

Create the model

We define an initial architecture for our model, based on the architecture seen in the Tensorflow image classification model.

`model = Sequential([  layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),  layers.Conv2D(16, 3, padding='same', activation='relu'),  layers.MaxPooling2D(),  layers.Conv2D(32, 3, padding='same', activation='relu'),  layers.MaxPooling2D(),  layers.Conv2D(64, 3, padding='same', activation='relu'),  layers.MaxPooling2D(),  layers.Flatten(),  layers.Dense(128, activation='relu'),  layers.Dense(1, activation='sigmoid')])`

Compile the model

For this tutorial, we choose the optimizers.Adam optimizer and losses.`BinaryCrossentropy` loss function. To view training and validation accuracy for each training epoch, we pass the metrics argument.

`model.compile(optimizer='adam', loss=keras.losses.BinaryCrossentropy(from_logits=True), metrics=[keras.metrics.BinaryAccuracy()])`

Model summary

View all the layers of the network using the model’s summary method:

`model.summary()Model: "sequential_31"_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================rescaling_27 (Rescaling)     (None, 300, 300, 3)       0         _________________________________________________________________conv2d_108 (Conv2D)          (None, 300, 300, 16)      448       _________________________________________________________________max_pooling2d_60 (MaxPooling (None, 150, 150, 16)      0         _________________________________________________________________conv2d_109 (Conv2D)          (None, 150, 150, 32)      4640      _________________________________________________________________max_pooling2d_61 (MaxPooling (None, 75, 75, 32)        0         _________________________________________________________________conv2d_110 (Conv2D)          (None, 75, 75, 64)        18496     _________________________________________________________________max_pooling2d_62 (MaxPooling (None, 37, 37, 64)        0         _________________________________________________________________flatten_20 (Flatten)         (None, 87616)             0         _________________________________________________________________dense_52 (Dense)             (None, 128)               11214976  _________________________________________________________________dense_53 (Dense)             (None, 1)                 129       =================================================================Total params: 11,238,689Trainable params: 11,238,689Non-trainable params: 0_________________________________________________________________`

Train the model

`epochs=10history = model.fit(  train_ds,  validation_data=val_ds,  epochs=epochs)Epoch 1/1012/12 [==============================] - 1s 85ms/step - loss: 1.9371 - binary_accuracy: 0.5107 - val_loss: 0.7001 - val_binary_accuracy: 0.4444Epoch 2/1012/12 [==============================] - 1s 49ms/step - loss: 0.6491 - binary_accuracy: 0.6737 - val_loss: 0.7258 - val_binary_accuracy: 0.4778Epoch 3/1012/12 [==============================] - 1s 49ms/step - loss: 0.5943 - binary_accuracy: 0.6958 - val_loss: 0.7169 - val_binary_accuracy: 0.5333Epoch 4/1012/12 [==============================] - 1s 49ms/step - loss: 0.5111 - binary_accuracy: 0.7762 - val_loss: 0.7201 - val_binary_accuracy: 0.5667Epoch 5/1012/12 [==============================] - 1s 49ms/step - loss: 0.4013 - binary_accuracy: 0.8427 - val_loss: 0.6920 - val_binary_accuracy: 0.5667Epoch 6/1012/12 [==============================] - 1s 49ms/step - loss: 0.3027 - binary_accuracy: 0.8921 - val_loss: 0.8354 - val_binary_accuracy: 0.5889Epoch 7/1012/12 [==============================] - 1s 50ms/step - loss: 0.2438 - binary_accuracy: 0.9049 - val_loss: 0.8499 - val_binary_accuracy: 0.5778Epoch 8/1012/12 [==============================] - 1s 49ms/step - loss: 0.1725 - binary_accuracy: 0.9292 - val_loss: 0.9742 - val_binary_accuracy: 0.5222Epoch 9/1012/12 [==============================] - 1s 50ms/step - loss: 0.2792 - binary_accuracy: 0.8878 - val_loss: 0.9390 - val_binary_accuracy: 0.5222Epoch 10/1012/12 [==============================] - 1s 50ms/step - loss: 0.1347 - binary_accuracy: 0.9658 - val_loss: 0.9914 - val_binary_accuracy: 0.5889`

Visualize training results

`print(history.history)acc = history.history['binary_accuracy']val_acc = history.history['val_binary_accuracy']# acc = history.history['accuracy']# val_acc = history.history['val_accuracy']loss = history.history['loss']val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(8, 8))plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')plt.plot(epochs_range, val_acc, label='Validation Accuracy')plt.legend(loc='lower right')plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)plt.plot(epochs_range, loss, label='Training Loss')plt.plot(epochs_range, val_loss, label='Validation Loss')plt.legend(loc='upper right')plt.title('Training and Validation Loss')plt.show(){'loss': [1.5215493440628052, 0.6543726325035095, 0.5897634625434875, 0.5006453990936279, 0.39839598536491394, 0.2903604209423065, 0.22604547441005707, 0.22543807327747345, 0.2558016777038574, 0.14820142090320587], 'binary_accuracy': [0.5580110549926758, 0.6491712927818298, 0.7099447250366211, 0.7845304012298584, 0.8425414562225342, 0.8867403268814087, 0.9116021990776062, 0.9060773253440857, 0.9033148884773254, 0.9558011293411255], 'val_loss': [0.7000985741615295, 0.7257931232452393, 0.7169376611709595, 0.7200638055801392, 0.6920430660247803, 0.8354127407073975, 0.8498525619506836, 0.9741556644439697, 0.9390344619750977, 0.9914490580558777], 'val_binary_accuracy': [0.4444444477558136, 0.47777777910232544, 0.5333333611488342, 0.5666666626930237, 0.5666666626930237, 0.5888888835906982, 0.5777778029441833, 0.5222222208976746, 0.5222222208976746, 0.5888888835906982]}`

Looking at the plots, we are seeing a typical sign of overfitting. Overfitting happens when the model fits a bit too much with the training data but does poorly against the validation data. Notice that the accuracy increases along with the epochs for the training accuracy but with the validation data, the accuracy doesn’t increase, and in this case, the loss increases.

`predictWithTestDataset(model)number predictions=37Accuracy:0.6216216216216216`

Data augmentation

Overfitting generally occurs when there are a small number of training examples. It is not surprising as we only have a total of 455 images across two classes. So we need to find a way to generate more training data.

Keras has some really easy-to-use transformation functions for data augmentation on images, that actually produce very decent images.

The idea of image augmentation, a form of regularization, is to make it harder for the model to overfit. Overfitting happens easily when there is very little data. We introduce variability and hopefully, it should help the model to generalise better.

The keras functions for image augmentation are available under tf.keras.layers.experimental.preprocessing.

We create a layer with the image augmentation transformations and we can include it in the creation of the model, just like any other layer.

For the image augmentation we do a random horizontal flip and a simple rotation.

`data_augmentation = keras.Sequential(  [    layers.experimental.preprocessing.RandomFlip("horizontal",                                                  input_shape=(img_height,                                                               img_width,                                                              3)),    layers.experimental.preprocessing.RandomRotation(0.1)     ])`

Let’s visualize what a few augmented examples look like by applying data augmentation to the same image several times:

`plt.figure(figsize=(10, 10))for images, _ in train_ds.take(1):  for i in range(9):    augmented_images = data_augmentation(images)    ax = plt.subplot(3, 3, i + 1)    plt.imshow(augmented_images[0].numpy().astype("uint8"))    plt.axis("off")`

We will use data augmentation to train a model in a moment.

Dropout

Another technique to reduce overfitting is to introduce Dropout to the network, a form of regularization. Think of it as a random culling of neurons in the neural network, during training. We want to prevent the model from depending on one feature alone. The culling of each layer is based on a fraction(0.2) that we pass as an argument to the Dropout function.

Let’s create a new neural network using layers.Dropout, then train it using augmented images.

`model = Sequential([  data_augmentation,  layers.experimental.preprocessing.Rescaling(1./255, input_shape=(img_height, img_width, 3)),  layers.Conv2D(16, 3, padding='same', activation='relu'),  layers.MaxPooling2D(),  layers.Conv2D(32, 3, padding='same', activation='relu'),  layers.MaxPooling2D(),  layers.Conv2D(64, 3, padding='same', activation='relu'),  layers.MaxPooling2D(),  layers.Dropout(0.2),  layers.Flatten(),  layers.Dense(128, activation='relu'),  layers.Dense(1, activation='sigmoid')])`

Compile and train the model

`from tensorflow import optimizersmodel.compile(loss=keras.losses.BinaryCrossentropy(from_logits=True),              optimizer=optimizers.RMSprop(lr=1e-4),                  metrics=[keras.metrics.BinaryAccuracy()])model.summary()Model: "sequential_34"_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================sequential_32 (Sequential)   (None, 300, 300, 3)       0         _________________________________________________________________rescaling_29 (Rescaling)     (None, 300, 300, 3)       0         _________________________________________________________________conv2d_114 (Conv2D)          (None, 300, 300, 16)      448       _________________________________________________________________max_pooling2d_66 (MaxPooling (None, 150, 150, 16)      0         _________________________________________________________________conv2d_115 (Conv2D)          (None, 150, 150, 32)      4640      _________________________________________________________________max_pooling2d_67 (MaxPooling (None, 75, 75, 32)        0         _________________________________________________________________conv2d_116 (Conv2D)          (None, 75, 75, 64)        18496     _________________________________________________________________max_pooling2d_68 (MaxPooling (None, 37, 37, 64)        0         _________________________________________________________________dropout_26 (Dropout)         (None, 37, 37, 64)        0         _________________________________________________________________flatten_22 (Flatten)         (None, 87616)             0         _________________________________________________________________dense_56 (Dense)             (None, 128)               11214976  _________________________________________________________________dense_57 (Dense)             (None, 1)                 129       =================================================================Total params: 11,238,689Trainable params: 11,238,689Non-trainable params: 0_________________________________________________________________epochs = 25history = model.fit(  train_ds,  validation_data=val_ds,  epochs=epochs)Epoch 1/2512/12 [==============================] - 2s 76ms/step - loss: 0.9417 - binary_accuracy: 0.5580 - val_loss: 0.7153 - val_binary_accuracy: 0.5222Epoch 2/2512/12 [==============================] - 1s 61ms/step - loss: 0.6869 - binary_accuracy: 0.5338 - val_loss: 0.7236 - val_binary_accuracy: 0.5333Epoch 3/2512/12 [==============================] - 1s 61ms/step - loss: 0.6557 - binary_accuracy: 0.5985 - val_loss: 0.8124 - val_binary_accuracy: 0.5222Epoch 4/2512/12 [==============================] - 1s 61ms/step - loss: 0.6447 - binary_accuracy: 0.6315 - val_loss: 0.6829 - val_binary_accuracy: 0.5556Epoch 5/2512/12 [==============================] - 1s 65ms/step - loss: 0.6482 - binary_accuracy: 0.6273 - val_loss: 0.6708 - val_binary_accuracy: 0.5778Epoch 6/2512/12 [==============================] - 1s 61ms/step - loss: 0.6482 - binary_accuracy: 0.6348 - val_loss: 0.6733 - val_binary_accuracy: 0.5556Epoch 7/2512/12 [==============================] - 1s 61ms/step - loss: 0.6325 - binary_accuracy: 0.6592 - val_loss: 0.6762 - val_binary_accuracy: 0.5333Epoch 8/2512/12 [==============================] - 1s 62ms/step - loss: 0.5994 - binary_accuracy: 0.6680 - val_loss: 0.6587 - val_binary_accuracy: 0.6111Epoch 9/2512/12 [==============================] - 1s 61ms/step - loss: 0.6204 - binary_accuracy: 0.6904 - val_loss: 0.7240 - val_binary_accuracy: 0.5333Epoch 10/2512/12 [==============================] - 1s 62ms/step - loss: 0.6343 - binary_accuracy: 0.6480 - val_loss: 0.6776 - val_binary_accuracy: 0.5667Epoch 11/2512/12 [==============================] - 1s 62ms/step - loss: 0.6439 - binary_accuracy: 0.6107 - val_loss: 0.6811 - val_binary_accuracy: 0.5556Epoch 12/2512/12 [==============================] - 1s 62ms/step - loss: 0.6361 - binary_accuracy: 0.6301 - val_loss: 0.6612 - val_binary_accuracy: 0.6222Epoch 13/2512/12 [==============================] - 1s 62ms/step - loss: 0.6025 - binary_accuracy: 0.6949 - val_loss: 0.6725 - val_binary_accuracy: 0.5778Epoch 14/2512/12 [==============================] - 1s 61ms/step - loss: 0.5977 - binary_accuracy: 0.6868 - val_loss: 0.7521 - val_binary_accuracy: 0.5444Epoch 15/2512/12 [==============================] - 1s 62ms/step - loss: 0.5713 - binary_accuracy: 0.6833 - val_loss: 0.6427 - val_binary_accuracy: 0.6444Epoch 16/2512/12 [==============================] - 1s 62ms/step - loss: 0.5918 - binary_accuracy: 0.6939 - val_loss: 0.6515 - val_binary_accuracy: 0.6333Epoch 17/2512/12 [==============================] - 1s 61ms/step - loss: 0.5831 - binary_accuracy: 0.7253 - val_loss: 0.6556 - val_binary_accuracy: 0.5889Epoch 18/2512/12 [==============================] - 1s 62ms/step - loss: 0.5626 - binary_accuracy: 0.7121 - val_loss: 0.6877 - val_binary_accuracy: 0.5667Epoch 19/2512/12 [==============================] - 1s 62ms/step - loss: 0.5476 - binary_accuracy: 0.7327 - val_loss: 0.6398 - val_binary_accuracy: 0.6556Epoch 20/2512/12 [==============================] - 1s 62ms/step - loss: 0.5551 - binary_accuracy: 0.7283 - val_loss: 0.6465 - val_binary_accuracy: 0.6333Epoch 21/2512/12 [==============================] - 1s 62ms/step - loss: 0.5436 - binary_accuracy: 0.7312 - val_loss: 0.7083 - val_binary_accuracy: 0.5667Epoch 22/2512/12 [==============================] - 1s 65ms/step - loss: 0.5987 - binary_accuracy: 0.6781 - val_loss: 0.8078 - val_binary_accuracy: 0.5222Epoch 23/2512/12 [==============================] - 1s 62ms/step - loss: 0.5534 - binary_accuracy: 0.7139 - val_loss: 0.6705 - val_binary_accuracy: 0.6111Epoch 24/2512/12 [==============================] - 1s 85ms/step - loss: 0.5617 - binary_accuracy: 0.7406 - val_loss: 0.6471 - val_binary_accuracy: 0.6111Epoch 25/2512/12 [==============================] - 1s 62ms/step - loss: 0.5541 - binary_accuracy: 0.7303 - val_loss: 0.6263 - val_binary_accuracy: 0.7000`

Visualize training results

After applying data augmentation and Dropout, there is less overfitting than before, and training and validation accuracy are closely aligned.

`print(history.history)acc = history.history['binary_accuracy']val_acc = history.history['val_binary_accuracy']loss = history.history['loss']val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(8, 8))plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')plt.plot(epochs_range, val_acc, label='Validation Accuracy')plt.legend(loc='lower right')plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)plt.plot(epochs_range, loss, label='Training Loss')plt.plot(epochs_range, val_loss, label='Validation Loss')plt.legend(loc='upper right')plt.title('Training and Validation Loss')plt.show(){'loss': [0.8289327025413513, 0.6810811758041382, 0.6626855731010437, 0.6704213619232178, 0.650795042514801, 0.6398268938064575, 0.6561762690544128, 0.6122907400131226, 0.6228107810020447, 0.6147234439849854, 0.6149318814277649, 0.6190508604049683, 0.607336699962616, 0.5861756801605225, 0.593088686466217, 0.6063793301582336, 0.5983501672744751, 0.5894269347190857, 0.5698645114898682, 0.5585014224052429, 0.5401753783226013, 0.5774908065795898, 0.5512883067131042, 0.5710932016372681, 0.5434897541999817], 'binary_accuracy': [0.5386740565299988, 0.5386740565299988, 0.6077347993850708, 0.6160221099853516, 0.6022099256515503, 0.6325966715812683, 0.6325966715812683, 0.6685082912445068, 0.6961326003074646, 0.6602209806442261, 0.6408839821815491, 0.6574585437774658, 0.6712707281112671, 0.6933701634407043, 0.6574585437774658, 0.6767956018447876, 0.7044199109077454, 0.6933701634407043, 0.7154695987701416, 0.7265193462371826, 0.7265193462371826, 0.6906077265739441, 0.7154695987701416, 0.7071823477745056, 0.7375690340995789], 'val_loss': [0.7153488993644714, 0.7235575318336487, 0.8124216794967651, 0.6829271912574768, 0.6708189249038696, 0.673344612121582, 0.676236629486084, 0.6586815714836121, 0.7239749431610107, 0.677582323551178, 0.6810950636863708, 0.6611502170562744, 0.6725294589996338, 0.7520950436592102, 0.642659068107605, 0.6514749526977539, 0.6556094884872437, 0.687703013420105, 0.639808714389801, 0.6464514136314392, 0.7082778811454773, 0.8077911138534546, 0.670492947101593, 0.6470986008644104, 0.6263118386268616], 'val_binary_accuracy': [0.5222222208976746, 0.5333333611488342, 0.5222222208976746, 0.5555555820465088, 0.5777778029441833, 0.5555555820465088, 0.5333333611488342, 0.6111111044883728, 0.5333333611488342, 0.5666666626930237, 0.5555555820465088, 0.6222222447395325, 0.5777778029441833, 0.5444444417953491, 0.644444465637207, 0.6333333253860474, 0.5888888835906982, 0.5666666626930237, 0.6555555462837219, 0.6333333253860474, 0.5666666626930237, 0.5222222208976746, 0.6111111044883728, 0.6111111044883728, 0.699999988079071]}`

Predict on new data

Finally, let’s use our model to classify images on our test dataset that weren’t included in the training or validation sets.

`predictWithTestDataset(model)number predictions=37Accuracy:0.7027027027027027`

Transfer Learning

We have already used image augmentation to try and get better results from our model, and I have to say the results were not bad at all. We were able to get a model with almost 70% accuracy on the validation dataset. Surely, we can do better than that, if we are able to collect hundreds more, perhaps thousands of more images for our training and validation data set.

We can certainly do that, but there is another way that doesn’t involve the tedious and expensive process of collecting more training data: Transfer Learning.

With transfer learning, we can borrow a model that is already trained against thousands of images and re-train it for our use case, but with much fewer images than it would have been possible to if we trained a model from scratch.

To do so we can use Keras to download a pre-trained model with the Xception architecture already trained on Imagenet.

To perform transfer learning we need to freeze the weights of the base model and perform the training as we normally would. You will notice that we still do the image augmentation and regularization.

`base_model = keras.applications.Xception(    weights="imagenet",  # Load weights pre-trained on ImageNet.    input_shape=(img_height, img_width, 3),    include_top=False,)  # Do not include the ImageNet classifier at the top.# Freeze the base_modelbase_model.trainable = False# Create new model on topinputs = keras.Input(shape=(img_height, img_width, 3))x = data_augmentation(inputs)  # Apply random data augmentation# Pre-trained Xception weights requires that input be normalized# from (0, 255) to a range (-1., +1.), the normalization layer# does the following, outputs = (inputs - mean) / sqrt(var)norm_layer = keras.layers.experimental.preprocessing.Normalization()mean = np.array([127.5] * 3)var = mean ** 2# Scale inputs to [-1, +1]x = norm_layer(x)norm_layer.set_weights([mean, var])# The base model contains batchnorm layers. We want to keep them in inference mode# when we unfreeze the base model for fine-tuning, so we make sure that the# base_model is running in inference mode here.x = base_model(x, training=False)x = keras.layers.GlobalAveragePooling2D()(x)x = keras.layers.Dropout(0.2)(x)  # Regularize with dropoutoutputs = keras.layers.Dense(1, activation="sigmoid")(x)model = keras.Model(inputs, outputs)model.summary()Model: "model_13"_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================input_28 (InputLayer)        [(None, 300, 300, 3)]     0         _________________________________________________________________sequential_32 (Sequential)   (None, 300, 300, 3)       0         _________________________________________________________________normalization_13 (Normalizat (None, 300, 300, 3)       7         _________________________________________________________________xception (Functional)        (None, 10, 10, 2048)      20861480  _________________________________________________________________global_average_pooling2d_13  (None, 2048)              0         _________________________________________________________________dropout_28 (Dropout)         (None, 2048)              0         _________________________________________________________________dense_59 (Dense)             (None, 1)                 2049      =================================================================Total params: 20,863,536Trainable params: 2,049Non-trainable params: 20,861,487_________________________________________________________________# model.compile(optimizer=keras.optimizers.Adam(),#               loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),#               metrics=['accuracy'])model.compile(    optimizer=keras.optimizers.Adam(),    loss=keras.losses.BinaryCrossentropy(from_logits=True),    metrics=[keras.metrics.BinaryAccuracy()],)epochs = 25history = model.fit(  train_ds,  validation_data=val_ds,  epochs=epochs)Epoch 1/2512/12 [==============================] - 7s 376ms/step - loss: 0.6989 - binary_accuracy: 0.5116 - val_loss: 0.6324 - val_binary_accuracy: 0.5333Epoch 2/2512/12 [==============================] - 4s 322ms/step - loss: 0.6106 - binary_accuracy: 0.5943 - val_loss: 0.5748 - val_binary_accuracy: 0.6222Epoch 3/2512/12 [==============================] - 4s 322ms/step - loss: 0.5557 - binary_accuracy: 0.6647 - val_loss: 0.5378 - val_binary_accuracy: 0.6889Epoch 4/2512/12 [==============================] - 4s 326ms/step - loss: 0.5280 - binary_accuracy: 0.6333 - val_loss: 0.5127 - val_binary_accuracy: 0.7222Epoch 5/2512/12 [==============================] - 4s 329ms/step - loss: 0.4751 - binary_accuracy: 0.7638 - val_loss: 0.4912 - val_binary_accuracy: 0.7889Epoch 6/2512/12 [==============================] - 4s 331ms/step - loss: 0.4586 - binary_accuracy: 0.7535 - val_loss: 0.4775 - val_binary_accuracy: 0.7556Epoch 7/2512/12 [==============================] - 4s 335ms/step - loss: 0.4328 - binary_accuracy: 0.7778 - val_loss: 0.4625 - val_binary_accuracy: 0.8111Epoch 8/2512/12 [==============================] - 4s 339ms/step - loss: 0.3951 - binary_accuracy: 0.8387 - val_loss: 0.4519 - val_binary_accuracy: 0.8111Epoch 9/2512/12 [==============================] - 4s 344ms/step - loss: 0.3745 - binary_accuracy: 0.8427 - val_loss: 0.4435 - val_binary_accuracy: 0.8111Epoch 10/2512/12 [==============================] - 4s 348ms/step - loss: 0.3631 - binary_accuracy: 0.8373 - val_loss: 0.4395 - val_binary_accuracy: 0.7889Epoch 11/2512/12 [==============================] - 4s 350ms/step - loss: 0.3449 - binary_accuracy: 0.8705 - val_loss: 0.4302 - val_binary_accuracy: 0.8111Epoch 12/2512/12 [==============================] - 4s 355ms/step - loss: 0.3409 - binary_accuracy: 0.8623 - val_loss: 0.4249 - val_binary_accuracy: 0.8222Epoch 13/2512/12 [==============================] - 4s 356ms/step - loss: 0.3491 - binary_accuracy: 0.8848 - val_loss: 0.4214 - val_binary_accuracy: 0.8333Epoch 14/2512/12 [==============================] - 4s 356ms/step - loss: 0.3522 - binary_accuracy: 0.8569 - val_loss: 0.4173 - val_binary_accuracy: 0.8333Epoch 15/2512/12 [==============================] - 4s 354ms/step - loss: 0.3106 - binary_accuracy: 0.8641 - val_loss: 0.4120 - val_binary_accuracy: 0.8333Epoch 16/2512/12 [==============================] - 4s 348ms/step - loss: 0.3108 - binary_accuracy: 0.8973 - val_loss: 0.4059 - val_binary_accuracy: 0.8333Epoch 17/2512/12 [==============================] - 4s 348ms/step - loss: 0.3041 - binary_accuracy: 0.8840 - val_loss: 0.4043 - val_binary_accuracy: 0.8333Epoch 18/2512/12 [==============================] - 4s 364ms/step - loss: 0.3106 - binary_accuracy: 0.8548 - val_loss: 0.3994 - val_binary_accuracy: 0.8444Epoch 19/2512/12 [==============================] - 4s 343ms/step - loss: 0.3072 - binary_accuracy: 0.8774 - val_loss: 0.4031 - val_binary_accuracy: 0.8333Epoch 20/2512/12 [==============================] - 4s 341ms/step - loss: 0.3008 - binary_accuracy: 0.8870 - val_loss: 0.3960 - val_binary_accuracy: 0.8444Epoch 21/2512/12 [==============================] - 4s 342ms/step - loss: 0.2959 - binary_accuracy: 0.8738 - val_loss: 0.3969 - val_binary_accuracy: 0.8444Epoch 22/2512/12 [==============================] - 4s 340ms/step - loss: 0.2655 - binary_accuracy: 0.8874 - val_loss: 0.3959 - val_binary_accuracy: 0.8444Epoch 23/2512/12 [==============================] - 4s 340ms/step - loss: 0.2452 - binary_accuracy: 0.9098 - val_loss: 0.3957 - val_binary_accuracy: 0.8444Epoch 24/2512/12 [==============================] - 4s 359ms/step - loss: 0.2532 - binary_accuracy: 0.9214 - val_loss: 0.3906 - val_binary_accuracy: 0.8444Epoch 25/2512/12 [==============================] - 4s 340ms/step - loss: 0.2512 - binary_accuracy: 0.9202 - val_loss: 0.3963 - val_binary_accuracy: 0.8556print(history.history)acc = history.history['binary_accuracy']val_acc = history.history['val_binary_accuracy']loss = history.history['loss']val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(8, 8))plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')plt.plot(epochs_range, val_acc, label='Validation Accuracy')plt.legend(loc='lower right')plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)plt.plot(epochs_range, loss, label='Training Loss')plt.plot(epochs_range, val_loss, label='Validation Loss')plt.legend(loc='upper right')plt.title('Training and Validation Loss')plt.show(){'loss': [0.6730840802192688, 0.6024077534675598, 0.5369390249252319, 0.5121317505836487, 0.47049736976623535, 0.43795397877693176, 0.4308575689792633, 0.409035861492157, 0.39621278643608093, 0.38378414511680603, 0.36360302567481995, 0.3390505313873291, 0.33790066838264465, 0.34375235438346863, 0.3241599202156067, 0.3240811824798584, 0.30652204155921936, 0.3121297359466553, 0.29941326379776, 0.31102144718170166, 0.29046544432640076, 0.2721157371997833, 0.2742222845554352, 0.273193895816803, 0.2644941210746765], 'binary_accuracy': [0.541436493396759, 0.580110490322113, 0.7099447250366211, 0.6546961069107056, 0.7762430906295776, 0.7817679643630981, 0.7734806537628174, 0.8121547102928162, 0.8093922734260559, 0.8066298365592957, 0.8563535809516907, 0.8535911440849304, 0.8701657652854919, 0.8618784546852112, 0.8535911440849304, 0.889502763748169, 0.8812154531478882, 0.8729282021522522, 0.8784530162811279, 0.8674033284187317, 0.8839778900146484, 0.8784530162811279, 0.8839778900146484, 0.9171270728111267, 0.8950276374816895], 'val_loss': [0.6324149966239929, 0.5748280882835388, 0.537817120552063, 0.5126944184303284, 0.4911538362503052, 0.47753846645355225, 0.4625410735607147, 0.45192599296569824, 0.44351860880851746, 0.4394892752170563, 0.430193156003952, 0.42491695284843445, 0.42138800024986267, 0.4172518849372864, 0.4119878113269806, 0.40589141845703125, 0.40429064631462097, 0.399353951215744, 0.40307578444480896, 0.39602288603782654, 0.39687782526016235, 0.39587682485580444, 0.39574024081230164, 0.39059093594551086, 0.3963099718093872], 'val_binary_accuracy': [0.5333333611488342, 0.6222222447395325, 0.6888889074325562, 0.7222222089767456, 0.7888888716697693, 0.7555555701255798, 0.8111110925674438, 0.8111110925674438, 0.8111110925674438, 0.7888888716697693, 0.8111110925674438, 0.8222222328186035, 0.8333333134651184, 0.8333333134651184, 0.8333333134651184, 0.8333333134651184, 0.8333333134651184, 0.8444444537162781, 0.8333333134651184, 0.8444444537162781, 0.8444444537162781, 0.8444444537162781, 0.8444444537162781, 0.8444444537162781, 0.855555534362793]}`

Test model against test dataset

`predictWithTestDataset(model)WARNING:tensorflow:5 out of the last 12 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f58803470d0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.number predictions=37Accuracy:0.7837837837837838`
`# Unfreeze the base_model. Note that it keeps running in inference mode# since we passed `training=False` when calling it. This means that# the batchnorm layers will not update their batch statistics.# This prevents the batchnorm layers from undoing all the training# we've done so far.base_model.trainable = Truemodel.summary()model.compile(    optimizer=keras.optimizers.Adam(1e-5),  # Low learning rate    loss=keras.losses.BinaryCrossentropy(from_logits=True),    metrics=[keras.metrics.BinaryAccuracy()],)epochs = 2history = model.fit(train_ds, epochs=epochs, validation_data=val_ds)Model: "model_13"_________________________________________________________________Layer (type)                 Output Shape              Param #   =================================================================input_28 (InputLayer)        [(None, 300, 300, 3)]     0         _________________________________________________________________sequential_32 (Sequential)   (None, 300, 300, 3)       0         _________________________________________________________________normalization_13 (Normalizat (None, 300, 300, 3)       7         _________________________________________________________________xception (Functional)        (None, 10, 10, 2048)      20861480  _________________________________________________________________global_average_pooling2d_13  (None, 2048)              0         _________________________________________________________________dropout_28 (Dropout)         (None, 2048)              0         _________________________________________________________________dense_59 (Dense)             (None, 1)                 2049      =================================================================Total params: 20,863,536Trainable params: 20,809,001Non-trainable params: 54,535_________________________________________________________________Epoch 1/212/12 [==============================] - 19s 1s/step - loss: 0.1918 - binary_accuracy: 0.9218 - val_loss: 0.3842 - val_binary_accuracy: 0.8556Epoch 2/212/12 [==============================] - 16s 1s/step - loss: 0.1469 - binary_accuracy: 0.9509 - val_loss: 0.3520 - val_binary_accuracy: 0.8556print(history.history)acc = history.history['binary_accuracy']val_acc = history.history['val_binary_accuracy']loss = history.history['loss']val_loss = history.history['val_loss']epochs_range = range(epochs)plt.figure(figsize=(8, 8))plt.subplot(1, 2, 1)plt.plot(epochs_range, acc, label='Training Accuracy')plt.plot(epochs_range, val_acc, label='Validation Accuracy')plt.legend(loc='lower right')plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)plt.plot(epochs_range, loss, label='Training Loss')plt.plot(epochs_range, val_loss, label='Validation Loss')plt.legend(loc='upper right')plt.title('Training and Validation Loss')plt.show(){'loss': [0.1816166639328003, 0.14842070639133453], 'binary_accuracy': [0.9226519465446472, 0.950276255607605], 'val_loss': [0.38415849208831787, 0.3520398437976837], 'val_binary_accuracy': [0.855555534362793, 0.855555534362793]}`
`predictWithTestDataset(model)WARNING:tensorflow:6 out of the last 13 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f58f13847b8> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.number predictions=37Accuracy:0.8378378378378378`

Conclusion

We have been able to get some pretty good results with a limited dataset. This is indeed very promising! It seems that neural networks are more than capable of categorizing art collections according to themes.

There are many ways to further improve these results, from gathering more images, experimenting with different image sizes, and even trying new model architectures.

Hope you found this article useful. And hope to see you for the next one. Happy coding!

RESOURCES

--

--