Chest X-Ray ShenaniGANs

Areeba Abid
The Startup
Published in
5 min readMay 23, 2020

In this post, we train a GAN to generate fake chest X-rays, looking at how small changes to learning rates have a big impact on model quality.

For beGANners:

Generative-Adversarial Networks (GANs) are trained to generate fake data that can pass as real by pitting two models against each other: a generator (G) and a discriminator (D).

A cycle GAN (two GANs hooked up to each other) generating a fake zebra video from a horse video (src)

The generator never sees any real data. It produces a sample and is penalized if the discriminator can tell that it’s fake. Using feedback from the discriminator, it learns what kinds of output can pass as real.

The discriminator’s job is a lot easier. Its task is learning to spot those fakes, picking up patterns that separate real data from the fake stuff. This model is penalized when it is fooled by the generator’s synthetic samples.

Both models are trained simultaneously, and since each is penalized when the other does well, balancing their learning rates is critical. If one model learns too fast, its feedback becomes useless to the other.

Fortunately, looking at our loss functions over the training period tells us a lot about how to adjust our learning rates.

Round 1: What I beGAN with

I started with a dataset of chest X-rays from Kaggle and boilerplate GAN code from this PyTorch tutorial. With no changes to the model architecture, the results were pretty bad:

Image 0: First round of training. Sad.

The generated images are just noise, so clearly G did not learn much. The loss functions confirm this:

Figure 0: Generator gets wrecked

Here we have a discriminator that crushes our generator. D figures out too quickly to spot G’s outputs as fake, and G doesn’t have time to learn anything.

This doesn’t mean our discriminator is very good; it has probably learned just enough to stay ahead of our clueless generator. We need some gentle competition between the two to toughen them both up.

Round 2: Let’s try aGAN

To give the generator a better shot, we’ll modify the learning rates (LR). In the model from the tutorial, both LR’s were set to 0.0002. Let’s try multiplying G’s LR by 10:

Image 1: After initial adjustments to the learning rates

We see a big improvement here. The generator is now getting a sense of the skeletal tissue around the borders of the x-rays, the spine, and the dense organs near the middle. The white blob on the left side of the spine is the heart! Cool.

Our G_loss doesn’t look so bad compared to D_loss anymore:

FIgure 1: Losses after initial adjustments to learning rates

Seems promising, so with the same parameters, let’s have it run it for longer and see what happens:

Image 2: More epochs with the same LRs

So after 10 more epochs with the same parameters, we’re starting to see more definition around the skeleton, and the outlines of the organs are a little more complex.

The loss curves have gotten kinda wacky, and don’t converge like they did before:

Figure 2: Note, the model is already semi-trained from our previous efforts, so the losses start out a lot lower.

D_loss is now greater than G_loss. This is bad for our generated image quality, since we want the discriminator to provide better and better feedback to the generator as D learns. If D suffers, then the penalties that G is basing its learning on will suffer too.

Let’s turn the learning rate of our generator back down a little. Maybe the 10-fold increase in G’s LR was too much, so we’ll reduce it from 0.002 to 0.001. (D’s LR is still at 0.0002.)

Increase the number of epochs to 30 (free GPU, so why not) and let it run:

Image 3: A long way from where we started

Not bad! I’m very happy with the rib lines that are starting to appear. We see a lot less blurriness around the neck, liver, and heart. Still a little fuzzy, but we’ve come a long way with just a few minor tweaks to the learning rates.

Here’s a link to my code.

Step 3: Taking a GANder at next steps

We’ve seen how our discriminator and our generator need to learn alongside each other at compatible rates. We’ve also seen how to use the loss graphs to adjust their relative learning rates.

We’ll stop here for now, but there’s more we can do to fine tune the learning of our model aside from learning rates:

  • Rather than increasing LR, we can try increasing the number of steps per epoch. This might help to improve resolution and address the fuzziness we saw.
  • We can also train for longer, and add early stopping, to make sure we’re reaching a point of convergence.

We’ll probably try those ideas out in a future post. Stay tuned, and remember, if I gan do this, so gan you!

--

--

Areeba Abid
The Startup

I write about machine learning and medicine. M.D. Candidate at Emory School of Medicine, ex-Google software engineer