Improving Classification Accuracy with ACGAN (Keras)

Eric Muccino
Apr 13 · 6 min read

Supervised machine learning uses labeled data to train models for classification or regression over a set of targets. The performance of a model is a function of the data that is used to train it. The less data that is available, the harder it is for a model to learn to make accurate predictions on unseen data.

GAN

Generative Adversarial Networks (GAN) are an unsupervised machine learning technique that provides us a way of generating synthetic data, imitating the distributions and relationships found within a sample set of real data. GANs work by leveraging two separate networks, a Generator and a Discriminator. The Generator learns to generate synthetic data, seeded from a randomly chosen vector in a latent space. The Discriminator learns to distinguish the real data from the synthetic data that is produced by the Generator. The Generator leverages the gradients of the discriminator in order to improve the quality of the generated data. Both networks take turns training, with each network learning from the improvements of the other.

In this post, we will see how to set up a Auxilary Classifier GAN (ACGAN) for numerical and categorical data. Then, we will run an experiment to verify the ability of synthetically generated data to improve the performance of a classification model. For this tutorial, we will use the Lending Club Loan Data set which can be acquired via Kaggle here. For a full guide on how to clean and use this data set, check out this kernel.

In this tutorial, we will be using Keras via TensorFlow 2.1.0.

Data Setup

Our goal is to predict a positive or negative loan condition based on the loan details. We will use a subset of the features available in the data set and ignoring samples with missing values. Also, we will only use a portion of the data set in order to simulate a scenario where data availability is limited.

Data Organization and Preprocessing

Next, we need to organize our data so we can use it to train our models. Since we have a combination of data types, we need to sort our features by type so we can preprocess the numeric features and categorical features separately. The numeric data is scaled to be within a uniform feature range. The text data is tokenized so that it may be quickly converted into one-hot encoded vectors, capable of being processed by a neural network. Once our features are preprocessed, we can merge them back into a unified DataFrame. We put aside 20% of the preprocessed data for testing purposes. Our testing data is not used for GAN or classifier training. It is only used to provide a final evaluation of our classifiers once they have been trained on our training data.

GAN Setup

Now it’s time to build our GAN. Since we want to generate data that is specific to a given target, we will use a form of ACGAN. The Generator is given a random seed and a specified target as input. The Generator will learn to produce a synthetic data sample that corresponds to the given target. Within the network, the latent vector and the target are merged, passed through hidden layers, then finally produce an output. Since each individual categorical feature is represented by a set of output values in the form of a one-hot encoded vector, we provide theses features an extra set of hidden layers that do not intermingle with the numeric output features. These extra hidden layers allow a stage for embedding layers to learn to produce their corresponding one-hot encoded token. Giving the text features a bottle-necked output path that is separate from the numerical features, we reduce the ability of the categorical features to dominate the gradients during training.

The Discriminator takes a data sample as input and returns a discrimination. Within the network, the categorical encodings are first processed in a manner that mirrors the method used in the Generator. The text feature encodings are then merged with the numeric data and passed through a series of hidden layers and an output layer which provides a discrimination. The discrimination is a classification of the validity of the data sample. As an ACGAN, our discriminator will predict the target of the sample, or it will determine that the sample was synthetically generated.

The complete GAN is formed by connecting the Generator and the Discriminator, so that Generator can train from the gradients of the Discriminator. We save our disjointed Generator and Discriminator models for generating synthetic data and training the Discriminator, respectively.

Batch Sampling

We need some helper functions for sampling batches of training data. We need a function for providing latent vectors and targets for training the Generator. We also need a couple of function for providing real and synthetic data for training the Discriminator. The synthetic data is generated by running inference on the Generator.

GAN Training

We are ready to set up and run a training schedule. The GAN is trained by alternating between training the Discriminator and training the Generator. The Discriminator needs to have its training attribute enabled and disabled before training the Discriminator and Generator, respectively. During training, we will want to monitor the progress of the Generator. We can do this visually by periodically plotting the distributions and relationships between real data and synthetically generated data.

Let’s take a look at some examples of our synthetically generated data vs real data. We can use a scatter plot to view relationships between numeric features and a histogram to visualize occurrences of token pairs between categorical features.

Image for post
Image for post
Left: int_rate vs grade. Right: home_ownership vs verification_status

Classification Experiment

With our GAN sufficiently trained, let’s see how we can use synthetic data to augment our real data to improve the accuracy of a Classifier. Our Classifier is designed very similarly to the Discriminator used in our GAN, with two differences. The output of our Classifier only provides predictions of the target. It does not predict the legitimacy of the data samples. The second difference is that the hidden layers have been expanded in height and width. This is to improve the expressiveness of our Classifier, increasing the risk of underfitting our data. This is done only for the sake of the experiment and serves to highlight the ability of synthetic data to aid in decision boundary sharpening and regularization.

Two Classifiers are initialized. One is trained with real data only. The other is trained with a combination of real and synthetic data, each batch being split evenly. Finally, we evaluate the performance of each classifier using the test data we have set aside.

Results and Conclusion

Image for post
Image for post

In this experiment, the Classifier trained with a combination of real and synthetic data outperformed the Classifier trained only with real data. Augmenting the real data with synthetic data resulted in an accuracy improvement of almost 10%!

GANs are able to generate synthetic data that can provide an up-sampling of under-represented classes and fill in the gaps between samples. This is especially important when classes are imbalanced or the overall quantity of data is limited.

If you plan to reproduce this experiment, I recommend experimenting with more features and/or a different sampling of data. Some further improvement could be made through model alterations as well as increased training duration.

For more on GAN and ACGAN, check out the original papers:

Mindboard

Case Studies, Insights, and Discussions of our…

Medium is an open platform where 170 million readers come to find insightful and dynamic thinking. Here, expert and undiscovered voices alike dive into the heart of any topic and bring new ideas to the surface. Learn more

Follow the writers, publications, and topics that matter to you, and you’ll see them on your homepage and in your inbox. Explore

If you have a story to tell, knowledge to share, or a perspective to offer — welcome home. It’s easy and free to post your thinking on any topic. Write on Medium

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store