Explainable Data Generators: A Solution to overfitting in deep learning

Debasrita Chakraborty
codelogicx
Published in
5 min readJun 13, 2022

Generative adversarial networks (GANs) are a type of artificial intelligence type of neural network that is particularly good at generating artificial data. This data is usually generated from an existing dataset and has the same statistical distribution as that dataset. This means that, unlike traditional machine learning models, generative adversarial networks can even create realistic-looking images, handwriting, or voice. This can be helpful for training deep neural networks which are data-hungry and need lots of data to train. GANs may be used for data augmentation purposes.

deep learning

How does GAN data generation work?

GANs are composed of two neural networks: a generator and a discriminator (Figure 1). The generator creates fake samples while the discriminator tries to tell the difference between real and fake samples. So, there is a two-phase training — one to train the generator to generate fake samples and the other to train the discriminator to identify the original samples and fake samples. While there is always a high error rate, it can be reduced by increasing the diversity of these two networks.

Figure 1

This kind of neural network can be used for many applications, from generating fonts or generating realistic images from sketches, but has also been used for predicting Reddit posts or even creating fake news stories with sensational headlines.

Is the data generation reliable?

GANs look so fascinating for their amazing applicability. However, we must question the practicality of the data that it generates. The funny side of the story is if the GAN is trained over samples of people from all over the world (wearing all sorts of clothes and in all sorts of climatic conditions), it may end up generating a picture where a person wearing a tuxedo in a bathroom. This may seem realistic but is nonsensical. The quality of all AI algorithms depends heavily on the quality of data it sees. This level of sensibility in AI is still a far dream but we can make it somewhat reliable.

Currently, we can just try to make the GAN generate samples that can make valid data augmentations. Validity is another challenge as it depends on the data it sees. This is an egg-chicken problem again. This brings us to the condition of overfitting where the algorithm sees only a small part of the overall data and memorizes it completely.

Overfitting

Overfitting is when the network has memorized specific features of training data and cannot generalize them to new datasets. It may seem that the GAN should not suffer from that as it itself can generate data. However, overfitting happens when the discriminator heavily relies on the training data. Overfitting in GANs happens when the data generated is very similar to the real images as if some tiny modifications over the real images. Suppose we wish to generate samples for images of giraffes and our training data only has images of standing giraffes. Then, if the generator generates an unreal standing giraffe (as shown in Figure 2 — the head isn’t in the same proportion as the rest of the body), the discriminator might identify it as real but when there is a real giraffe that is sleeping (as shown in Figure 3), it would identify that as fake.

Figure 2
Figure 3

Can we avoid this?

The reliability of the samples generated by GANs can only be semantically correct if there is a way to induce an explainable side to the model’s generation capabilities. Explainable AI is a field of study in machine learning and artificial intelligence research. It seeks to make AI systems more transparent by providing humans with explanations about the learned models that are driving decision-making. If we rely on a technique to explain our model’s generations, it’s critical to grasp the approach’s underlying mechanics. Sadly, the problems stated above are still an open issue. If the training data isn’t a good representation of the overall dataset, then the data generation will never be reliable.

Figure 4

However, it would be somewhat beneficial to couple some explainable models like LIME into GAN which would help it understand which areas of the inputs are semantically important for data generation. The explainable model would both check the generated sample as well as the discriminator’s output and would accordingly adjust the generator so that the important features used by the discriminator to identify fakes can be made reliable.

The LIME algorithm is usually preferred because it is model-agnostic, which means it may be used with any machine learning model. The method tries to understand the model by perturbing some parameters of input samples and seeing how the output changes. This way insignificant areas are identified and are not used for learning. The limited significant area is used for data generation. This kind of solution may help in cases of overfitting where the dataset is too small.

Limitations

There are certain areas where the overfitting cannot be uplifted though. If the samples come from imbalanced distributions, the problem of overfitting might persist. This might be helpful in a way. Imagine a forger who wanted to make realistic documents. He won’t be able to generate such fakes unless there exist a good number of fake training samples which look too trustworthy.

--

--