Optimizing Autoencoders with Weight Tying for Unlabelled Data
Imagine that a significant portion of your data is unlabelled. What can you do besides labeling it, which is not always feasible? Should you just set it aside until better times?
A tip from the Hands-On Machine Learning With sciktit-Learn, Keras and TensorFlow book by Aurelien Geron suggests that unlabelled data can still be used to learn essential features of a dataset before training it on labelled data. The method involves using an autoencoder.
Suppose you have a neural network for image classification. Here’s how it works:
- Take the core of your model which can be a Stack of Convolution Layers and transform it into an autoencoder as is illustrated in Fig.1
- Feed the data into the autoencoder, which will then try to reconstruct the input data.
- Through this process, the core of your network learns essential features of the dataset.
In this setup, the Stack of Convolution Transpose Layers consists of transpose convolution layers acting as a decoder. It upsamples the features downsampled by the encoder. The key point is that the input and output are the same image, not a label.
- Once trained, the Stack of Convolution Layers can be transferred back to the main neural network and used as a pre-trained one
- Train the main neural network with a pre-trained Stack of Convolution Layers
Improving Efficiency with a Non-Trainable Decoder
Sounds nice, doesn’t it? But I found one more trick to improve the efficiency of this approach and it is called Weight Tying. See its implementation on the MNIST dataset using Dense layers [1], or more advanced version using Convolution Layers [2].
Typically, both the encoder and decoder learn generic features of the dataset. However, since the decoder will be discarded in further work, its training might lead to information loss. To avoid this, we can make the autoencoder with a non-trainable decoder. Instead of having trainable parameters, the decoder will inherit weights from the encoder and use them to reconstruct the downsampled image.
To implement this, we replace Conv2DTranspose with a custom layer that performs the same operation but uses the encoder’s weights. Here’s what the custom layer can look like:
import tensorflow as tf
from tensorflow import keras
class CustomConvTranspose2d(keras.layers.Layer):
def __init__(self, filters: int,
kernel_size: int = 3,
strides: tuple = (2, 2),
padding: str = "same",
data_format: str = "channels_last",
**kwargs):
super(CustomConvTranspose2d, self).__init__(**kwargs)
self.filters = filters
self.kernel_size = kernel_size
self.padding = padding
self.strides = [1, *strides, 1]
self.data_format = "NHWC" if data_format == "channels_last" else "NCHW"
def call(self, inputs, weights, bias=None):
batch_size = tf.shape(inputs)[0]
height = inputs.shape[1] * self.strides[1]
width = inputs.shape[2] * self.strides[2]
output_shape = [batch_size, height, width, self.filters]
outputs = tf.nn.conv2d_transpose(inputs, weights, output_shape,
strides=self.strides,
padding=self.padding,
data_format=self.data_format)
if bias is not None:
outputs = tf.nn.bias_add(outputs, bias,
data_format=self.data_format)
return outputs
By using this custom layer, we ensure that a CustomConvTranspose2d is not trained but mirrors the encoder’s weights.
Example
In order to illustrate I suggest to evaluate Tying Weights approach on the People Clothing Segmentation dataset that can be easily found on kaggel. Example of an image with a mask is shown in Fig.2.
Because of limited computational resources I will perform a segmentation task instead of full image recognition. Such an approach can be useful in a case where one needs to use a Stack of Convolution Transpose Layers to classify a particular objects on a dataset. The goal now is to segment a person on an image. The masks were modified correspondingly.
I suggest to consider a following Convolution Layers Stack that is used as an encoder:
down_block = Conv2D → BatchNorm → ReLU →
Conv2D → BatchNorm → ReLU →
MaxPooling2D
encoder will consist of three successive conv_block:
encoder = [down_block_1, down_block_2, down_block_3]
Note, having two down_blocks results in Dice Loss of 0.26, where is three blocks goes lower then 0.05.
The real code looks like this
import tensorflow as tf
from tensorflow import keras
# set of hard encoded parameters
CONV_KERNEL = (3, 3)
POOL_KERNEL = (2, 2)
STRIDE_SIZE = (1, 1)
PADDING_TYPE = "same"
class DownBlock(keras.layers.Layer):
"""
Description
-------
Down block has a following architecture
Conv2D → BatchNorm → ReLU →
Conv2D → BatchNorm → ReLU →
MaxPooling2D
Parameters
-------
filters_in: int, number of output filters
Returns
-------
tf.tensor
"""
def __init__(self, filter_in, **kwargs):
super(DownBlock, self).__init__(**kwargs)
self.conv_1 = keras.layers.Conv2D(filters=filter_in,
kernel_size=CONV_KERNEL,
padding=PADDING_TYPE)
self.conv_2 = keras.layers.Conv2D(filters=filter_in,
kernel_size=CONV_KERNEL,
padding=PADDING_TYPE)
self.batch_1 = keras.layers.BatchNormalization()
self.batch_2 = keras.layers.BatchNormalization()
self.relu = keras.layers.ReLU()
self.maxpool = keras.layers.MaxPooling2D(pool_size=POOL_KERNEL)
def call(self, x):
x = self.conv_1(x)
x = self.batch_1(x)
x = self.relu(x)
x = self.conv_2(x)
x = self.batch_2(x)
x = self.relu(x)
x = self.maxpool(x)
return x
Going throw the encoder the image will be downsampled to its main features. In order to perform upsampling a decoder with Stack of Convolution Transpose Layers of three consecutive Conv2DTranspose is added.
class Autoencoder(keras.models.Model):
"""
Description
-------
Autoencoder consists of
encoder: three down blocks
decoder: thre upsampling blocks
Parameters
-------
None
Returns
-------
x: tf.tensor, restored image/mask
"""
def __init__(self, **kwargs):
super(Autoencoder, self).__init__(**kwargs)
# encoder declaration
self.down_l1 = DownBlock(64)
self.down_l2 = DownBlock(128)
self.down_l3 = DownBlock(256)
# decoder declaration
self.up_l3 = keras.layers.Conv2DTranspose(128)
self.up_l2 = keras.layers.Conv2DTranspose(64)
self.up_l1 = keras.layers.Conv2DTranspose(1)
def call(self, x):
x = self.down_l1(x)
x = self.down_l2(x)
x = self.down_l3(x)
x = self.up_l3(x)
x = self.up_l2(x)
x = self.up_l1(x)
x = keras.activations.sigmoid(x)
return x
To see how replacement of decoder by a non-trainable one may affect efficiency of an autoencoder the Autoencoder_Custom is build using the Custom Transpose Convolution Stack.
class Autoencoder_Custom(keras.models.Model):
"""
Description
-------
Autoencoder consists of
encoder: three down blocks
decoder: thre upsampling blocks
which does not have
trainable parameters
Parameters
-------
None
Returns
-------
x: tf.tensor, restored image/mask
"""
def __init__(self, **kwargs):
super(Autoencoder, self).__init__(**kwargs)
# encoder declaration
self.down_l1 = DownBlock(64)
self.down_l2 = DownBlock(128)
self.down_l3 = DownBlock(256)
# decoder declaration
self.up_l3 = CustomConv2DTranspose(128)
self.up_l2 = CustomConv2DTranspose(64)
self.up_l1 = CustomConv2DTranspose(1)
def call(self, x):
x = self.down_l1(x)
x = self.down_l2(x)
x = self.down_l3(x)
# get encoder weights
w1 = self.down_l1.weights[0]
w2 = self.down_l2.weights[0]
w3 = self.down_l3.weights[0]
# use custom decoder with encoder's weights
x = self.up_l3(x,w3)
x = self.up_l2(x, w2)
x = self.up_l1(x, w1)
x = keras.activations.sigmoid(x)
return x
When build for an image size (1, 128, 128), i.e. grey-scale, we can see that model with non-trainable encoder has less amount of parameters and around 25% smaller weight compared to the classical autoencoder as is seen in Fig.3
To estimate models performance a Dice Loss was chosen, that essentially shows quality of overlapping of predictions and masks. Both models show very similar performance on a given dataset, see Fig.4. That means that encoder solely was enough to learn the key features of the dataset.
The quality of a generated mask is then shown in Fig.5:
Conclusion
With this enhance method, one can make use of an unlabelled data to prepare a pre-trained encoder block. By utilizing this enhanced method, it is possible to leverage unlabeled data to pre-train an encoder block effectively. A comparison between a classical autoencoder and an autoencoder with a non-trainable decoder in a segmentation task demonstrates that focusing on the encoder can achieve equivalent segmentation quality. This approach effectively consolidates all learned information into the encoder, minimizing the need for a trainable decoder.
The code source can be found on my GitHub in the section Weight_Tied_Autoencoders.