All about Structural Similarity Index (SSIM): Theory + Code in PyTorch
Recently, while implementing a depth estimation paper, I came across the term Structural Similarity Index(SSIM). SSIM is used as a metric to measure the similarity between two given images. As this technique has been around since 2004, a lot of material exists explaining the theory behind SSIM but very few resources go deep into the details, that too specifically for a gradient-based implementation as SSIM is often used as a loss function. Hence, this article is my humble attempt to plug this gap!
The objective of this article is two-fold,
- To explain the theory and intuition behind SSIM and explore some of its application in current cutting edge Deep Learning.
- Go deep into a PyTorch implementation. You can skip to the code here. The full implementation can be found as a standalone notebook here. Just click on the “Open in Colab” link to start running the code!
So let’s begin!
The Theory
SSIM was first introduced in the 2004 IEEE paper, Image Quality Assessment: From Error Visibility to Structural Similarity. The abstract provides a good intuition into the idea behind the system proposed,
Objective methods for assessing perceptual image quality traditionally attempted to quantify the visibility of errors (differences) between a distorted image and a reference image using a variety of known properties of the human visual system. Under the assumption that human visual perception is highly adapted for extracting structural information from a scene, we introduce an alternative complementary framework for quality assessment based on the degradation of structural information.
Summary: The authors make 2 essential points,
- Most Image quality assessment techniques rely on quantifying errors between a reference and a sample image. A common metric is to quantify the difference in the values of each of the corresponding pixels between the sample and the reference images (By using, for example, Mean Squared Error).
- The Human visual perception system is highly capable of identifying structural information from a scene and hence identifying the differences between the information extracted from a reference and a sample scene. Hence, a metric that replicates this behavior will perform better on tasks that involve differentiating between a sample and a reference image.
The Structural Similarity Index (SSIM) metric extracts 3 key features from an image:
- Luminance
- Contrast
- Structure
The comparison between the two images is performed on the basis of these 3 features.
Fig 1 given below shows the arrangement and flow of the Structural Similarity Measurement system. Signal X and Signal Y refer to the Reference and Sample Images.
But what does this metric calculate?
This system calculates the Structural Similarity Index between 2 given images which is a value between -1 and +1. A value of +1 indicates that the 2 given images are very similar or the same while a value of -1 indicates the 2 given images are very different. Often these values are adjusted to be in the range [0, 1], where the extremes hold the same meaning.
Now, let’s explore briefly, how these features are represented mathematically, and how they contribute to the final SSIM score.
- Luminance: Luminance is measured by averaging over all the pixel values. Its denoted by μ (Mu) and the formula is given below,
- Contrast: It is measured by taking the standard deviation (square root of variance) of all the pixel values. It is denoted by σ (sigma) and represented by the formula below,
- Structure: The structural comparison is done by using a consolidated formula (more on that later) but in essence, we divide the input signal with its standard deviation so that the result has unit standard deviation which allows for a more robust comparison.
So now we have established the mathematical intuition behind the three parameters. But hold on! We are not yet done with the math, a little bit more. What we lack now, are comparison functions that can compare the two given images on these parameters, and finally, a combination function that combines them all. Here, we define the comparison functions and finally the combination function that yields the similarity index value
- Luminance comparison function: It is defined by a function, l(x, y) which is shown below. μ (mu) represents the mean of a given image. x and y are the two images being compared.
where C1 is a constant to ensure stability when the denominator becomes 0. C1 is given by,
Update: Throughout the article, we had not explored what the K and L constants in this equation are. Thankfully a reader pointed that out in the comments, and so in the interest of making this article a bit more helpful, I’ll just define them here.
L is the dynamic range for pixel values (we set it as 255 since we are dealing with standard 8-bit images). You can read more about what are the different image types and what they mean, here.
K1, K2 are just normal constants, nothing much there!
- Contrast comparison function: It is defined by a function c(x, y) which is shown below. σ denotes the standard deviation of a given image. x and y are the two images being compared.
where C2 is given by,
- Structure comparison function: It is defined by the function s(x, y) which is shown below. σ denotes the standard deviation of a given image. x and y are the two images being compared.
where σ(xy) is defined as,
And finally, the SSIM score is given by,
where α > 0, β > 0, γ > 0 denote the relative importance of each of the metrics. To simplify the expression, if we assume, α = β = γ = 1 and C3 = C2/2, we can get,
But there’s a plot twist!
While you would be able to implement SSIM using the above formulas, chances are it won’t be as good as the ready-to-use implementations available, as the authors explain that,
For image quality assessment, it is useful to apply the SSIM index locally rather than globally. First, image statistical features are usually highly spatially nonstationary. Second, image distortions, which may or may not depend on the local image statistics, may also be space-variant. Third, at typical viewing distances, only a local area in the image can be perceived with high resolution by the human observer at one time instance (because of the foveation feature of the HVS [49], [50]). And finally, localized quality measurement can provide a spatially varying quality map of the image, which delivers more information about the quality degradation of the image and may be useful in some applications.
Summary: Instead of applying the above metrics globally (i.e. all over the image at once) it’s better to apply the metrics regionally (i.e. in small sections of the image and taking the mean overall).
This method is often referred to as the Mean Structural Similarity Index.
Due to this change in approach, our formulas also deserve modifications to reflect the same (it should be noted that this approach is more common and will be used to explain the code).
(Note: If the content below seems a bit overwhelming, no worries! If you get the gist of it, then going through the code will give you a much clearer idea.)
The authors use an 11x11 circular-symmetric Gaussian Weighing function (basically, an 11x11 matrix whose values are derived from a gaussian distribution) which moves pixel-by-pixel over the entire image. At each step, the local statistics and SSIM index are calculated within the local window. Since we are now calculating the metrics locally, our formulas are revised as,
Where wi is the gaussian weighting function.
If you found this a bit unintuitive, no worries! It suffices to imagine wi as a multiplicand that is used to calculate the required values with the help of some mathematical tricks.
Once computations are performed all over the image, we simply take the mean of all the local SSIM values and arrive at the global SSIM value.
Finally done with the theory! Now onto the code!
The Code
Before we plunge into the code, it’s important to note that we won’t be going through every line but we will explore in-depth the essential ones. Let’s get started!
The full code can be found as a standalone notebook here. Just click on the “Open in Colab” button to start running the code! The explanation in this section will be referring to the notebook mentioned above.
First, let’s explore some utility functions that perform some essential tasks.
Function #1: gaussian(window_size, sigma)
This function essentially generates a list of numbers (of length equal to window_size) sampled from a gaussian distribution. The sum of all the elements is equal to 1 and the values are normalized. Sigma is the standard deviation of the gaussian distribution.
Note: This is used to generate the 11x11 gaussian window mentioned above.
Example:
Code:gauss_dis = gaussian(11, 1.5)
print("Distribution: ", gauss_dis)
print("Sum of Gauss Distribution:", torch.sum(gauss_dis))Output: Distribution: tensor([0.0010, 0.0076, 0.0360, 0.1094, 0.2130, 0.2660, 0.2130, 0.1094, 0.0360,0.0076, 0.0010]) Sum of Gauss Distribution: tensor(1.)
Function #2: create_window(window_size, channel)
While we generated a 1D tensor of gaussian values, the 1D tensor itself is of no use to us. Hence we gotta convert it to a 2D tensor (the 11x11 Tensor we talked about earlier). The steps taken in this function are as follows,
- Generate the 1D tensor using the gaussian function.
- Convert it to a 2D tensor by cross-multiplying with its transpose (this preserves the gaussian character).
- Add two extra dimensions to convert it to 4D. (This is only when SSIM is used as a loss function in computer vision)
- Reshape to adhere to PyTorch weight’s format.
Code:window = create_window(11, 3)
print(window.shape)Output: torch.Size([3, 1, 11, 11])
Now that we have explored the two utility functions, let’s go through the main code! The core SSIM is implemented through the ssim() function which is explored below.
Function #3: ssim(img1, img2, val_range, window_size=11, window=None, size_average=True, full=False)
Before we move onto the essentials, let us explore what happens in the function before the ssim metrics are calculated,
- We set the maximum value of the normalized pixels (implementation detail; needn’t worry)
- We initialize the gaussian window by means of the create_window() function IF a window was not provided during the function call.
Once these steps are completed, we go about calculating the various values (the sigmas and the mus of the world) which are needed to arrive at the final SSIM score.
Note: Since we are calculating local statistics and we need to make it computationally efficient, the formulas used are slightly different (They are just permutations of the formulas discussed above. Relevant mathematical materials are provided in the appendix.)
- We first calculate μ(x), μ(y), their squares, and μ(xy). channels here store the number of color channels of the input image. The groups parameter is used to apply a convolution filter to all the input channels. More information regarding groups can be found here.
channels, height, width = img1.size()mu1 = F.conv2d(img1, window, padding=pad, groups=channels)
mu2 = F.conv2d(img2, window, padding=pad, groups=channels)mu1_sq = mu1 ** 2
mu2_sq = mu2 ** 2mu12 = mu1 * mu2
- We then go on to calculate the squares of σ(x), σ(y), and σ(xy). For more math, check Appendix 1.1.
sigma1_sq = F.conv2d(img1 * img1, window, padding=pad, groups=channels) - mu1_sqsigma2_sq = F.conv2d(img2 * img2, window, padding=pad, groups=channels) - mu2_ssigma12 = F.conv2d(img1 * img2, window, padding=pad, groups=channels) - mu12
- Thirdly, we calculate the contrast metric according to the formula mentioned here,
contrast_metric = (2.0 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2)contrast_metric = torch.mean(contrast_metric)
- Finally, we calculate the SSIM score and return the mean according to the formula mentioned here.
numerator1 = 2 * mu12 + C1
numerator2 = 2 * sigma12 + C2
denominator1 = mu1_sq + mu2_sq + C1
denominator2 = sigma1_sq + sigma2_sq + C2ssim_score = (numerator1 * numerator2) / (denominator1 * denominator2)return ssim_score.mean()
That was a lot! Now let’s see how the code performs!
We are going to test the code in three cases to check how does it perform. Let’s get going!
- Case #1: True Image vs False Image
In the first scenario, we are going to run 2 very different Images through SSIM. One of them is considered the True Image while the other is considered the False Image. (Since we are measuring the difference, the Truth and Falsity labels are essentially interchangeable; They are being used only as reference points.)
The images are,
The code below is for representation purposes only although not much different from the code in the notebook. For more detail and visualization, check the notebook.
Code: img1 = load_images("img1.jpg") # helper function to load images
img2 = load_images("img2.jpg")_img1 = tensorify(img1) # helper function to convert cv2 image to tensors _img2 = tensorify(img2)ssim_score = ssim(_img1, _img2, 225)
print(True vs False Image SSIM Score: ", ssim_score)Output:True vs False Image SSIM Score: tensor(0.3385)
- Case #2: True Image vs True Image with Gaussian Noise
In this scenario, we compare the true image and a heavily noised version of it. The images are shown below,
On running the same piece of code as above we get,
Code: noise = np.random.randint(0, 255, (640, 480, 3)).astype(np.float32)
noisy_img = img1 + noise_img1 = tensorify(img1)
_img2 = tensorify(noisy_img)true_vs_false = ssim(_img1, _img2, val_range=255)print("True vs Noised True Image SSIM Score:", true_vs_false)Output:True vs Noised True Image SSIM Score: tensor(0.0185)
- Case #3: True Image vs True Image
In the final case, we compare the True Image against itself. Hence, the image shown below is compared to itself. If our SSIM code is working perfectly, the score should be one.
On running the piece of code shown below, we can confirm that the SSIM score for this given scenario is indeed one.
Code:_img1 = tensorify(img1)true_vs_false = ssim(_img1, _img1, val_range=255)print("True vs True Image SSIM Score:", true_vs_false)Output:True vs True Image SSIM Score: tensor(1.)
Conclusion
Finally, we are here! In this article, we covered the theory behind SSIM and the code that goes into implementing it. In the References, some additional materials are provided including links to Computer Vision literature where SSIM is used in some form.
Hope understanding SSIM was much easier for you than it was for me :). I tried to focus on the areas that I personally found complicated and difficult to understand, hoping to not only consolidate my learnings but also in the process, help somebody else stumbling along the same path ;).
Would really appreciate feedback positive or negative. You can drop it in the comments section or reach out to me at pranjaldatta99@gmail.com.
References
- The Original Paper: https://www.cns.nyu.edu/pub/eero/wang03-reprint.pdf
- High-Quality Monocular Depth Estimation via Transfer Learning: The depth estimation paper where SSIM is used as one of the loss functions. The paper can be found here. You can also check my PyTorch Implementation of this paper here.
- More on SSIM Applications: https://www.imatest.com/docs/ssim/
Appendix
1.0: Since using local statistics, the formula for the mean (μ), changes from
1.1: The variance (square of standard deviation) formula used in the Python code can be derived as,
While the derivation above shows the general method, in our context the final formula becomes,