Exploring GANs to Generate Synthetic Data

A simple step-by-step tutorial using the IRIS dataset

XQ
The Research Nest
11 min readJul 14, 2023

--

Created using Midjourney

The goals of this tutorial are simple.

  1. Train a GAN on a dataset
  2. Use the trained generator to create synthetic data
  3. Train a machine learning model on the synthetic data
  4. Use the synthetic model on the real data and check how it performs

To explore this, let’s use the Iris dataset.

The Iris dataset is one of the classic datasets in machine learning and data science, often used for teaching and learning purposes.

The dataset is comprised of measurements from 150 iris flowers, from three different species:
1. Iris-setosa
2. Iris-versicolor
3. Iris-virginica

Each species is equally represented, with 50 samples each.

For each individual flower, the dataset includes four measurements, all in centimeters:

  1. Sepal Length: The length of the sepal, which are the leaf-like structures on the outside of the flower that protect the flower bud.
  2. Sepal Width: The width of the sepal.
  3. Petal Length: The length of the petal, which are the inner parts of the flower that are often brightly colored.
  4. Petal Width: The width of the petal.

These four measurements are the features or independent variables of the dataset.

The species of each measured flower is also included, making this a supervised classification dataset. The species serve as the target or dependent variable and can take on one of three classes, corresponding to the three species.

In simple terms, the Iris dataset consists of measurements of the lengths and widths of petals and sepals from three different species of iris flowers. The goal of studying this dataset is to develop a model that can predict the species of an iris flower based on these four measurements.

Once, we understand the dataset, we can formulate our approach.

Our first goal is to create a GAN whose generator can create this synthetic data of all four attributes for the corresponding species it belongs to.

Step 1. Training the GAN

Generative Adversarial Networks, or GANs for short, are a type of artificial intelligence model that can create new data that looks like data it has seen before. It’s a bit like an artist looking at a landscape and then painting a picture resembling it.

In this case, we want to create new data that looks like the Iris dataset. This dataset is pretty simple. It has information about different kinds of iris flowers, including measurements of their petals and sepals, and what type of iris they are.

A good choice for this task would be a special type of GAN called a Conditional GAN, or cGAN for short. The ‘conditional’ part means that we give the model some extra information to help it in its task. In this case, we give the iris species class to the GAN.

So when the cGAN creates new data, it doesn’t just make random iris measurements; it makes measurements corresponding to the type of iris we told it to make. And when the cGAN checks if it has done a good job, it also considers the type of iris.

Here’s the complete sample code to define a simple GAN structure, train it and generate sample synthetic data.

Note- If you are new to Python and ML, don’t worry too much about the intricacies of the code below. Think of them as black boxes, run them directly in your system, tweak the parameters here and there, and observe what happens. Learn by doing and observation.

import numpy as np
import pandas as pd
from sklearn import datasets
from sklearn.preprocessing import MinMaxScaler, OneHotEncoder
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

# Load iris dataset
iris = datasets.load_iris()
X = iris.data # we only take the first four features.
y = iris.target

# Normalize the data
scaler = MinMaxScaler()
X = scaler.fit_transform(X)

# Convert data to pandas DataFrame
real_data = pd.DataFrame(X, columns=['a', 'b', 'c', 'd'])
real_labels = y

# One hot encode labels
one_hot_encoder = OneHotEncoder(sparse=False)
one_hot_labels = one_hot_encoder.fit_transform(np.array(real_labels).reshape(-1, 1))

# Constants
NOISE_DIM = 100
NUM_CLASSES = 3
NUM_FEATURES = 4
BATCH_SIZE = 64
TRAINING_STEPS = 5000

# Generator
def create_generator():
noise_input = Input(shape=(NOISE_DIM,))
class_input = Input(shape=(NUM_CLASSES,))
merged_input = Concatenate()([noise_input, class_input])
hidden = Dense(128, activation='relu')(merged_input)
output = Dense(NUM_FEATURES, activation='linear')(hidden)
model = Model(inputs=[noise_input, class_input], outputs=output)
return model

