Sampling by GAN — A Simple Case Study

Nelson Tsai
Taiwan AI Academy
Published in
15 min readMay 5, 2020

在做機器學習時,是不是經常為缺乏資料、樣本所苦呢?其實不只機器學習,在數值模擬與統計分析也常常碰到這個問題。用數值方法產生樣本不難,難的是產生出來的數據要能逼近原始的資料分佈。本文就以簡單的實例來告訴大家,如何利用生成對抗網路 (Generative Adversarial Network, or simply GAN) 來產生符合目標分佈的合理樣本。

Sampling from normal distribution

想像在一個賣場裡,假如我們統計一下當天來客的身高分佈,或者考慮某學校考完試後,學生們最終分數的分佈,都不難發現分佈曲線長得就像一個鐘形;也就是所謂的高斯或常態分佈 (normal distribution),具有中央高兩邊低的形狀 (如 Fig. 1 中的橘線) [1]。如果沒看到數學式就覺得虛的朋友,其機率分佈函數 (p.d.f.) 長得像這樣:

Probability density function (p.d.f) of the normal (Gaussian) distribution

不難發現這個函數是由兩個參數來描述,一個是 μ,說明這個分佈的平均值,比如說賣場來客的平均身高;另一個是 σ,說明分佈的集中程度,值越小分佈越集中。我們將這個分佈寫成 N(μ, σ)。

Fig. 1 Left: Histogram (probability density) of complete height data from customers. Right: Same plot but with insufficient sample data. Note that the orange curve represents the underlying normal distribution.

在賣場的例子中,假使因為某些理由我們拿不到足夠人數的資料,那麼整個身高分佈直方圖看起來就會像 Fig. 1 右圖,跟左圖具完整資料的“完美”分佈比較起來就顯得不具代表性。所以問題來了,我們有沒有可能產生 (模擬) 一組隨機、多數量的數據樣本使得其符合給定目標機率分佈呢?這就是本文這次想分享的主題。當然為了簡單起見,就讓我們限定目標為常態分佈。

Monte Carlo way

要產生具目標分佈的隨機樣本一個常用手段就是藉由蒙地卡羅 (Monte Carlo) 模擬 [2]。但模擬總要有個起點,此法的一個常用起點便是均勻分布的隨機變數。因此,我們的任務用數學一點的說法就是如何利用 0 到 1 間的均勻分佈 U(0, 1) 轉換成常態分佈 N(0, 1) (為了簡化說明,我們取 μ=0, σ=1)。

這類轉換有個標準的作法叫 Inversion of the Cumulative Distribution Function,只是此法會受限於須找出目標分佈可解析的轉換函數。所以我們這裡示範能推廣並最具有蒙地卡羅精神的“接受-拒絕採樣” (Acceptance-Rejection Sampling) 演算法來完成任務。先不管其中的數學證明,我們直接來看一下它是如何操作的:

Fig. 2 The pseudo-code of Acceptance-Rejection Sampling [3].

讓我們視 g, f 分別為起點分佈、目標分佈的 p.d.f.。找到一個常數 c 使得 x 在給定的範圍內(採樣袋子),cg(x) 總是大於等於 f(x);同時再找一個均勻分佈的“骰子”,其跑出的數值永遠介於 0 跟 1 之間。接下來就是採樣的迴圈:首先根據起點分佈在採樣袋子取出一個值 Y 並計算比值 f(Y)/cg(Y),然後擲“骰子”得到數值 U。若 U 大於比值就從頭再採樣,反之我們就接受 Y 為我們的輸出。採樣夠多的點後,這樣就可以得到具目標分佈的隨機樣本了。下圖是採樣範圍限定 [-5, 5],利用蒙地卡羅法由均勻分佈轉換到常態分佈的結果 (10,000個採樣點)。同時結果也跟常用的套件 numpy.random 中 normal 函數產生出來的結果比較,不難看出相當一致。[Python code 可以參看連結]

Fig. 3 Monte Carlo results for making a Gaussian distribution from a uniform distribution; data points are within -5 and 5 [compared with the one generated by numpy.random.normal (orange curve)]

GAN way

除了蒙地卡羅方法可產生具目標分佈的隨機樣本外,當然還有本文的主角 — 生成對抗網路 (GAN)。那甚麼是 GAN 呢?先談談談本質。簡單來說 GAN 中的 G 就是指生成 (generative) 模型,而 GAN 整體是訓練生成模型中深度學習架構的一種。生成模型主要是學出資料(或特徵)與標籤的聯合機率,不同於判別 (discriminative) 模型學習給定資料得出標籤分佈的條件機率。前者不僅能利用貝氏法則來辨別資料種類,更重要的是能在同一類中生成新的樣本,而後者僅提供最佳的分類邊界,注重辨別不同資料種類間的差異。因此,生成模型所帶的訊息顯然比辨別模型更為豐富,但是需要付出的代價就是更難訓練。

