教電腦畫畫:初心者的生成式對抗網路(GAN)入門筆記(TensorFlow + Python3)

計算機繪畫教室

聽來真是炫炮。

其實距離聽到生成式對抗網路(翻成中文雖然感覺很潮但真的拗口,以下簡稱GAN,Generative Adversarial Network)這個名詞也好一陣子了:第一次看到這個字是去年G社發的一篇論文;內容大致上是在說他們利用GAN這一種機器學習的方法讓計算機領域最著名的兩個人物Alice和Bob在一連串加解密的過程中發展出自己的加密法。單看這句話,可以說是集合了當今全民AI時代最風行的各個buzzword,什麼嗎訊冷凝,低普冷凝,人工智慧,機器發明,無怪乎當時那篇文章一發表也佔據了很多科技新聞的版面……啊離題了,有興趣的可以參考一下當時的報導

總之GAN的概念新歸新,不過自己一直沒有比較系統地去了解它背後的機制和原理。恰好前先時候看到了O’Reilly上推出的初心者教學,花不了多少時間,就一步一步照著做完了,順便把過程筆記一下。

GAN基本結構

Resource: https://github.com/jonbruner/generative-adversarial-networks/blob/master/gan-notebook.ipynb

首先需要大致瞭解GAN的運作模式:這裡面有兩個需要被訓練的model,一個是Discriminator network,另一個是Generator network;我們姑且稱之為偵探和工匠(腦海中第一個浮現的名詞,爬了一下其他文也有人稱畫家與鑑賞家):我們現在手上有真的data,工匠要做的事就是偽造出假的data,而偵探則是要分辨現在給他的data是真的還是假的,並且會給出一個回饋。工匠根據這個回饋來「訓練」他現在的工藝,也就是調整model的parameter;一旦工匠的工藝成熟到偵探分辨不出來誰真誰假,就可以說我們訓練出了一個能夠模擬真正data分布的model。

嘴砲那麼多,這中間疑點還是重重啊,回饋怎麼給?真假參數如何設定?偵探跟直接用loss function作回饋的區別在哪?直接看程式。

題目設定&環境

這次Tutorial設定的task是MNIST手寫辨識資料集,採用的神經網路framework是TensorFlow(1.2),語言是Python3.6。

首先我們把該資料集抓下來,TensorFlow上有包好好的指令集讓大家使用,連自己讀檔都省了,佛心啊。

MNIST的內容就是手寫的數字0–9(圖片),以下直接上範例示意,細節不再贅述。

Resource: https://corpocrat.com/wp-content/uploads/2014/10/figure_1.png

載入的MNIST訓練集會有mnist.train及mnist.validation兩個子集可以存取,這次只會使用到mnist.train.next_batch存取train set中的資料。next_batch的功能就是批次讀檔,例如mnist.train.next_batch(100)就會讀出100張手寫字,而讀出來的每張字會包含圖片本體和標籤。標籤這次不會用到,可以不用理會,圖片本體的dimension是784(28*28),在後面的處理中,我們會作一個reshape的動作(如下)方便在神經網路中進行訓練。

mnist.train.next_batch(100)[0].reshape([100, 28, 28, 1])

偵探(Discriminator network)

偵探的目的就是要分辨真假資料,在這個task上就是給一張圖片,然後輸出一個「相似度」的分數—越高表示這張圖片越像從真的dataset出來,反之則是由工匠偽造的。

Resource: https://d3ansictanv2wj.cloudfront.net/GAN_Discriminator-b767fb56c3473f66a935aa90f3b7f28b.png

一個比較特別的是,這邊最後輸出的value並沒有像一般的CNN加上sigmoid layer或softmax layer讓他維持在[1,0]之間(機率分布),這個設定是根據實際經驗得來的:這樣作有可能會讓偵探過於強大,也就是輸出的分數極端地偏向1或0,而沒有辦法有價值的回饋給工匠改進。想像一下問卷調查的結果如果只有「極度討厭」和「極度喜歡」這兩個選項,店家也很難根據武斷的評論作改善,有點類似這個情境。

在這個Tutorial裡,採用了CNN模型作為偵探;發展已久的CNN模型在圖像辨識的task上有許多很不錯的成果,TensorFlow官方文件上也有相關的教學

