Demystifying Neural Networks: Text-to-Image with AutoEncoder

Dagang Wei
4 min readFeb 4, 2024

--

Image Generated with DALL-E

This article is part of the series Demystifying Neural Networks.

Introduction

In the world of artificial intelligence, the ability to generate detailed images from textual descriptions is a groundbreaking advancement. Among the myriad of techniques for achieving this, the use of autoencoders presents a unique and efficient pathway. In this blog post, we’ll explore how to use an autoencoder neural network, trained on the MNIST dataset, to create a system that can generate images from text descriptions.

How does it work?

Learning Image Embeddings

The journey begins with the construction of an autoencoder neural network. An autoencoder is a type of artificial neural network used to learn efficient codings of unlabeled data. The network is designed to encode inputs into a compact representation (embeddings) and then decode these embeddings to reconstruct the inputs as closely as possible.

For our purpose, we utilize the MNIST dataset, a collection of handwritten digits, to train our autoencoder. The MNIST dataset is ideal for this task due to its simplicity and the distinct characteristics of each digit, making it easier to learn the embeddings. The training process involves feeding the MNIST images into the autoencoder, which then attempts to compress and reconstruct the images. Through this process, the network learns the essential features of the digits, resulting in a set of embeddings that represent the original images in a reduced dimensionality space.

Learning Label Embeddings

With the embeddings from the MNIST images obtained, the next step involves training another neural network. This network is somewhat different; its input is labels (in our case, the digits 0 through 9), and its output is the embeddings corresponding to these labels. The goal here is to teach the network to understand the embeddings of the labels based on the embeddings generated from the first neural network.

This step is crucial as it bridges the gap between raw numerical labels and their visual representations in the form of embeddings. By training the network with labels as inputs and their corresponding image embeddings as outputs, we effectively teach the model to map textual descriptions (in this case, simple digit labels) to a space that represents visual content.

Generating Images from Text

The final step in our journey is the most exciting part: generating images based on text descriptions. With the trained model from Step 2, we can now input textual descriptions (again, the digits 0 through 9 for our MNIST example) and obtain their corresponding embeddings. These embeddings are then fed into the decoder part of the autoencoder built in Step 1.

The decoder uses these embeddings to reconstruct the visual representation of the input text. Essentially, it translates the embeddings back into images. For instance, if we input the label “5” into our system, it retrieves the embedding for “5”, which the decoder then uses to generate an image of the handwritten digit 5.

Implementation with Keras

The code is available in this colab notebook.

import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras.layers import Input, Dense, Flatten, Reshape
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils import to_categorical


# Load and preprocess MNIST data
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.
x_train = np.reshape(x_train, (len(x_train), 28, 28, 1))
x_test = np.reshape(x_test, (len(x_test), 28, 28, 1))

embedding_dim = 128

# Define the encoder using the Sequential API
encoder = Sequential([
Flatten(input_shape=(28, 28, 1)),
Dense(2 * embedding_dim, activation='relu'),
Dense(embedding_dim, activation='relu')
])

# Define the decoder using the Sequential API
decoder = Sequential([
Dense(2 * embedding_dim, activation='relu'),
Dense(784, activation='sigmoid', input_shape=(embedding_dim,)),
Reshape((28, 28, 1))
])

# Combine encoder and decoder to create the autoencoder using the functional API
input_img = Input(shape=(28, 28, 1))
encoded_img = encoder(input_img)
decoded_img = decoder(encoded_img)
autoencoder = Model(inputs=input_img, outputs=decoded_img)
autoencoder.compile(optimizer=Adam(), loss='binary_crossentropy')

# Train autoencoder
autoencoder.fit(x_train, x_train, epochs=50, batch_size=256, shuffle=True, validation_data=(x_test, x_test))

# Prepare label data for the label-to-embedding model
y_train_cat = to_categorical(y_train, 10)
y_test_cat = to_categorical(y_test, 10)

# Train a label-to-embedding model
model_label_to_embedding = Sequential([
Dense(embedding_dim, activation='relu', input_shape=(10,))
])

# Compile the model
model_label_to_embedding.compile(optimizer=Adam(), loss='mean_squared_error')

# Use encoder to generate embeddings
y_train_embeddings = encoder.predict(x_train)
y_test_embeddings = encoder.predict(x_test)

model_label_to_embedding.fit(y_train_cat, y_train_embeddings, epochs=50, batch_size=256, shuffle=True, validation_data=(y_test_cat, y_test_embeddings))

# Generate images from labels
def generate_images_from_labels(labels, model_label_to_embedding, decoder):
labels_cat = to_categorical(labels, 10) # Convert labels to one-hot
predicted_embeddings = model_label_to_embedding.predict(labels_cat)
decoded_images = decoder.predict(predicted_embeddings)
return decoded_images.reshape(-1, 28, 28)

labels = np.arange(10) # Example labels: 0 through 9
generated_images = generate_images_from_labels(labels, model_label_to_embedding, decoder)

# Visualization of generated images
plt.figure(figsize=(20, 4))
for i, img in enumerate(generated_images, 1):
ax = plt.subplot(1, 10, i)
plt.imshow(img.squeeze(), cmap='gray') # Use squeeze() to remove single-dimensional entries from the shape
ax.set_title(f"Label: {labels[i-1]}")
plt.axis('off')
plt.tight_layout()
plt.show()

Conclusion

The text-to-image generation process using an autoencoder and MNIST data showcases the power of neural networks in bridging the gap between textual descriptions and visual representations. This method, while demonstrated on a relatively simple dataset, lays the groundwork for more complex and detailed image generation tasks. By understanding and leveraging the underlying principles, researchers and developers can apply similar techniques to generate images from more detailed text descriptions, opening up new possibilities in the field of AI and machine learning.

--

--