Demystifying Neural Networks: Masked Image Recovery with AutoEncoders

The Art of Learning from Incomplete Data

Dagang Wei
4 min readFeb 5, 2024
Image generated with Bard

This article is part of the series Demystifying Neural Networks.

Introduction

In the vast and expanding universe of artificial intelligence, neural networks stand as towering giants, powering innovations and solving problems that once seemed insurmountable. Among their many applications, one particularly fascinating area is image recovery, where neural networks can fill in missing parts of images as if by magic. Today, we’ll demystify this process by focusing on a specific type of neural network known as the AutoEncoder, demonstrating its capability to recover masked images with a simple yet powerful example: the MNIST dataset.

What is an AutoEncoder?

An AutoEncoder is a type of artificial neural network used for unsupervised learning. It’s designed to encode input data into a compressed representation and then decode that representation back into the original input or something very close to it. This process of reduction and reconstruction makes AutoEncoders exceptionally good at learning efficient representations of data, making them ideal for tasks like dimensionality reduction, denoising, and, intriguingly, image recovery.

The Challenge: Masked Image Recovery

Imagine you have a photograph, but parts of it are obscured or missing. The challenge is to restore the missing parts accurately, relying solely on the information contained in the unmasked portions. This is analogous to what we aim to achieve with the MNIST dataset, where digits are partially covered, and the goal is to recover the full digit image.

How AutoEncoders Come to the Rescue

To tackle this challenge, we use an AutoEncoder structured into two main components: the encoder and the decoder.

  • The encoder learns to compress the input image into a dense, lower-dimensional representation. Despite the masking, it captures the essential features of the visible parts of the digit.
  • The decoder takes this compressed form and attempts to reconstruct the original image, filling in the masked parts in the process.

This process is akin to an artist who, upon seeing only a portion of a painting, uses their understanding of shapes, colors, and textures to complete the artwork.

Implementation with Keras

For our example, we use the MNIST dataset, which consists of thousands of handwritten digits. We modify these images by masking random parts, then train an AutoEncoder to recover the original images from their masked versions.

Step 1: Data Preparation

We start by loading the MNIST dataset, then randomly mask part of each image. This prepares our dataset for training, where the input is the masked image, and the target output is the original, unmasked image.

Step 2: Building the AutoEncoder

Our AutoEncoder uses dense layers for both encoding and decoding. The encoder compresses the image into a lower-dimensional space, while the decoder reconstructs it. The entire model is trained end-to-end with the goal of minimizing the difference between the original and reconstructed images.

Step 3: Training and Evaluation

With our AutoEncoder built, we proceed to train it using the masked images as input. The training process involves adjusting the weights of the network to minimize reconstruction errors. After training, we evaluate the model’s performance by comparing the reconstructed images against the original, unmasked images.

The code is available in this colab notebook:

import numpy as np
import matplotlib.pyplot as plt
from keras.layers import Input, Dense, Flatten, Reshape
from keras.models import Model
from keras.datasets import mnist
from sklearn.metrics import mean_squared_error

def randomly_cover_images(images, cover_size=(10, 10)):
covered_images = images.copy()
img_height, img_width = images.shape[1:3]
for img in covered_images:
top_left_x = np.random.randint(0, img_width - cover_size[0])
top_left_y = np.random.randint(0, img_height - cover_size[1])
img[top_left_y:top_left_y + cover_size[1], top_left_x:top_left_x + cover_size[0]] = 0
return covered_images

# Load and preprocess data
(x_train, _), (x_test, _) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 28, 28))
x_test = np.reshape(x_test, (len(x_test), 28, 28))

# Cover images
x_train_covered = randomly_cover_images(x_train)
x_test_covered = randomly_cover_images(x_test)

# Flatten the images for the Dense layers
x_train_flat = x_train.reshape((len(x_train), 28*28))
x_test_flat = x_test.reshape((len(x_test), 28*28))
x_train_covered_flat = x_train_covered.reshape((len(x_train_covered), 28*28))
x_test_covered_flat = x_test_covered.reshape((len(x_test_covered), 28*28))

# Building the autoencoder
input_img = Input(shape=(784,))
encoded = Dense(128, activation='relu')(input_img)
encoded = Dense(64, activation='relu')(encoded)

decoded = Dense(128, activation='relu')(encoded)
decoded = Dense(784, activation='sigmoid')(decoded)

autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')

# Train the autoencoder
autoencoder.fit(x_train_covered_flat, x_train_flat,
epochs=50,
batch_size=256,
shuffle=True,
validation_data=(x_test_covered_flat, x_test_flat))

# Evaluate the model
x_test_decoded = autoencoder.predict(x_test_covered_flat)
mse = mean_squared_error(x_test_flat.flatten(), x_test_decoded.flatten())
print(f"Mean Squared Error: {mse}")

# Visualization
def visualize_reconstruction(original, covered, reconstructed, n=10):
plt.figure(figsize=(20, 4))
for i in range(n):
# Display original
ax = plt.subplot(3, n, i + 1)
plt.imshow(original[i])
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)

# Display covered
ax = plt.subplot(3, n, i + 1 + n)
plt.imshow(covered[i])
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)

# Display reconstruction
ax = plt.subplot(3, n, i + 1 + 2*n)
plt.imshow(reconstructed[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()

visualize_reconstruction(x_test, x_test_covered, x_test_decoded)

Conclusion

The AutoEncoder’s ability to recover masked images from the MNIST dataset is a testament to the power of neural networks to learn from incomplete data. This process not only highlights the potential of AutoEncoders in image recovery tasks but also serves as a stepping stone toward more complex applications, such as restoring historical documents, enhancing low-resolution images, or even generating new content based on partial inputs.

--

--