Transfer Learning and Fine Tuning NSFW Image Detector with Tensorflow

Part 2 : Transfer Learning and Fine Tuning. Full code is available here

Diko Sakti Prabowo
7 min readSep 15, 2023

--

In the previous part, we demonstrated the steps to build a basic NSFW Image Detector from scratch and applied some overfitting handling method.

In this tutorial, you’ll discover how to classify NSFW (Not Safe for Work) images using a technique called transfer learning from a pre-trained network.

A pre-trained model is like a smart student who has already studied a lot of pictures and learned from them. You can use this smart student’s knowledge as a starting point for your own learning task.

The idea behind transfer learning for image classification is that if a model has learned a lot about the visual world from a big dataset, you can use its knowledge to help you with your own image classification task. Instead of starting from scratch, you can build on what this model already knows.

In this tutorial, we’ll explore two ways to customize a pre-trained model:

1. Feature Extraction: Think of this like using a pre-trained artist’s brushes to paint a new picture. You take the useful things the model learned before and add a new part (classifier) to help it classify your pictures.

2. Fine-Tuning: It’s like making some small adjustments for a pre-trained artist to make them even better at a specific style of art. You’ll change a few things in the pre-trained model to make it work really well for your specific classification task.

Create the base model from the pre-trained model

We’ll begin by using MobileNet V2, a smart model made by Google. It has learned a lot from a massive dataset called ImageNet, which has tons of pictures of different things like fruits and tools. This knowledge will help us identify NSFW images in our own collection of pictures.

Now, to make the best use of MobileNet V2, we’ll pick a special part of it for feature extraction. We don’t want the very last part because it’s not so useful for us. Instead, we’ll grab the part just before everything gets flatten, which we call the bottleneck layer. This layer knows more general stuff, which is what we need.

To set this up, we’ll get a MobileNet V2 model that’s already learned from ImageNet. But we’ll leave out the top part where it does the final classification. That way, we can use it for our own classification job.

IMG_SHAPE = (img_height, img_width)+(3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
include_top=False,
weights='imagenet')

Feature Extraction

Now, in this step, we’re going to keep the special part we made earlier just as it is (treating it like a feature collector) . Then, we’ll add a new part on top, and teach it to do a specific job.

Freeze the convolutional base

Before we proceed to compile and train your model, it’s essential to freeze the convolutional base. By freezing it is preventing the weights within a specific layer from getting updated during training. Since MobileNet V2 consists of numerous layers, setting the entire model’s “trainable” flag to False effectively freezes all of them.

Freezing ensures that the knowledge acquired by these layers from the initial dataset remains intact and isn’t altered during the subsequent training process.

base_model.trainable = False
base_model.summary()
Zero trainable params because it has been frozen

Add a classification head

To create predictions from the set of features, we’ll calculate the average value across all the different parts of the image (like a grid of 5x5 squares). This will compress all the information into a single list of 1280 numbers for each image. We’ll achieve this using a tool called tf.keras.layers.GlobalAveragePooling2D.

image_batch, label_batch = next(iter(train_ds))
feature_batch = base_model(image_batch)
global_average_layer = layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)

Next, we’ll use a tf.keras.layers.Dense layer to transform these features into a single prediction for each image. No activation function is necessary in this step because this prediction will be treated as a logit, which means it’s a raw prediction value.

prediction_layer = tf.keras.layers.Dense(num_classes)
prediction_batch = prediction_layer(feature_batch_average)

We’ll use preprocessing layer from mobilenet. It is a similar layer with rescaling layer that we did in our own custom model.

preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input

Create a model by linking the data augmentation, rescaling, base_model, and feature extractor layers using the Keras Functional API. Set the training parameter to False because our model includes a BatchNormalization layer.

inputs = tf.keras.Input(shape=(180, 180, 3))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)
model.summary()
Classification layer/dense layer are trainable while mobilenetv2 remains frozen

Compile and train

base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])

initial_epochs = 10
history = model.fit(train_ds,
epochs=initial_epochs,
validation_data=val_ds)

Visualize Training and Validation History

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(initial_epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

From graph above, we can see validation accuracy has improved significantly compared to our previous model. This demonstrate how useful transfer learning is in building a neural network.

loss, accuracy = model.evaluate(test_ds)
print('Test accuracy :', accuracy)
y_prob = model.predict(test_ds)
y_pred = y_prob.argmax(axis=1)
y_pred = list(map(lambda x: class_names[x],y_pred))
ConfusionMatrixDisplay.from_predictions(y_test,y_pred)

Some interesting takeawyas from the confusion matrix:

  1. Overall performance is already better than our own custom model
  2. 2. This model still have difficulty in differentiating between drawings and hentai as well as porn and sexy images

For reminder, this is previous model’s confusion matrix

We can improve the model by adjusting it to our specific dataset with fine tuning.

Fine Tuning

When we did feature extraction, we didn’t change the pre-trained MobileNetV2 model. We just added a few new layers on top to help with our specific task.

Now, if we want to make our model even better, we can do something called “fine-tuning.” This means we’ll adjust the weights of some top layers in the pre-trained model while training our new layers. This way, the model will learn to recognize things that are important for our particular dataset.

Rather than fine-tuning every layer in the pre-trained model, we should try to fine-tune a small number of top layers. The lower layers learned basic stuff that’s useful for almost any picture. As we move up in the layers, they learned things specific to the dataset they were trained on.

So, when we fine-tune, we’re mainly tweaking the specialized stuff to fit our dataset better, without forgetting the general knowledge.

Un-freeze the top layers of the model

We are going to unfreeze the base_model and keep the bottom layers to be freezed. Then, recompile and resume training.

base_model.trainable = True

# Fine-tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
layer.trainable = False

model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer = tf.keras.optimizers.RMSprop(learning_rate=base_learning_rate/10),
metrics=['accuracy'])
model.summary()
fine_tune_epochs = 10
total_epochs = initial_epochs + fine_tune_epochs

history_fine = model.fit(train_ds,
epochs=total_epochs,
initial_epoch=history.epoch[-1],
validation_data=val_ds)

Visualize Training Results

acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

Fine tuning is prone to overfitting as seen in this graph but we still got a better accuracy against test data.

loss, accuracy = model.evaluate(test_ds)
print('Test accuracy :', accuracy)
y_prob= model.predict(test_ds)
y_pred = y_prob.argmax(axis=1)
y_pred = list(map(lambda x: class_names[x],y_pred))
ConfusionMatrixDisplay.from_predictions(y_test,y_pred)

Although this fine-tuned model still has some behaviour as previous model, such as missclassifying hentai with drawings and porn with sexy images, the ability to recognize hentai images significantly increase while maintaining performance in other classes.

Let’s save this model to be used in future projects.

model.save('nsfw_mobilenet_fine_tuned.keras')

Conclusion

Congratulations! You’ve reached the end of this 2 part tutorial on building an NSFW image classifier with TensorFlow. We’ve covered a lot of ground, and you’ve acquired some powerful skills in deep learning and computer vision.

Throughout this journey, we’ve explored the importance of responsible content filtering in today’s digital landscape and delved into the inner workings of convolutional neural networks, data augmentation, dropout regularization, transfer learning and fine-tuning.

I encourage you to continue experimenting, learning, and sharing your insights with the world. The future holds exciting opportunities in the field of artificial intelligence, and you are well-equipped to be a part of it.

Thank you for joining me on this journey, I hope this tutorial will help you to further improve your skills and interest in deep learning.

--

--

Diko Sakti Prabowo

Data Scientist with 3 years experience, passionate about football, AC Milan and rock music.