這個模型包含了兩個(5*5)convolution layer和兩個fully connected layer,首先我們在第一層以5*5 convolution大小抽出32個feature map,第二層64個,最後兩次fully connected輸出一個value。

其實TensorFlow已經把所有最麻煩的東西通通包成了好操作的抽象class,只要知道怎麼設定那些weight size和參數大小,剩下的就是一些碼農事了。

想深入了解CNN是什麼的話,可以參考這個詳細的說明。不然就先暫且把他當作一個黑盒子吧。

工匠(Generator network)

接下來就是工匠了,工匠的目的是要偽造圖片,因此輸出入跟偵探是相反的:工匠要輸出一個圖片,而輸入則是一個隨機數。這邊比較不太直觀,我們可以把他想成是一個random_number_generator,吃了不同的seed會輸出不同的數字;而工匠則是吃了不同的seed會輸出不同的圖片,而若是工匠訓練地非常完美,他就可以不斷地輸出跟真實手寫數字相差無幾的圖片。

Resource: https://d3ansictanv2wj.cloudfront.net/GAN_Generator-8352b780e83fd13c28cb48fa8e7a4ddb.png

工匠採用的模型像是一個逆向的CNN,假設d=100,首先我們把產生出的d-dimensional noise vector投射到3136維的vector上,接著reshape成56*56,再利用3*3的convolution layer產生50個feature map,接著25個feature map,最後把25個feature map輸出成最後的一個map。

一般來說,CNN抽取feature map的數量會隨著層數而增加,在逆向版本中則是減少。另一個細節是,在中間每次做convolution的過程裡,最後都要把每層的輸出在resize回56*56,這是因為取了strides為2*2,map出來的大小都會減半,所以最後一次輸出時是不用resize的。

至於為什麼要這樣作而不直接用28*28的大小做convolution,Tutorial的原作者有提到他不知道原因,但這樣作效果比較好。這也是trial & error出來的結果,非常的heuristic。

範例圖片

那我們來看一下我們的工匠現在手藝如何吧。

用TensorFlow一開始稍微不直覺的地方就是要先把一切的東西都設定好在開始進行一個Session,不過練習一下之後覺得這樣其實滿不錯的,有一種神經網路插拔模組化的feel。

Sample image from untrained generator

果然是偽物,該好好調整一下了。

如何訓練

首先整理一下模型之間的輸出入關係。

一行文的話就是:Gz = 假圖、Dx = 真圖的分數、Dg = 假圖的分數

要出遠門最重要的就是找一個好的導航,要訓練一個模型最重要的就是定義好他的loss function。在GAN裡面會有兩個導航方向,一個是讓偵探(discriminator)往柯南的方向移動,另一個是讓工匠(generator)往魯班的方向移動。

要讓偵探變柯南,就要分別提高偵探認出真圖的能力和認出假圖的能力,也就是

  1. 讓真圖分數和真值(1)差別最小化
  2. 讓假圖分數和假值(0)差別最小化

我們知道Dx是真圖分數,Dg是假圖分數,因此分別對這兩個分數與真值(1)和假值(0)取cross_entropy,又我們要把分數scale到一個機率分布之間,因此可以使用TensorFlow內的tf.nn.sigmoid_cross_entropy_with_logits函式,這個函式和cross_entropy一樣,只是會先對未scale過的x值取sigmoid,公式如下:

Let x = logits, z = labels, the result is:
z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))

工匠(generator)也是一樣的邏輯,只是換成讓假圖分數和1差別最小化:提高偵探對假圖的分數。

tf.ones_like(x)tf.zeros_like(x)則是會產生和x同樣大小個1和0。

最後作reduce_mean取平均值即可。

最佳化loss function

定義完loss function後,我們要設法讓loss function最佳化(ML的核心思想)。理解這部份需要的預備知識稍微多一點,有興趣可以參考這個。總之,最佳化loss function採用的方法叫做gradient descent(GD),而GD又有不同的操作方法。在這個Tutorial裡選擇的是Adam optimization。

參數是作夢夢到的,不要問為什麼(誤)。原Tutorial的作者說他們調了很久才發現較低的learning rate會有比較好的結果,這部份也是經驗法則。