# Discriminator
def create_discriminator():
data_input = Input(shape=(NUM_FEATURES,))
class_input = Input(shape=(NUM_CLASSES,))
merged_input = Concatenate()([data_input, class_input])
hidden = Dense(128, activation='relu')(merged_input)
output = Dense(1, activation='sigmoid')(hidden)
model = Model(inputs=[data_input, class_input], outputs=output)
return model

# cGAN
def create_cgan(generator, discriminator):
noise_input = Input(shape=(NOISE_DIM,))
class_input = Input(shape=(NUM_CLASSES,))
generated_data = generator([noise_input, class_input])
validity = discriminator([generated_data, class_input])
model = Model(inputs=[noise_input, class_input], outputs=validity)
return model

# Create and compile the Discriminator
discriminator = create_discriminator()
discriminator.compile(loss='binary_crossentropy', optimizer=Adam())

# Create the Generator
generator = create_generator()

# Create the GAN
gan = create_cgan(generator, discriminator)

# Ensure that only the generator is trained
discriminator.trainable = False

gan.compile(loss='binary_crossentropy', optimizer=Adam())

# Train GAN
for step in range(TRAINING_STEPS):
# Select a random batch of real data with labels
idx = np.random.randint(0, real_data.shape[0], BATCH_SIZE)
real_batch = real_data.iloc[idx].values
labels_batch = one_hot_labels[idx]

# Generate a batch of new data
noise = np.random.normal(0, 1, (BATCH_SIZE, NOISE_DIM))
generated_batch = generator.predict([noise, labels_batch])

# Train the discriminator
real_loss = discriminator.train_on_batch([real_batch, labels_batch], np.ones((BATCH_SIZE, 1)))
fake_loss = discriminator.train_on_batch([generated_batch, labels_batch], np.zeros((BATCH_SIZE, 1)))
discriminator_loss = 0.5 * np.add(real_loss, fake_loss)

# Train the generator
generator_loss = gan.train_on_batch([noise, labels_batch], np.ones((BATCH_SIZE, 1)))

if step % 500 == 0:
print(f"Step: {step}, Discriminator Loss: {discriminator_loss}, Generator Loss: {generator_loss}")

# Generate instances for a given class
def generate_data(generator, data_class, num_instances):
one_hot_class = one_hot_encoder.transform(np.array([[data_class]]))
noise = np.random.normal(0, 1, (num_instances, NOISE_DIM))
generated_data = generator.predict([noise, np.repeat(one_hot_class, num_instances, axis=0)])
return pd.DataFrame(generated_data, columns=['a', 'b', 'c', 'd'])

# Generate 40 instances of class 1
generated_data = generate_data(generator, 1, 40)
print(generated_data)

Now, we have a function called “generate_data” that takes in the trained “generator,” the required data class (iris species), and the number of instances to generate the synthetic data.

Step 2: Use the trained generator to get synthetic data

Here’s the code block I used to generate 50 instances of each class and create a CSV file in a similar format to the IRIS dataset.

# Generate 50 instances for each class
synthetic_data_class_0 = generate_data(generator, 0, 50)
synthetic_data_class_1 = generate_data(generator, 1, 50)
synthetic_data_class_2 = generate_data(generator, 2, 50)

# Combine all synthetic data into a single DataFrame and apply inverse transform to bring it back to original scale
synthetic_data = pd.concat([synthetic_data_class_0, synthetic_data_class_1, synthetic_data_class_2], ignore_index=True)
synthetic_data = pd.DataFrame(scaler.inverse_transform(synthetic_data), columns=['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)'])

# Create corresponding class labels
synthetic_labels = [0]*50 + [1]*50 + [2]*50

# Add labels to the synthetic data
synthetic_data['class'] = synthetic_labels

# Save synthetic data as a CSV file
synthetic_data.to_csv('synthetic_iris_data.csv', index=False)

Let’s visualize how this synthetic data is distributed and compare it with the real data.

import pandas as pd
import matplotlib.pyplot as plt
from sklearn import datasets

# Load the Iris dataset from sklearn
iris = datasets.load_iris()
real_data = pd.DataFrame(iris.data, columns=iris.feature_names)
real_data['class'] = iris.target

