Custom Pipeline for Every TensorFlow Keras Image Classification Model

Mar 2, 2024


When I first started working with Keras, I used to dive straight into the code without following any structured workflow. However, I soon realized that this approach made the process quite challenging. Therefore, I decided to create a pipeline for image classification tasks.

  • In this article, I will share the pipeline that I use for every image classification task.
Visualization of Augmented Images

I developed a set of functions that I now utilize consistently in my deep learning notebooks and projects.

Functions for:

→ Creating Dataset ( Augmented — Not Augmented )

→ Visualization of Images from Dataset ( Augmented and not Augmented images)

→ Visualization of Class Distribution

→ Visualization of Model History ( training and validation loss , training and validation accuracy )

1. Creating Dataset ( Augmented - Not Augmented)

For training any model,you need to prepare datasets: one for training and one for validation. You can split them at any rate you prefer. I typically allocate 70% for the training set and 30% for the validation set.

If you don’t have enough images or if you want to increase the diversity of your dataset, you can use data augmentation. However, if you already have sufficient data, after splitting your dataset into training and validation sets, you can train your model directly with those datasets

What is Data Augmentation ?

Data augmentation is a technique to increase the diversity of your training set by applying random transformations, such as image rotation , image scaling, and more .

Augmented images are generated from original images with different orientations, scales, brightness levels and more. Keras provides a very useful function called ‘ImageDataGenerator’ for the augmentation process.

Augmented datasets are used for prevent overfitting, helping the model generalize better.

Note : Augmentation is applied only to the training set.

Function for Creating Dataset

from tensorflow.keras.preprocessing.image import ImageDataGenerator

# your directories

# if you want to augmented dat set use it like this : prep_data(True)
def prep_data(augmented,batch_size=16):
if augmented:
# you can change this parameters
train_datagen = ImageDataGenerator(
# Augmentation is applied only to the training set
validation_datagen = ImageDataGenerator(rescale=1./255)

# if you set augmented=False , images are just rescaled
train_datagen = ImageDataGenerator(rescale=1.0 / 255.0)
validation_datagen = ImageDataGenerator(rescale=1.0 / 255.0)

# training set
train_set = train_datagen.flow_from_directory(
target_size=(180, 180), # The dimensions to which all images found will be resized
class_mode="sparse") # you can change this to onehotEncoded format or another format

# validation set
validation_set = validation_datagen.flow_from_directory(
target_size=(180, 180),

return train_set , validation_set

It returns train and validation datasets , with these datasets you can train your models .

Creating Augmented Dataset with 16 batch size

2. Visualization of Images

This step is not really necessary but I think it is good to see some examples from your dataset. Particularly when using an augmented dataset, it is helpful to observe some example augmented images because sometimes the augmentation may not produce the desired results .

  • By observing some examples you can adjust the augmentation parameters .
import matplotlib.pyplot as plt 

# create dataset (augmented or not augmented , it is up to you , process is same in both cases


class_names = train_set.class_indices
class_names = {v: k for k, v in class_names.items()}

fig, axes = plt.subplots(1, 4, figsize=(15, 5))

for i in range(4):
label_index = int(labels[i])
class_name = class_names[label_index]

Example Images (Augmented)
Example Images (Not Augmented)

3. Visualization of Class Distribution

Balancing a dataset is a crucial step because it helps prevent the model from becoming biased towards one class. If you have an imbalanced dataset, the results may not be satisfactory. Visualizing the distribution can be useful for determining whether the dataset is balanced or imbalanced.

import os
import matplotlib.pyplot as plt

# train an validation folders path
train_dir = "path_to_training_dir"
validation_dir = "path_to_validation_dir"

# calculate distribution in training set
train_class_counts = {}
for class_folder in os.listdir(train_dir):
class_path = os.path.join(train_dir, class_folder)
if os.path.isdir(class_path):
num_images = len(os.listdir(class_path))
train_class_counts[class_folder] = num_images

# calculate distribution in validation set
validation_class_counts = {}
for class_folder in os.listdir(validation_dir):
class_path = os.path.join(validation_dir, class_folder)
if os.path.isdir(class_path):
num_images = len(os.listdir(class_path))
validation_class_counts[class_folder] = num_images
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 6))

# training
plt.subplot(1, 2, 1), train_class_counts.values())
plt.title('Training set Distribution')
plt.ylabel('Sample Numbers')

# validation
plt.subplot(1, 2, 2), validation_class_counts.values())
plt.title('Validation set Distribution')
plt.ylabel('Sample Numbers')

Class Distribution
  • In this case “Starfish” class has way more example than other classes , you may want to consider reducing number of images in “Starfish” class.

4.Visualization of Model History

When I read deep learning notebooks on Kaggle or other platforms, authors often use the same scripts for plotting for every different models they train. I think creating one function and use it for every model is more efficient and makes it easier for readers to understand.

In Keras, the .fit() function is used for training, and if you save its output as a variable, it stores the training history of the model.

history =

# you can create train_set and validation_set by following step 1
  • By saving this model , you can use it for plotting accuracy , loss or other metrics that you want to visualize.
import matplotlib.pyplot as plt

## visulization function for Models
def visualize(history):
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(acc) + 1)

fig, axs = plt.subplots(1, 2, figsize=(12, 5))

axs[0].plot(epochs, acc, 'r', label='Training acc')
axs[0].plot(epochs, val_acc, 'b', label='Validation acc')
axs[0].set_title('Training and validation accuracy')

axs[1].plot(epochs, loss, 'r', label='Training loss')
axs[1].plot(epochs, val_loss, 'b', label='Validation loss')
axs[1].set_title('Training and validation loss')