Fig. 4 Discriminative vs. Generative

被大神Yann LeCun視為近10年來機器學習發展中最有意思的想法 [4]的 GAN是2014年 Ian Goodfellow 等人提出[5],其中架構主要包含了兩個互相競爭的神經網路模型:

  • A generator G,輸入 random noise z 產生目標樣本 x
  • A discriminator D,輸入 generator 產生的樣本 x 或真實樣本 y 並判斷是不是真的

然後利用所謂的對抗式訓練來訓練模型,也就是 GAN 中的 A (adversarial)。舉例來說,假使我們的目標是希望產生貓的自然圖片,一開始 generator 產生了一張圖 (如 Fig. 5 左上,顯然品質欠佳),結果被 discriminator 判為假,因此再次訓練 generator 來混淆 discriminator;當 discriminator 被混淆,也會再次訓練加強判斷能力以判斷 generator 產生的圖片。如此,經過兩者反覆競爭(魔高一尺,道高一丈;道高一丈,魔高十丈?) 最後生成的圖片就很難辨真假了。

Fig. 5 Adversarial training. Source: https://www.tensor flow.org/tutorials/generative/dcgan

Generator of 1D normal distribution

回到我們的主線任務,我們希望訓練出一個 generator 生成具常態分佈的隨機樣本,那該怎麼做呢?當然要先架一個 generator?Errrr…不是。在此剛好藉機提醒大家,回到所有機器學習的起點,第一個需要準備的就是數據。沒有數據就無法建模。

所以讓我們先製造 “real” 資料分佈,也就是具常態分佈的隨機樣本。為了說明簡單,我們樣本值取 -5 到 5 之間且分佈平均為 0、標準差為 1。製造方法可以用前面提到的蒙地卡羅或直接用 numpy 套件產生。不囉嗦,直接上code。

Fig. 6 Real data and noise input classes.

將 RealDistribution 製造出來的隨機樣本畫成直方圖會長得跟 Fig. 3 類似。同時,generator 也需要隨機樣本 (noise) 作為輸入,所以我們在 -5 到 5 之間均勻採樣,並且每個樣本點都隨機加上介於 0 到 0.01 之間的微擾,最終寫成NoiseDistribution class。這樣子我們就可以視等等要建立的 generator 為一個函數 G,輸入一個均勻分佈樣本點 z,輸出一個常態分佈樣本點 x = G(z)。

由於我們的任務相對簡單,generator 的架構長這樣就可以:

Fig. 7 Model architecture of the generator

輸入、輸出的 dimension 取 1,而中間通過一層以 ReLU 為激活函數具 32 個節點的 hidden layer。這裡我們用 pytorch framework 來建模。

Fig. 8 Generator class.

Discriminator of 1D normal distribution

其實 discriminator 與 generator 架構的差異也不大,同樣長得像 Fig. 7。唯一的差別是 discriminator 的 output 會再經過 sigmoid 激活函數。

Fig. 9 Discriminator class.

所以 Discriminator 可視為一個函數 D ,當輸入一個 generator 產生的“假”樣本點 G(z) 或是“真”樣本 y ,它就會輸出一個不大於 1 的正值。越接近 1 表示它認為是“真”的。

Training our GAN model and final results

建好 generator 與 discriminator 後,接下來就是訓練的問題。模型訓練需要定義兩個東西,一是優化器,再來是損失函數。在這次的任務中,我們選用最簡單的就行,比如說 stochastic gradient descent optimizer。那麼損失函數呢?基本上有兩個。先說針對 discriminator 的。很自然地,既然要求它能分辨真假,那麼最佳情況在輸入“真” (“假”) 樣本點其輸出當然要是 1 (0)而且損失值要是 0,而其他情況損失值都必須大於 0。這個不正是標準的二元分類問題?因此我們可以採用所謂的 binary cross entropy loss [6]來分別考慮輸出 D(x)、D(G(z)):

discriminator loss = -log(D(x)) -log(1-D(G(z)))

Fig. 10 Computing the discriminator loss

接下來是 generator 部分。它的任務是要糊弄 discriminator,所以其損失函數就變成

generator loss = -log(D(G(z)))

Fig. 11 Computing the generator loss

