Train StyleGAN2-ADA with Custom Datasets in Colab
In this article we are going to train NVIDIA’s StyleGAN2-ADA on a custom dataset in Google Colab using TensorFlow 1.14
Table of contents
- Crash Course on GANs
- StyleGAN
- How to Train StyleGAN2-ADA on Google Colab
- TensorFlow 1.x
- Mounting Google Drive
- Creating TFRecords Dataset
- Training
- References
Generative Adversarial Networks (GANs) are one of the hottest topics in computer science in recent times. They are a clever way of training a generative model (unsupervised learning) by framing the problem as a supervised learning problem.
Crash Course on GANs
The main idea is that two different models are trained simultaneously by an adversarial process.
Generative adversarial networks are based on a game theoretic scenario in which the generator network must compete against an adversary. The generator network directly produces samples. Its adversary, the discriminator network, attempts to distinguish between samples drawn from the training data and samples drawn from the generator.
The first model is called Generator and it learns to generate plausible data.
The Generator never sees the real data, it only gets a fixed-length random vector (drawn randomly from a Gaussian distribution) as input and generates a sample in the domain.
The second is the Discriminator, it takes as input the examples (real or generated) and learns to distinguish the generator’s fake data from the real data. It penalizes the generator for producing implausible results.
After the training, the discriminator model is discarded (set as untrainable) as we are interested in training only the generator.
During training, the generator progressively becomes better at creating images that look real, while the discriminator becomes better at telling them apart. The process reaches equilibrium when the discriminator can no longer distinguish real images from fakes.
StyleGAN
GANs learn to generate entirely new images that mimic the appearance of real photos. However, they offer very limited control over the generated images.
With the StyleGAN (NVIDIA), a Style-Based Generator Architecture, the Generator automatically learn to separate different aspects of the images without any human supervision. After the training, you can combine these aspects in any way we like.
The Generator thinks of an image as a collection of “styles”, where each style controls the effects at a particular scale
With this type of architecture, you can choose the strength at which each style is applied, with respect to an “average face”(in the case of the face dataset)
To learn more about StyleGAN check out this article:
StyleGAN is one of NVIDIA’s most popular generative models. Several versions of StyleGAN have been released. We’ll use the latest version, StyleGAN2-ADA, which is more suitable for small datasets.
How to Train StyleGAN2-ADA on Google Colab
Colaboratory, or “Colab” for short, is a product from Google Research. Colab allows anybody to write and execute arbitrary python code through the browser, and is especially well suited to machine learning, data analysis and education.
Google Colab is free, but I highly recommend paying the extra $10 per month for the Pro version, which offers faster GPUs, more RAM and disk, and longer runtimes, which is critical for training GANs that can run for several hours/days.
In order to run a notebook model on Google Colab you need to connect it to an hosted runtime in this way.
If you’re using the Pro version then to make use of the extra GPU Google, you need click on Change runtime type in Resources and set the Hardware accelerator to GPU
TensorFlow 1.x
StyleGAN2-ADA uses TensorFlow 1.14 and it doesn’t support TensorFlow 2.x
However, Google Colab removed support for TensorFlow 1 in their latest release so you can’t use %tensorflow_version 1.x
anymore.
It’s still possible to manually install TensorFlow 1.x through pip
. Just write the following code in the first cell in Google Colab.
Then just import it using the command import tensorflow
To check which GPU you are using you can type !nvidia-smi
Mounting Google Drive
For training the StyleGAN2-ADA we are using a custom dataset composed of .jpg stored in a folder on Google Drive, so we need to connect Colab with Google Drive.
Custom: Custom datasets can be created by placing all images under a single directory. The images must be square-shaped and they must all have the same power-of-two dimensions.
Type the following lines of code in a Google Colab cell
We first mount Google Drive, then create a directory called colab-sg2-ada
and in this directory we create another one stylegan2-ada
Once done this we clone from Github the StyleGAN2-ADA repository from NVIDIA and create another directory for the datasets
and one for the results
Creating TFRecords Dataset
The StyleGAN2-ADA and all the other StyleGANs take as input TFRecords Datasets.
If you want to use a custom dataset you need to first convert it in a TFRecords dataset.
To convert the images to multi-resoultion TFRecords we have to use the code provided with the documentation by NVIDIA:
Training
Once created the Dataset we can start the training.
The training will export network pickles (
network-snapshot-<KIMG>.pkl
) and example images (fakes<KIMG>.png
) at regular intervals (controlled by--snap
). For each pickle, it will also evaluate FID by default (controlled by--metrics
) and log the resulting scores inmetric-fid50k_full.txt
.
The snapshot_count
indicates how often you want your model to generate a sample image and a .pkl file.
This is extremely important because Google Colab may disconnect the runtime based on network usage, so you need to save the progress of the model once in a while, otherwise you will lose all the work you have done.
When Colab times out, you can take the last .pkl file saved and use it to resume the training from that model. Just copy the path of the last .pkl file and paste it in the resume_from =
line.
There are other metrics that you can change to personalize the training. Here in the example I put also mirror
and metrics
but you can find the full list in the train.py
file in NVIDIA’s StyleGAN2-ADA repository on Github.
Once you start the training you’ll see the snapshots being saved in the results
folder on Google Drive
Here, you’ll see both the fakes images and the .pkl files generated from the model at the different snapshots.
References
[1] Ian Goodfellow, Yoshua Bengio and Aaron Courville, Deep Learning (2016), The MIT Press
[2] Jason Brownlee, A Gentle Introduction to Generative Adversarial Networks (GANs) (2019)
[3] Deep Convolutional Generative Adversarial Network (2022), TensorFlow tutorials
[4] Overview of GAN Structure (2022), Google developers
[5] Tero Karras FI, A Style-Based Generator Architecture for Generative Adversarial Networks (2019), YouTube video
[6] Jonathan Hui, GAN — StyleGAN & StyleGAN2 (2020), Medium
[7] Tero Karras, Miika Aittala, Janne Hellsten, Samuli Laine, Jaakko Lehtinen and Timo Aila, StyleGAN2-ADA (2020)