Classifying healthy and bleached corals using CNNs

Gabriela Padilla
LatinXinAI
Published in
7 min readMay 2, 2024

Coral bleaching is a major global issue that has been intensifying in recent decades due to climate change. About half of the world’s living corals have died since the 1950s, and reef biodiversity has dropped by 63%.

That’s why in this project, we will be classifying images from healthy and bleached corals. To do this task, we will use Convolutional neural networks (CNNs).

Let’s first get an overview of what CNNs are:

A CNN works similarly to how a detective might investigate a crime scene. Just as a detective carefully examines different pieces of evidence, a CNN breaks down the image puzzle into smaller sections and analyzes them for specific features.

Here’s how it works:

  1. Puzzle Pieces (Pixels): In the image puzzle, each tiny piece represents a pixel, which contains information about color and intensity.
  2. Detective (CNN): The CNN acts as the detective, tasked with finding specific objects or patterns within the image.
  3. Evidence Examination (Convolutional Layers): Just as a detective might use magnifying glasses to examine evidence closely, a CNN uses convolutional layers to analyze small sections of the image puzzle. These layers detect features like edges, shapes, and textures.
  4. Detective’s Notebook (Feature Maps): As the detective uncovers important clues, they jot them down in their notebook. Similarly, CNNs create feature maps that highlight significant features found in the image, helping to identify patterns and objects.
  5. Solving the Case (Classification or Detection): Once the detective has gathered enough evidence and pieced together the puzzle, they can solve the case. Similarly, CNNs use the information from the feature maps to classify objects in the image (e.g., “cat,” “tree”) or detect specific patterns (e.g., “road,” “building”).

Step-by-Step project explanation

Dataset

For this project, I used a dataset with almost 5000 images of bleached corals and healthy corals. 80% of the dataset was used for training the model and 20% was used for validation.

Before training, we will preprocess the images from our dataset. These are the code snippets for preprocessing the images before training. They are using TensorFlow’s image_dataset_from_directory function to create datasets from image directories, splitting them into training and validation sets, converting them to grayscale, setting the image size, and specifying the batch size. This preprocessing step is crucial for preparing the data to be fed into the model during training and validation.

TRAIN_DATASET = tf.keras.utils.image_dataset_from_directory(
DATASET_PATH,
validation_split=DATASET_VALIDATION_SPLIT,
subset="training",
color_mode="grayscale",
seed=123,
image_size=(IMG_HEIGHT, IMG_WIDTH),
batch_size=BATCH_SIZE)

VALIDATION_DATASET = tf.keras.utils.image_dataset_from_directory(
DATASET_PATH,
validation_split=DATASET_VALIDATION_SPLIT,
subset="validation",
color_mode="grayscale",
seed=123,
image_size=(IMG_HEIGHT, IMG_WIDTH),
batch_size=BATCH_SIZE)

CNN Model

Let’s first get an overview of the model configurations:

