Exploratory image analysis — Part 1 : Advanced density plots
Exploratory data analysis and visualization techniques are essential to get insight from the data. Unlock the full power of AI approaches by understanding and focusing on data quality!
Introduction
Exploratory data analysis uses descriptive statistics and visualization techniques to provide insights on the data. Descriptive statistics aims at summarizing and analyzing data, and visualization techniques allows highlighting patterns, correlations, trends, outliers and errors in the data, as well as communicating the results. Typical visualization techniques include histograms, scatter plots, box plots, non-linear dimensionality reduction techniques, projection embeddings, dataset sprite plot and interactive version of these plots. In brief, descriptive statistics and visualization techniques are key to gain insights into the data characteristics and distribution.
In this article, we look at advanced density visualization techniques, which are useful to understand the data distribution, identify shifts on test data, and improve prediction errors. In specific, we focus on image data, CIFAR-10 dataset, and the joypy library which yields advanced density plots. Let’s get started!
Installation
First, we’ll install the necessary libraries within an environment.
!pip install matplotlib==3.8.2 \
numpy==1.26.3 \
pillow==10.2.0 \
pandas==2.2.0 \
seaborn==0.13.1 \
joypy==0.2.6
import os
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import PIL.Image as Image
np.random.seed(0)
Data
We look at CIFAR-10 dataset, which is a collection 60,000 images of 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images. We will use the testset. Images have 32x32 pixels and 3 channels (RGB).
Data download and loading
Data can be download from the CIFAR-10 website or Kaggle, but the simplest method is to download it using keras or pytorch. We show how to it using Keras, which requires to install tensorflow.
from keras.datasets import cifar10
# Download the CIFAR-10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print(x_train.shape, y_train.shape, x_test.shape, y_test.shape)
# We retain only the test set
images = x_test
labels = y_test
We define CIFAR-10 labels, as given on the website. We also define different groups (transport, wild, pets) for the classes that are expected to present similarities.
# Channels and CIFAR-10 classes
channels = ['r', 'g', 'b']
cifar10_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
cifar10_labels_idx = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
#cifar10_labels_idx = os.listdir(data_dir)
# Labels names
labels_names = [cifar10_labels[label[0]] for label in labels]
# Define groups with similar classes
cifar10_groups = {'transport': ['airplane', 'ship', 'automobile', 'truck'],
'pet': ['cat', 'dog'],
'wild': ['deer', 'horse', 'frog', 'bird']}
# Name for dataset (for saving results)
result_name = 'cifar10'
# Path for results
result_dir = '../../../Results/cifar10/data_anal/'
if not os.path.exists(result_dir):
os.makedirs(result_dir)
After reading the data, we display few random images and then a sprite of the dataset.
# Display a grid of images with labels as titles
def display_grid_images_labels(images, labels, dim_resize=None, num_subplots=(3, 3),
figsize = (6, 6), path_file=None):
fig, axes = plt.subplots(num_subplots[0], num_subplots[1], figsize=figsize)
axes = axes.flatten()
for i, (img, label) in enumerate(zip(images, labels)):
ax = axes[i]
ax.imshow(img)
ax.set_title(label)
ax.axis('off')
if path_file is not None:
fig.savefig(path_file)
else:
plt.show()
num_selected = 16
images_selected = [images[i] for i in range(num_selected)]
labels_selected = [cifar10_labels[labels[i][0]] for i in range(num_selected)]
file_save = os.path.join(result_dir, f'{result_name}_img.png')
display_grid_images_labels(images_selected, labels_selected,
path_file=file_save, figsize=(5,6), num_subplots=(4,4))
It is common practice to create a sprite plot of the dataset, which is a single image that contains all or a subset of the images in the dataset.
def images_to_sprite(data, invert_colors=False):
if len(data.shape) == 3:
data = np.tile(data[...,np.newaxis], (1,1,1,3))
data = data.astype(np.float32)
min = np.min(data.reshape((data.shape[0], -1)), axis=1)
data = (data.transpose(1,2,3,0) - min).transpose(3,0,1,2)
max = np.max(data.reshape((data.shape[0], -1)), axis=1)
data = (data.transpose(1,2,3,0) / max).transpose(3,0,1,2)
# Inverting the colors seems to look better for MNIST
if invert_colors:
data = 1 - data
n = int(np.ceil(np.sqrt(data.shape[0])))
padding = ((0, n ** 2 - data.shape[0]), (0, 0),
(0, 0)) + ((0, 0),) * (data.ndim - 3)
data = np.pad(data, padding, mode='constant',
constant_values=0)
# Tile the individual thumbnails into an image.
data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3)
+ tuple(range(4, data.ndim + 1)))
data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
data = (data * 255).astype(np.uint8)
return data, n
# Subsample the dataset
num_selected = 25*25
images_2d = np.array(images[:num_selected,:,:,:]).reshape(-1, 32, 32, 3)
# Create the sprite image
sprite, n = images_to_sprite(images_2d)
sprite_path = os.path.join(result_dir, f'{result_name}_sprite.png')
Image.fromarray(sprite).save(sprite_path)
#Image.fromarray(sprite).show()
fig = plt.figure(figsize=(10,10))
plt.imshow(sprite)
plt.axis('off')
Descriptive statistics
Summary statistics
A summary statistics provides a quantitative summary of the data with a few numbers. It is fast way to get a general idea of the data. We use pandas to create a dataframe of the dataset. Then, we group the data by class and use the describe method to provide a summary of the central tendency, dispersion and shape of the dataset.
For large image datasets, we would need to subsample the dataset/images or preprocess the data.
def create_df_class(images, labels, labels_names, channels, class_name='class'):
for label_idx in range(len(labels_names)):
# Combine all pixels for same group
images_subgroup = images[labels.flatten() == label_idx]
images_subgroup = images_subgroup.reshape(-1, 3)
# Create dataframe per class
images_subgroup_df = pd.DataFrame(images_subgroup, columns=channels)
images_subgroup_df[class_name] = labels_names[label_idx]
# Concatenate dataframes
if label_idx == 0:
images_df = images_subgroup_df
else:
images_df = pd.concat([images_df, images_subgroup_df], axis=0, ignore_index=True)
return images_df
# Create a dataframe with all images
num_selected = 500
images_df = create_df_class(images[:num_selected,:,:,:], labels[:num_selected], cifar10_labels, channels)
# print(images_df)
# Group by class
images_df_group = images_df.groupby('class')
# Display stats
stats = images_df_group.describe()
stats = stats.transpose()
stats
class airplane automobile bird cat deer \
r count 58368.000000 41984.000000 52224.000000 50176.000000 40960.000000
mean 135.941338 118.818288 124.331093 121.899454 124.246802
std 69.458980 68.817497 59.446712 61.540625 58.691362
min 0.000000 0.000000 0.000000 0.000000 0.000000
25% 78.000000 61.000000 79.000000 75.000000 80.000000
50% 135.000000 116.000000 131.000000 123.000000 121.000000
75% 197.000000 172.000000 170.000000 168.000000 164.000000
max 255.000000 255.000000 255.000000 255.000000 255.000000
g count 58368.000000 41984.000000 52224.000000 50176.000000 40960.000000
mean 145.330198 115.106159 127.130706 113.891920 119.134888
std 67.454871 67.767542 56.622067 59.125744 55.025181
min 0.000000 0.000000 0.000000 0.000000 0.000000
25% 90.000000 61.000000 90.000000 69.000000 79.000000
50% 146.000000 108.000000 132.000000 113.000000 114.000000
75% 202.000000 164.000000 169.000000 158.000000 156.000000
max 255.000000 255.000000 255.000000 255.000000 255.000000
b count 58368.000000 41984.000000 52224.000000 50176.000000 40960.000000
mean 150.140591 113.582960 112.737247 100.180863 102.528516
std 73.820860 70.371585 63.826265 60.017407 54.218043
min 0.000000 0.000000 0.000000 0.000000 0.000000
25% 86.000000 55.000000 64.000000 52.000000 62.000000
50% 160.000000 106.000000 109.000000 95.000000 97.000000
75% 215.000000 166.000000 158.000000 142.000000 134.000000
max 255.000000 255.000000 255.000000 255.000000 255.000000
...
25% 61.000000 46.000000 60.000000 99.000000 63.000000
50% 104.000000 74.000000 102.000000 150.000000 116.000000
75% 161.000000 110.000000 155.000000 197.000000 180.000000
max 255.000000 255.000000 255.000000 255.000000 255.000000
Data visualization common plots
Common plots that provide a general description of the data are histograms, scatter plots, box plots, and correlation plots. Pandas and Seaborn are common libraries for these plots. Here, we focus on histograms. We will start by showing that it is quite tricky to get useful insight with two examples of these plots. Then, we will look at more advance or dedicated tools for displaying densities.
Histograms provide a visual representation of the distribution of the data, as they represent the frequency of the values in the dataset. They highlight the shape and skewness of the distribution and can help identifying outliers in the data. The bins are the intervals of the values.
As a first example, we use pandas hist method to plot the histogram for one class.
# Select class bird and plot histogram using pandas
class_selected = 'bird'
images_df_class = images_df[images_df['class'] == class_selected]
fig, axes = plt.subplots(1, 3, figsize=(7, 4))
for i, channel in enumerate(channels):
images_df_class.hist(column=channel, ax=axes[i], bins=20, alpha=0.5)
Other libraries such as seaborn yield higher quality plots. Seaborn is a well known Python library based on matplotlib that provides a high-level interface for displaying statistical graphics.
A useful plot to start the exploratory analysis is the pairplot, also known as scatterplot matrix. It creates a grid of scatter plots, where each variable in the dataset is plotted against each other. The diagonal plots are histograms of the corresponding variable.
fig = sns.pairplot(images_df.sample(10000), hue='class' )
fig.savefig(os.path.join(result_dir, f'{result_name}_pairplot.png'))
Cannot see much? Well, I can’t!
- Histograms are cluttered!
- Scatter plots do not provide much information either on this case! Scatter plots are useful to find relations between variables. In this case variables are RGB channels.
Even though seaborn yields high quality plots, the standard choices do not provide great inside. We need another tool to come to the rescue for comparing densities.
Advanced density plots
A useful library to plot histograms is joypy. Joyplot is a matplotlib- and pandas-based library for partially overlapping plots.
In order to get further insight from the data, we separate plots by subclasses for clarity and for better identifying similar distributions. It may take a minute or so in order to estimate the density for each class, depending on the number of images.
import joypy
# Plot histogram per subclass
for subclass in cifar10_groups.keys():
images_subclass = images_df[images_df['class'].isin(cifar10_groups[subclass])]
fig, axes = joypy.joyplot(images_subclass,
legend=True,
color=['r', 'g', 'b'],
fade=True,
by='class',
figsize=(6,3),)
fig.savefig(os.path.join(result_dir, f'{result_name}_jp_{subclass}.png'))
Do you notice the similarities between classes? Cat and dog have very similar distributions, as well as automobile and truck, and deer and horse. We also notice that ship and bird are skewed to the left (left-tailed) because of the blue colors of sea and sky, respectively. These distributions opposed to the rest of the distributions and to the general distribution of the dataset, which are mostly right-tailed distributions.
Skewness and kurtosis are two important measures of the shape of the distribution. Skewness refers to symmetry of distribution while kurtosis refers to the tail of the distribution.
skewness = images_df_group.skew()
print(skewness)
r g b
class
airplane -0.012009 -0.160037 -0.303892
automobile 0.208095 0.335396 0.332865
bird -0.147719 -0.242198 0.246500
cat 0.001902 0.087687 0.420626
deer 0.242688 0.328370 0.650615
dog -0.059597 0.141171 0.406403
frog 0.232138 0.381372 0.932765
horse 0.021530 0.112794 0.456886
ship 0.006660 -0.095590 -0.167799
truck 0.086183 0.150249 0.231329
We can also compare to the distribution of the entire dataset.
# PLot histogram of all dataset
fig, axes = joypy.joyplot(images_df,
legend=True,
color=['r', 'g', 'b'],
figsize=(6,3))
fig.savefig(os.path.join(result_dir, f'{result_name}_jp.png'))
We can see that comparing densities by subclasses provides a first inside into the dataset and that joyplot yields off the shelf plots for this purpose.
Summary and discussion
Histograms can provide first data insights, specially for applications for which different classes can have different range of values. Histograms are also key in machine learning, as learning tasks can be seen as explicitly or implicitly learning the distribution of the data. In addition, this analysis can be useful for error analysis; for instance, by looking at cases that are not well classified.
In this dataset, we consider the channels as variables and the pixels as observations. This choice is an example for illustration purposes, and other more meaningful variables can be defined depending on the application.
A word of caution is raised when over-relying on summary statistics, providing global metrics [Alabi 2023]. However, they are a great starting point for extracting insight from data and they can be used to alert of potential shifts in the data by monitoring in a production environment.
In this article, we have briefly introduced descriptive statistics and data exploration, focusing on density plots. Other great techniques for image data are projection embeddings and non-linear dimensionality reduction techniques, which allow to visualize all data points in a 2D or 3D space, and interactive plots with libraries such as Plotly, Vega-Altair, among others. We will look at these techniques in future articles.
We got to the end of this tutorial! The full code is available in the link below!
If you like the tutorial, give it a thumbs up, share it and subscribe for more!
References
- Olubunmi Alabi and Tosin Bukola. Introduction to Descriptive statistics, Recent Advances in Biostatistics, 2023. Link
- Alex Krizhevsky, Learning Multiple Layers of Features from Tiny Images, Alex Krizhevsky, 2009. Link
- P Kaur et al. Descriptive statistics. International Journal of Academic Medicine 4(1):p 60–63, Jan–Apr, 2018.
- T Nick. Topics in Biostatistics. Methods in Molecular Biology. Chapter 3: Descriptive statistics, 2007. Link
Code
A colab notebook for this tutorial: here