一個細節是,對偵探作最佳化的時候,我們希望調整的是偵探的參數,而對工匠最佳化時亦然。不管是誰的loss function,在做計算時都有用到對方模型的參數(eg. 在算d_loss_fake時,會使用到Gz),但我們不希望optimize時一起調整,因此要限定調整的參數值,寫在minimize(var_list=…)這個變量裡。

只要事先用tf.get_variable()定義好的話,TensorFlow的架構可以容易讓我們取得不同變量的名稱。

開始訓練

設定了這麼久,終於可以開始訓練步驟了(感動QQ),TensorFlow在紀錄訓練過程中有個好用的工具TensorBoard可以使用,這邊我沒有使用,但可以參考一下。

再複習一下GAN的架構:我們有真圖,工匠出假圖;偵探認圖最佳化,工匠造假最佳化。OK,我們首先用所有的真圖和沒有訓練過的工匠做出的假圖給偵探一個新手教學。

根據原Tutorial,這個步驟的功用是讓一開始的偵探可以具有給出有方向性的feedback回工匠的能力(gradient),有點像是新人訓練的感覺。

接下來我們開始偵探 — 工匠的循環訓練。

這部份就滿直覺的了,每次的iteration都會batch出真圖和假圖,然後餵給偵探訓練;接下來出假圖,給訓練過的偵探送feedback回工匠訓練。就這樣持續不斷地來回最佳化。

初步結果

訓練一次GAN等於要一次訓練兩個CNN,沒有GPU不知道要等到何年何月,所以當然是放在學校的工作站上跑。工作站的配備是2*GeForce GTX TITAN Black,跑了100000次iteration的訓練時間花了約5個小時。

我們來看一下「工匠」的成果如何:

滿有趣的,「工匠」似乎「學會」了橫豎撇點這些筆畫的概念。

當然還是有一些「失敗」的作品,到底是想表達什麼啦!!

把前10000次iteration,每100次訓練的結果紀錄下來,可以察覺GAN嘗試去模擬圖畫的過程:

好可怕啊,天網要來惹(誤

100000次iteration,每100次iteration後紀錄一次圖片,所有產生的圖片(共1000張)都在這。而程式碼大致上和原Tutorial是相同的,請參閱原repo

一些隨想

  1. 其實在跟著實作的過程中,一直想到NP問題(笑)。在驗證一個問題是不是NP問題時,我們假設有一個上帝會給你小抄(certificate),只要你能在polynomial time裡檢驗這個小抄上的答案是不是對的,就可以說這是一個NP問題。偵探和鐵匠的關係給人一種類似的類比,當然在意義上是完全不同的。
  2. 在這過程中完全沒有用到label,因此是一種unsupervised learning。這給人很大的想像空間:GAN可以套用到任何的dataset中,並模擬出他們的distribution。所以可以訓練出各式各樣的generator:圖像的generator、影片的generator、語言的generator……。
  3. GAN結合了不同的model並交互訓練,有種用model訓練model的感覺。
  4. 很多設定和參數的調整其實都只是概念和heuristic,還有一大部分的數學細節尚未理解……。
  5. 如果直接用generator作training呢?loss function如何設計?這一定有人想過但不知道有沒有相關的成果。(書讀太少QQ)

小結

在原Tutorial的最後有提到GAN訓練非常的困難:從參數調整、不同的模型選擇、到訓練程序,所有的選項都有巨量的分支可以選擇,而且訓練時間非常長,要驗證又是另一個難題。不過,正因為這是一個新興的領域(2014年由Ian Goodfellow提出),還有許多的可能性可以繼續探索,所以……繼續交給那些被壓榨的phd研究(還有Google Brain),看有沒有human-generator被做出來的一天吧(笑)。

最後偷渡個敝校大老Yann LeCun說的話 — ”Generative Adversarial Network is the most interesting idea in the last ten years in machine learning.”

存參。

參考資料

  1. O’reilly Tutorial,原本看到的Tutorial
  2. O’reilly Tutorial Code
  3. 李弘毅老師的GAN講解,清楚明瞭。
  4. Generative Adversarial Networks, Ian GoodFellow發表的GAN原論文
  5. 2016 NIPS GAN tutorial, 一樣由Ian GoodFellow發表