model = tf.keras.Sequential([
tf.keras.layers.Rescaling(1./255, input_shape=(IMG_HEIGHT, IMG_WIDTH, 1)),
tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
tf.keras.layers.MaxPooling2D(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(len(CLASS_NAMES))
])
  • Rescaling Layer: Normalizes the input pixel values to the range [0, 1] by dividing each value by 255.
  • Conv2D Layer: Performs 2D convolution on the input, using 16 filters of size 3x3. ‘ReLu’ activation is applied after convolution.
  • MaxPooling2D Layer: Reduces the spatial resolution of the representation to decrease the number of parameters and calculations while retaining the most relevant features. A max-pooling operation is used here.
  • Conv2D Layer: Similar to the second layer, but with 32 filters of size 3x3.
  • MaxPooling2D Layer: Another max-pooling layer to further reduce spatial resolution.
  • Conv2D Layer: Similar to previous ones, but with 64 filters of size 3x3.
  • MaxPooling2D Layer: Once again, max-pooling is performed.
  • Flatten Layer: Flattens the representation into a vector to connect with fully connected (dense) layers.
  • Dense Layer: Contains 128 neurons with ‘ReLu’ activation function.
  • Dense Layer: This is the output layer with a number of neurons equal to the number of classes.

It’s time to train the model and save it to use it later! This is the code snippet to train our model:

start = time()

history = model.fit(
TRAIN_DATASET,
validation_data=VALIDATION_DATASET,
epochs=EPOCHS
)

end = time()

elapsed_time = end - start

final_results = {
"dataset": "Grayscale",
"total_parameters": model.count_params(),
"elapsed_time": elapsed_time,
"accuracy": history.history["val_accuracy"][-1]
}
  1. start = time(): Records the start time of training.
  2. history = model.fit(...): Trains the model using training and validation datasets.
  3. end = time(): Records the end time of training.
  4. elapsed_time = end - start: Calculates the training duration.
  5. final_results = { ... }: Stores dataset type, total parameters, elapsed time, and final validation accuracy.

This code snippet captures the training process, measures its duration, and collects key metrics for evaluation.

To save the model, we will use this code snippet:

model.save(MODEL_PATH)

Results and making predictions

This model had an accuracy of 80%. To enhance this accuracy, we should find more images and add them to the dataset.

To make these predictions we follow these steps:

1.Import the model we saved before

model_path = "corals_detection_model.h5"
model = tf.keras.models.load_model(model_path)

2. Import the image and preprocess it

img_path = '/content/t1.jpg'
original_img = tf.keras.preprocessing.image.load_img(img_path)
plt.imshow(original_img)
plt.axis('off')
plt.show()
gray_img = tf.image.rgb_to_grayscale(original_img)
gray_img = tf.image.resize(gray_img, (64, 64))
plt.imshow(gray_img)
plt.axis('off')
plt.show()
# Convert to array
img_array = tf.keras.preprocessing.image.img_to_array(gray_img)
img_array = tf.expand_dims(img_array, axis=0) # Create a batch
img_array.shape

This code snippet performs image preprocessing using TensorFlow and Matplotlib:

Load and Display Original Image:

  • img_path = '/content/t1.jpg': Specifies the file path of the original image.
  • original_img = tf.keras.preprocessing.image.load_img(img_path): Loads the original image using TensorFlow's image loading function.
  • plt.imshow(original_img): Displays the original image using Matplotlib.
  • plt.axis('off'): Turns off the axis for cleaner visualization.
  • plt.show(): Displays the original image without axis.

Convert to Grayscale and Resize:

  • gray_img = tf.image.rgb_to_grayscale(original_img): Converts the original RGB image to grayscale using TensorFlow's function.
  • gray_img = tf.image.resize(gray_img, (64, 64)): Resizes the grayscale image to dimensions (64, 64) using TensorFlow's resizing function.
  • plt.imshow(gray_img): Displays the resized grayscale image.
  • plt.axis('off'): Turns off the axis for cleaner visualization.
  • plt.show(): Displays the resized grayscale image without axis.

Convert to Array and Create Batch:

  • img_array = tf.keras.preprocessing.image.img_to_array(gray_img): Converts the resized grayscale image to a NumPy array using TensorFlow's conversion function.
  • img_array = tf.expand_dims(img_array, axis=0): Adds an extra dimension to the array to create a batch. This is necessary for processing images in batches during model training or inference.
  • img_array.shape: Prints the shape of the image array, indicating the dimensions of the batch.

This code prepares the original image for further processing by converting it to grayscale, resizing it to a standard size, converting it to a NumPy array, and organizing it into a batch for compatibility with machine learning models.

3. Make the prediction!

# Get probabilities
probabilities = 1 / (1 + np.exp(-result))

# Get prediction label
prediction = np.argmax(result, axis=1)[0]
class_names = ['bleached_corals', 'healthy_corals']

plt.figure(figsize=(8,8))
plt.subplots_adjust(hspace=0.4)
plt.imshow(original_img)
plt.title(class_names[prediction], fontdict={"fontsize": 13})
plt.axis("off")
plt.show()

This code snippet processes the output of a neural network model:

  1. probabilities: Calculates the probabilities of the classes using the sigmoid function.
  2. prediction: Determines the predicted class by finding the index with the highest probability.
  3. class_names: Defines the class labels.
  4. plt.figure(): Initializes a matplotlib figure for visualization.
  5. plt.imshow(): Displays the original image.
  6. plt.title(): Sets the title of the plot with the predicted class name.
  7. plt.axis(): Turns off axis for cleaner visualization.
  8. plt.show(): Displays the plot with the original image and predicted class name.

Here are some examples of the predictions from the model:

Same approach but different applications

Employing Convolutional Neural Networks (CNNs) for image classification can be applied to various domains beyond marine ecology. Let’s explore some alternative applications:

  1. Medical Imaging: CNNs can aid in the detection of diseases from medical images such as X-rays, MRIs, and CT scans. For instance, they can be utilized to identify different types of tumors, classify skin lesions for melanoma detection, or diagnose neurological disorders based on brain scans.
  2. Food Quality Inspection: In the food industry, CNNs can be employed to inspect the quality of food products. They can classify fruits and vegetables based on ripeness or identify defects such as mold or bruises on produce, ensuring the quality and safety of food items before they reach consumers.
  3. Object Recognition in Autonomous Vehicles: CNNs play a crucial role in the development of object recognition systems for autonomous vehicles. By analyzing images captured by onboard cameras, CNNs can identify various objects on the road, including vehicles, pedestrians, traffic signs, and obstacles, enabling the vehicle to make informed decisions in real time.
  4. Security and Surveillance: CNNs are widely used in security and surveillance systems for detecting and recognizing objects and individuals. They can analyze video footage to identify suspicious activities, track the movement of objects or people, and enhance overall security measures in public spaces, airports, and other sensitive locations.
  5. Environmental Monitoring: CNNs can contribute to environmental monitoring efforts by analyzing satellite images or aerial photographs. They can detect changes in land use, monitor deforestation, track wildlife populations, and assess the impact of climate change on ecosystems, providing valuable insights for conservation and resource management initiatives.
  6. Retail and E-commerce: In retail settings, CNNs can be employed for various tasks such as product recognition, visual search, and recommendation systems. By analyzing images of products, they can assist customers in finding visually similar items or offer personalized recommendations based on their preferences and browsing history.

You can check out the code in this Github Repository.

Thank you for reading this! If you want to see more of my work, connect with me on LinkedIn!

LatinX in AI (LXAI) logo

Do you identify as Latinx and are working in artificial intelligence or know someone who is Latinx and is working in artificial intelligence?

Don’t forget to hit the 👏 below to help support our community — it means a lot!

--

--