# Load the synthetic dataset
synthetic_data = pd.read_csv('synthetic_iris_data.csv')

# For each feature, create a histogram for the real and synthetic data
for feature in ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']:
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.hist(real_data[feature], bins=20, alpha=0.5, color='g', label='Real')
plt.title(f"Real Data: {feature}")
plt.legend()

plt.subplot(1, 2, 2)
plt.hist(synthetic_data[feature], bins=20, alpha=0.5, color='b', label='Synthetic')
plt.title(f"Synthetic Data: {feature}")
plt.legend()

plt.show()

# Print the summary statistics for the real and synthetic data
print("Summary statistics for the real data:")
print(real_data.describe())
print("\nSummary statistics for the synthetic data:")
print(synthetic_data.describe())

# For each pair of features, create a scatter plot for the real and synthetic data
features = ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
for i in range(len(features)):
for j in range(i+1, len(features)):
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.scatter(real_data[features[i]], real_data[features[j]], alpha=0.5, color='g')
plt.title(f"Real Data: {features[i]} vs {features[j]}")

plt.subplot(1, 2, 2)
plt.scatter(synthetic_data[features[i]], synthetic_data[features[j]], alpha=0.5, color='b')
plt.title(f"Synthetic Data: {features[i]} vs {features[j]}")

plt.show()
Summary statistics for the real data:
sepal length (cm) sepal width (cm) petal length (cm) \
count 150.000000 150.000000 150.000000
mean 5.843333 3.057333 3.758000
std 0.828066 0.435866 1.765298
min 4.300000 2.000000 1.000000
25% 5.100000 2.800000 1.600000
50% 5.800000 3.000000 4.350000
75% 6.400000 3.300000 5.100000
max 7.900000 4.400000 6.900000

petal width (cm) class
count 150.000000 150.000000
mean 1.199333 1.000000
std 0.762238 0.819232
min 0.100000 0.000000
25% 0.300000 0.000000
50% 1.300000 1.000000
75% 1.800000 2.000000
max 2.500000 2.000000

Summary statistics for the synthetic data:
sepal length (cm) sepal width (cm) petal length (cm) \
count 150.000000 150.000000 150.000000
mean 6.690141 3.107407 4.123409
std 0.935154 0.272033 1.988361
min 5.264442 2.745103 1.458329
25% 5.839154 2.888982 1.511931
50% 6.476824 2.999550 4.601924
75% 7.716801 3.390262 6.146769
max 8.360364 3.828354 6.711998

petal width (cm) class
count 150.000000 150.000000
mean 1.427482 1.000000
std 0.931393 0.819232
min 0.132300 0.000000
25% 0.225225 0.000000
50% 1.629165 1.000000
75% 2.389153 2.000000
max 2.635187 2.000000

Overall, we can see that the real data feel more continuous and spread out while the synthetic data is more into focused pockets. Interestingly, the general variation, however, looks similar.

Now, we have both synthetic and real datasets. Let’s train a synthetic model using the synthetic dataset we created.

Step 3: Training ML models on synthetic data

Since the Iris dataset is fairly simple and small, a simple model like logistic regression, decision tree or a K-nearest neighbors (KNN) could work quite well. For demonstration purposes, I’ll use the K-Nearest Neighbors (KNN) model, a simple yet powerful algorithm for multiclass classification problems.

First, let’s split the synthetic data into training and testing datasets. Then, we can use scikit-learn’s KNeighborsClassifier to train the model and evaluate its performance.

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import classification_report, confusion_matrix

# Load the synthetic dataset
synthetic_data = pd.read_csv('synthetic_iris_data.csv')

# Separate the features and the target variable
X = synthetic_data.drop(columns=['class'])
y = synthetic_data['class']

# Split the dataset into training and testing datasets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Standardize the feature values
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# Train a K-Nearest Neighbors model
knn = KNeighborsClassifier(n_neighbors=3)
knn.fit(X_train, y_train)

# Make predictions on the test dataset
y_pred = knn.predict(X_test)

