Day 92 — Simple GAN for MNIST

今日主題:使用原版生成對抗網路模擬MNIST圖像資料



筆記

一百天計畫的最後幾天,終於又可以回到生成對抗網路這個主題了。在找這個主題的Keras實作程式碼的時候,發現了[1]這個很棒的資源。裡面除了原版GAN以外還實做了其他很多不同的變體。實際拿來Google Colab上測試的結果也很成功,在使用GPU的情況下訓練速度還算可以接受。

今天的Code Study就以[1]這個程式碼來寫。

完整的程式碼可以看這裡

原版的程式碼是用Python物件導向的方式來寫,以下筆記的幾個method都是在 class GAN() 底下。


  • 初始化
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.latent_dim = 100
optimizer = Adam(0.0002, 0.5)
    # Build and compile the discriminator
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
    # Build the generator
self.generator = self.build_generator()
    # The generator takes noise as input and generates imgs
z = Input(shape=(self.latent_dim,))
img = self.generator(z)
    # For the combined model we will only train the generator
self.discriminator.trainable = False
    # The discriminator takes generated images as input and 
determines validity

validity = self.discriminator(img)
    # The combined model  (stacked generator and discriminator)
# Trains the generator to fool the discriminator

self.combined = Model(z, validity)
self.combined.compile(loss='binary_crossentropy',
optimizer=optimizer)

這邊考慮了輸入的MNIST圖片是 28 * 28 的格式,而且是灰階圖所以 channel 只有 1

參考自[3],Keras中 Adam Optimizer的用法如下:

Adam(lr=0.001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)

__init__() 中引用了兩個同一個 Class 的 method: build_discriminatorbuild_generator 。這兩個 method中定義了Discriminator以及Generator的模型。

底下來看這兩個method的定義。


  • Generator
def build_generator(self):
    model = Sequential()
    model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
    model.summary()
    noise = Input(shape=(self.latent_dim,))
img = model(noise)
    return Model(noise, img)

相當好讀的一段,產生出來的模型如下:

注意在這邊Generator的Input是隨機產生的雜訊,用這個雜訊來產生模擬MNIST的圖片。


  • Discriminator
def build_discriminator(self):

model = Sequential()
model.add(Flatten(input_shape=self.img_shape))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
model.summary()
    img = Input(shape=self.img_shape)
validity = model(img)
    return Model(img, validity)

產生的模型如下:


  • 訓練流程
def train(self, epochs, batch_size=128, sample_interval=50):
# Load the dataset
(X_train, _), (_, _) = mnist.load_data()
    # Rescale -1 to 1
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)
    # Adversarial ground truths
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
    for epoch in range(epochs):
# Train Discriminator
# Select a random batch of images
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
        noise = np.random.normal(0, 1, 
(batch_size, self.latent_dim))
        # Generate a batch of new images
gen_imgs = self.generator.predict(noise)
        # Train the discriminator
d_loss_real = self.discriminator.train_on_batch(imgs, valid)
d_loss_fake = self.discriminator.train_on_batch(gen_imgs,
fake)

d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

# Train Generator
noise = np.random.normal(0, 1, (batch_size,
self.latent_dim))
# Train the generator (to have the discriminator label
samples as valid)
g_loss = self.combined.train_on_batch(noise, valid)
        # If at save interval => save progress and generated image 
samples
if epoch % sample_interval == 0:
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]"
% (epoch, d_loss[0], 100*d_loss[1], g_loss))
self.sample_images(epoch)

訓練流程也不怎麼難懂,大概的邏輯是這樣的:

在每一個epoch中,首先訓練Discriminator:

  1. 讀入真實MNIST資料
  2. 將真實的MNIST資料標記為 valid
  3. 使用(當下epoch的)Generator,給入一個雜訊產生假的MNIST資料
  4. 將假的MNIST資料標記為 fake
  5. 損失函數 d_loss 設定為真實資料的 loss d_loss_real 與假資料的loss d_loss_fake 的平均。

然後來訓練Generator:

  1. 產生雜訊
  2. combined 是完整的G-D 模型,這時候餵給它雜訊,要求期望的輸出是 valid (這代表G產生的圖片騙過D了)。
  3. Discriminator設計為會回傳一個 validity 參數,也就是D認為輸入的圖片有多大機率是假的,這個值會被用來當作G的損失函數 g_loss (代表還有多大的機率沒騙過D)。

  • 隨機看結果的Helper Function
def sample_images(self, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_imgs = self.generator.predict(noise)
    # Rescale images 0 - 1
gen_imgs = 0.5 * gen_imgs + 0.5
    fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/%d.png" % epoch)
plt.close()

這就比較沒什麼好講的,就只是隨機取 5 x 5 共25張圖片,然後每隔幾個epoch來比較看看是不是更像真實的MNIST資料了一點。

那來看一下產生出來的圖片:

Epoch 0
Epoch 5,000
Epoch 10,000
Epoch 15,000
Epoch 30,000

可以看到,在Epoch 0的時候基本上產生出來的圖片就只是雜訊而已,但是到了Epoch 5,000的時候,其實就已經蠻有樣子了。原程式碼的設定到30,000,好在Google Colab的GPU跑得算快,大概30分鐘~1小時可以跑完。這時候產生出來的圖就已經非常接近原始的MNIST圖片了。


接下來的幾天應該都會嘗試做其他不同GAN的變體,大概沒時間真正深入研究每一種變體,以後找到時間應該可以來玩一下不同的應用。