Optimizing Autoencoders with Weight Tying for Unlabelled Data

Ruben Khachaturyan
6 min readJul 10, 2024

Imagine that a significant portion of your data is unlabelled. What can you do besides labeling it, which is not always feasible? Should you just set it aside until better times?

A tip from the Hands-On Machine Learning With sciktit-Learn, Keras and TensorFlow book by Aurelien Geron suggests that unlabelled data can still be used to learn essential features of a dataset before training it on labelled data. The method involves using an autoencoder.

Suppose you have a neural network for image classification. Here’s how it works:

  • Take the core of your model which can be a Stack of Convolution Layers and transform it into an autoencoder as is illustrated in Fig.1
  • Feed the data into the autoencoder, which will then try to reconstruct the input data.
  • Through this process, the core of your network learns essential features of the dataset.
Fig.1. a) the main model that has a Stack of Convolution Layers in its core; b) the Stack of Convolution Layers is used to build an autoencoder

In this setup, the Stack of Convolution Transpose Layers consists of transpose convolution layers acting as a decoder. It upsamples the features downsampled by the encoder. The key point is that the input and output are the same image, not a label.

  • Once trained, the Stack of Convolution Layers can be transferred back to the main neural network and used as a pre-trained one
  • Train the main neural network with a pre-trained Stack of Convolution Layers

Improving Efficiency with a Non-Trainable Decoder

Sounds nice, doesn’t it? But I found one more trick to improve the efficiency of this approach and it is called Weight Tying. See its implementation on the MNIST dataset using Dense layers [1], or more advanced version using Convolution Layers [2].

Typically, both the encoder and decoder learn generic features of the dataset. However, since the decoder will be discarded in further work, its training might lead to information loss. To avoid this, we can make the autoencoder with a non-trainable decoder. Instead of having trainable parameters, the decoder will inherit weights from the encoder and use them to reconstruct the downsampled image.

To implement this, we replace Conv2DTranspose with a custom layer that performs the same operation but uses the encoder’s weights. Here’s what the custom layer can look like:

import tensorflow as tf
from tensorflow import keras

class CustomConvTranspose2d(keras.layers.Layer):

def __init__(self, filters: int,
kernel_size: int = 3,
strides: tuple = (2, 2),
padding: str = "same",
data_format: str = "channels_last",
**kwargs):
super(CustomConvTranspose2d, self).__init__(**kwargs)
self.filters = filters
self.kernel_size = kernel_size
self.padding = padding
self.strides = [1, *strides, 1]
self.data_format = "NHWC" if data_format == "channels_last" else "NCHW"

def call(self, inputs, weights, bias=None):
batch_size = tf.shape(inputs)[0]
height = inputs.shape[1] * self.strides[1]
width = inputs.shape[2] * self.strides[2]
output_shape = [batch_size, height, width, self.filters]

outputs = tf.nn.conv2d_transpose(inputs, weights, output_shape,
strides=self.strides,
padding=self.padding,
data_format=self.data_format)
if bias is not None:
outputs = tf.nn.bias_add(outputs, bias,
data_format=self.data_format)

return outputs

By using this custom layer, we ensure that a CustomConvTranspose2d is not trained but mirrors the encoder’s weights.

Example

In order to illustrate I suggest to evaluate Tying Weights approach on the People Clothing Segmentation dataset that can be easily found on kaggel. Example of an image with a mask is shown in Fig.2.

Examples from a dataset with images (upper row) and corresponding masks (lower row)

Because of limited computational resources I will perform a segmentation task instead of full image recognition. Such an approach can be useful in a case where one needs to use a Stack of Convolution Transpose Layers to classify a particular objects on a dataset. The goal now is to segment a person on an image. The masks were modified correspondingly.

I suggest to consider a following Convolution Layers Stack that is used as an encoder:

down_block = Conv2D → BatchNorm → ReLU →
Conv2D → BatchNorm → ReLU →
MaxPooling2D

encoder will consist of three successive conv_block:

encoder = [down_block_1, down_block_2, down_block_3]

Note, having two down_blocks results in Dice Loss of 0.26, where is three blocks goes lower then 0.05.

The real code looks like this

import tensorflow as tf
from tensorflow import keras

# set of hard encoded parameters
CONV_KERNEL = (3, 3)
POOL_KERNEL = (2, 2)
STRIDE_SIZE = (1, 1)
PADDING_TYPE = "same"


class DownBlock(keras.layers.Layer):
"""
Description
-------
Down block has a following architecture
Conv2D → BatchNorm → ReLU →
Conv2D → BatchNorm → ReLU →
MaxPooling2D

Parameters
-------
filters_in: int, number of output filters

Returns
-------
tf.tensor
"""
def __init__(self, filter_in, **kwargs):
super(DownBlock, self).__init__(**kwargs)