也就是說當 discriminator 覺得它產生出來的樣本點為“真”,則損失值才為0。不知大家有沒有注意到跟一般監督式學習不同,損失函數並非直接比較generator 的輸出與固定的標準答案而是透過還在學習中的 discriminator 的輸出?其實,我們可以看做 generator 用了一整個神經網路 D 來當損失函數,所以效果可以很好 (如果訓練成功的話),但這也是 GAN 訓練無法容易收斂的根本原因。

設置好優化器與損失函數後就可以開始訓練了。這裡假設真實資料 y 有10,000 個樣本點、batch size 為 10、優化器 learning rate 取 0.001,先看看最開始訓練 1 個 epoch 後的結果:

Fig. 12 Distribution (histogram) for the real and generated samples after 1 epoch training.

除了橘線、綠線分別代表真實與生成樣本分佈外,decision boundary (藍線) 代表的是當輸入樣本點 x,discriminator 的輸出值 D(x)。所以越是兩側的值,它越不認為是“真”的。訓練到 2,000 個 epoch 後,結果變成:

Fig. 13 Distribution (Histogram) for the real and generated samples after 2,000-epoch training.

可以看得出來,生成樣本的分佈越來越貼近常態分佈;同時,decision boundary 在整個值域都接近 0.5,這說明了 generator 糊弄 discriminator 成功。整個訓練過程的變化如下圖所示。

Fig. 14 Training evolution

完整的demo code可以參考這個連結[7]。

How to improve our GAN training?

雖然結果看起來還不錯,但實際上訓練時大家可能會發現有時會訓練不起來或整組壞掉。其實 GAN 是有名的難訓練,人們常常碰到的問題除了之前提過的不易收斂外,還有 generator 只生出單調的樣本或 gradient 驟減的可能 [8]。下面提供幾個小訣竅或許能進一步改善本次的訓練任務:

  • Pretraining

眼尖的讀者可能會注意到 Fig. 12 中的 decision boundary 一開始就很鐘形,為什麼?其實這個例子中我們有預訓練 discriminator。對於真實樣本 y ,我們希望一開始的 D(y) 越接近真實樣本的 p.d.f. 越好。這樣 discriminator 才有鑑別力督促一開始僅產生平均分佈樣本的 generator 前進。

  • Minibatch discrimination

這個想法主要來自 openAI 團隊 2016 年的工作 [9]。數學稍微難點但想法卻不難了解。前面提過 generator 可能只會產出單調的樣本,所以我們可以在discriminator 中多加入額外的訊息,例如在每次 batch 輸入的資料中計算這個batch 資料間的特徵是否相同,若太相似則損失值會大增,從而減少樣本過分單調的問題。

  • WGAN & WGAN with gradient penalty (GP)

從數學上來說,目前要最佳化的 GAN 損失函數實質上就是所謂的 Jensen–Shannon Divergence [10],代表著訓練出來的生成樣本分佈與真實目標樣本分佈的“距離”。只可惜這個“距離”往往無法反映遠近而是不連續的增減 (當兩個分佈無重疊時),這樣使得訓練失去“方向”。所以人們找到了另一種更好代表“距離”的函數,Wasserstein Distance,於是出現了所謂的 WGAN 與進階版WGAN-GP。有興趣的讀者可以參看 [10]。

最後就在此告一段落,相信大家已迫不及待想自己 train 一發看看?上述的demo code 同樣放在這個連結,大家可以自己試試,Good Luck!

References & further readings

[1] https://en.wikipedia.org/wiki/Normal_distribution

[2] https://en.wikipedia.org/wiki/Monte_Carlo_method, https://en.wikipedia.org/wiki/Metropolis%E2%80%93Hastings_algorithm

[3] https://statweb.stanford.edu/~owen/mc/

[4] https://www.quora.com/What-are-some-recent-and-potentially-upcoming-breakthroughs-in-deep-learning

[5] Ian J. Goodfellow et al., “Generative Adversarial Networks”, arXiv: 1406.2661 (2014).

[6] https://gombru.github.io/2018/05/23/cross_entropy_loss/

[7] Demo code shown in this short essay is basically modified from several pioneers and, in particular, they are Seungwon Park, Eric Jang, John Glover, Jason Brownlee and Minjae Kim. Really appreciate their nice works and explanations.

[8] Jonathan Hui, “GAN — Why it is so hard to train Generative Adversarial Networks!”

[9] Tim Salimans et al., “Improved Techniques for Training GANs”, arXiv: 1606.03498 (2016).

[10] https://lilianweng.github.io/lil-log/2017/08/20/from-GAN-to-WGAN.html

--

--