# Evaluate the model
print("Confusion Matrix:\n", confusion_matrix(y_test, y_pred))
print("\nClassification Report:\n", classification_report(y_test, y_pred))
Confusion Matrix:
[[10 0 0]
[ 0 9 0]
[ 0 0 11]]

Classification Report:
precision recall f1-score support

0 1.00 1.00 1.00 10
1 1.00 1.00 1.00 9
2 1.00 1.00 1.00 11

accuracy 1.00 30
macro avg 1.00 1.00 1.00 30
weighted avg 1.00 1.00 1.00 30

This script loads the synthetic Iris dataset, separates the features and the target variable, and then splits the data into training and testing datasets. It then standardizes the feature values using StandardScaler, trains a KNN model on the training data, makes predictions on the test data, and finally evaluates the model using a confusion matrix and a classification report.

You can experiment with different models and parameters to see which works best on the synthetic data. The choice of model will depend on the specific characteristics of the synthetic data and the problem at hand.

We now have a trained synthetic model saved as “knn” to use on real data. Let’s see how it actually performs.

Step 4: Use the synthetic model on real data

Let’s load the entire real dataset of 150 samples and predict the classes.

from sklearn.metrics import accuracy_score

# Load the real Iris data from sklearn
iris = datasets.load_iris()
real_data = pd.DataFrame(iris.data, columns=iris.feature_names)
real_labels = iris.target

# Standardize the real data
real_data = scaler.transform(real_data)

# Use the trained KNN model to make predictions on the real data
real_pred = knn.predict(real_data)

# Calculate the accuracy of the model on the real data
accuracy = accuracy_score(real_labels, real_pred)
print("Accuracy of the model on the real data: ", accuracy)
Accuracy of the model on the real data:  0.7933333333333333

This is an interesting result. We have used a very raw approach with no optimizations. An accuracy of 79% is better than random. So, we are getting somewhere with generating synthetic data that can be used to train models entirely or used for data augmentation when real data is scarce.

To truly compare it with a real model, we can train another KNN model with real data and test both models with just the test split of the real dataset.

# Split the real data into training and testing datasets
X_train_real, X_test_real, y_train_real, y_test_real = train_test_split(real_data, real_labels, test_size=0.2, random_state=42)

# Standardize the feature values
scaler_real = StandardScaler()
X_train_real = scaler_real.fit_transform(X_train_real)
X_test_real = scaler_real.transform(X_test_real)

# Train a K-Nearest Neighbors model on the real data
knn_real = KNeighborsClassifier(n_neighbors=3)
knn_real.fit(X_train_real, y_train_real)

# Make predictions on the test dataset
y_pred_real = knn_real.predict(X_test_real)

# Evaluate the model
accuracy_real = accuracy_score(y_test_real, y_pred_real)
print("Accuracy of the model trained on real data: ", accuracy_real)
Accuracy of the model trained on real data:  1.0

Let’s compare side-by-side with the synthetic model, evaluating the same X-test_real.

# Make predictions on the real test data with the model trained on synthetic data
y_pred_synthetic_model = knn.predict(X_test_real)

# Evaluate the model trained on synthetic data
accuracy_synthetic_model = accuracy_score(y_test_real, y_pred_synthetic_model)
print("Accuracy of the model trained on synthetic data: ", accuracy_synthetic_model)

# Make predictions on the real test data with the model trained on real data
y_pred_real_model = knn_real.predict(X_test_real)

# Evaluate the model trained on real data
accuracy_real_model = accuracy_score(y_test_real, y_pred_real_model)
print("Accuracy of the model trained on real data: ", accuracy_real_model)
Accuracy of the model trained on synthetic data:  0.9333333333333333
Accuracy of the model trained on real data: 1.0

And there we have our results.

The synthetic model clearly doesn’t do well, but that’s a great start to explore further. A research goal would be to design architecture or find parameters and fine-tuning techniques to ultimately create a synthetic model that can achieve high accuracy on real-world data.

Note: The synthetic data generated and the accuracy scores may vary on each interaction and produce different results when run multiple times. Some averages can be computed by running the experiment multiple times to understand the real picture of the accuracy of the synthetic model.

--

--

XQ
The Research Nest

Exploring tech, life, and careers through content.