self.conv_1 = keras.layers.Conv2D(filters=filter_in,
kernel_size=CONV_KERNEL,
padding=PADDING_TYPE)
self.conv_2 = keras.layers.Conv2D(filters=filter_in,
kernel_size=CONV_KERNEL,
padding=PADDING_TYPE)
self.batch_1 = keras.layers.BatchNormalization()
self.batch_2 = keras.layers.BatchNormalization()
self.relu = keras.layers.ReLU()
self.maxpool = keras.layers.MaxPooling2D(pool_size=POOL_KERNEL)

def call(self, x):
x = self.conv_1(x)
x = self.batch_1(x)
x = self.relu(x)
x = self.conv_2(x)
x = self.batch_2(x)
x = self.relu(x)
x = self.maxpool(x)
return x

Going throw the encoder the image will be downsampled to its main features. In order to perform upsampling a decoder with Stack of Convolution Transpose Layers of three consecutive Conv2DTranspose is added.

class Autoencoder(keras.models.Model):
"""
Description
-------
Autoencoder consists of
encoder: three down blocks
decoder: thre upsampling blocks

Parameters
-------
None

Returns
-------
x: tf.tensor, restored image/mask
"""
def __init__(self, **kwargs):
super(Autoencoder, self).__init__(**kwargs)
# encoder declaration
self.down_l1 = DownBlock(64)
self.down_l2 = DownBlock(128)
self.down_l3 = DownBlock(256)
# decoder declaration
self.up_l3 = keras.layers.Conv2DTranspose(128)
self.up_l2 = keras.layers.Conv2DTranspose(64)
self.up_l1 = keras.layers.Conv2DTranspose(1)

def call(self, x):
x = self.down_l1(x)
x = self.down_l2(x)
x = self.down_l3(x)
x = self.up_l3(x)
x = self.up_l2(x)
x = self.up_l1(x)
x = keras.activations.sigmoid(x)
return x

To see how replacement of decoder by a non-trainable one may affect efficiency of an autoencoder the Autoencoder_Custom is build using the Custom Transpose Convolution Stack.

class Autoencoder_Custom(keras.models.Model):
"""
Description
-------
Autoencoder consists of
encoder: three down blocks
decoder: thre upsampling blocks
which does not have
trainable parameters

Parameters
-------
None

Returns
-------
x: tf.tensor, restored image/mask
"""
def __init__(self, **kwargs):
super(Autoencoder, self).__init__(**kwargs)
# encoder declaration
self.down_l1 = DownBlock(64)
self.down_l2 = DownBlock(128)
self.down_l3 = DownBlock(256)
# decoder declaration
self.up_l3 = CustomConv2DTranspose(128)
self.up_l2 = CustomConv2DTranspose(64)
self.up_l1 = CustomConv2DTranspose(1)

def call(self, x):
x = self.down_l1(x)
x = self.down_l2(x)
x = self.down_l3(x)
# get encoder weights
w1 = self.down_l1.weights[0]
w2 = self.down_l2.weights[0]
w3 = self.down_l3.weights[0]
# use custom decoder with encoder's weights
x = self.up_l3(x,w3)
x = self.up_l2(x, w2)
x = self.up_l1(x, w1)
x = keras.activations.sigmoid(x)
return x

When build for an image size (1, 128, 128), i.e. grey-scale, we can see that model with non-trainable encoder has less amount of parameters and around 25% smaller weight compared to the classical autoencoder as is seen in Fig.3

Fig.3 Classical autoencoder (left) has 15 million parameters and weights 5.8 Mb, compared to the autoencoder withnon-trainable decoder (right) with 11million parameters and weights 4.4 Mb

To estimate models performance a Dice Loss was chosen, that essentially shows quality of overlapping of predictions and masks. Both models show very similar performance on a given dataset, see Fig.4. That means that encoder solely was enough to learn the key features of the dataset.

Fig.4 Comparison of Dice Loss decay between a classical autoencoder (orange) and autoecndqoer with non-trainable decoder (blue) after 20 epoch for image size 128x128 pixels and batch size 32 images.

The quality of a generated mask is then shown in Fig.5:

Fig.5 Two images show how a generated mask (Left) and an original mask (right) look next to each other

Conclusion

With this enhance method, one can make use of an unlabelled data to prepare a pre-trained encoder block. By utilizing this enhanced method, it is possible to leverage unlabeled data to pre-train an encoder block effectively. A comparison between a classical autoencoder and an autoencoder with a non-trainable decoder in a segmentation task demonstrates that focusing on the encoder can achieve equivalent segmentation quality. This approach effectively consolidates all learned information into the encoder, minimizing the need for a trainable decoder.

The code source can be found on my GitHub in the section Weight_Tied_Autoencoders.

References

  1. Building an Autoencoder with Tied Weights in Keras
  2. Efficient Medical Image Segmentation with Intermediate Supervision Mechanism

--

--