[機器學習] GAN 筆記

Hoskiss
Hoskiss stand
Published in
6 min readMay 19, 2019

--

紀錄學習 GAN (Generative Adversarial Network) 的重點摘要,真心佩服李宏毅老師與知乎神文等,可以把事情說得清楚透徹令人容易理解的大師

GAN 的目標是,產生與原始資料(抽樣成training data)分佈相近的資料,所以參考上圖流程,給 generator 餵進一組分佈,產生假的資料,再跟真的資料一起餵給 discriminator,讓他來判斷分類這是真的還是假的

所以如果把前面這一坨神經網路包起來看,整個其實可以看成是個二元分類問題,談到分類問題的話,我們來了解一下GAN的 loss function,老朋友 cross entropy。情境是像這樣:考慮一次從真實 data 抽樣 m 個 x,generator (有個初始化的參數) 也同時產生 m 個 x tilde(~),先訓練 discriminator 也就是分類器,他要想辦法最大化下面這個東西 (等於 cross entropy 加個負號),前面項表示要判斷真實 data 越接近 1 越好,後面項是 fake 越接近 0 越好,理論上 discriminator 參數最好要多 update 幾次

discriminator 想最大化這坨

接下來換固定 discriminator 參數,更新 generator,但是 generator 要想辦法讓產生的結果騙過 discriminator,也就是讓 discriminator 判斷越接近 1,這時候我們的 discriminator 已經(接近)讓上面的式子最大化了,generator 要開始試圖讓後面項最小化,(前面項完全是真實 data 跟 discriminator 的事,generator 插不了手),而 min max V(G, D) 就是這個意思

generator 想最小化這坨

就這樣,基本的 GAN 就重複這個訓練過程,看起來一切都很美好,另外有什麼事嗎?還真的有,來繼續看下去這些大神們怎麼研究這個 network。長長的推導就不寫出來了,不然也只是複製貼上一遍,這邊寫下我的理解順序。當 discriminator 被訓練得很好的時候,數學式上可以把最佳的 D 代入 V,這時候會推出 V 等於一個常數加上 real data 跟 fake data 的 JS divergence,當兩個分佈沒有交集的時候,這個 JS divergence 就是個常數,那如果 loss function 是個常數,gradient descent 就毫無用武之地,generator 找不到方向變強只能躺在床上耍廢

兩個分佈真的那麼難交集嗎?沒錯啦,主要是兩個原因: 1. 我們是抽樣 data,本來就很難重疊 2. 分佈是高維中的 manifold (念一下覺得很厲害的詞啊~),意思大概就是想成三維空間中的兩條線,重疊機率很小吧,就算剛好有一點相交也是沒用的,至少必須要一小段重疊才行,gradient 才有辦法往拉近兩個分佈更新。那為什麼說是 manifold?假設 real data 是個一千維空間中的分佈,我們給 generator 一百維的 noise 叫他想辦法產生跟 real data 很像的分佈,generator 再怎麼神奇參數,他的 output 就是被一百維的 input 限制,產生實際較低維度的 manifold。所以總之如果 discriminator 太強,GAN 都很難訓練得起來,generator 產生不出很像的 data,聽起來他也很無奈,幫 QQ

故事還沒結束,實際上這個 loss function 是有調整過的版本,對於 generator 來說,用下面這項來取代原本的式子,好處是在 training 的時候,一開始的梯度比較大,參數更新比較快

但這樣也有後遺症,一個就是梯度不穩定,另一個是 mode collapse,mode collapse 就是會產生很多看起來很像的 data,例如有 mnist 的資料,要產生類似真實的手寫數字,train 到後來的圖像都產生數字 1 或某個數字,感覺 generator 在想:反正這個就能騙過 discriminator,就都給你,很像全班考卷答案都一樣的感覺 XD。WGAN 前傳提出了解釋,他推導用這樣的 loss function,對於 generator 來說,等價於最小化下面這個式子

最小化的情形就是,分別是想要減少 KL divergence 跟增加 JS divergence,這是很奇怪的事情,像是想同時拉近又推遠兩個分佈的距離,欲拒還迎是有沒有這麼曖昧,也因為如此導致梯度不太穩定。另外前面的 KL 項,他會為了不讓 loss 趨近於無窮大,( Pg 在沒有 Pr 分佈的地方產生 data ),會偏向”保守”,產生很多重複而較安全的 data,也就是看到的 mode collapse,李宏毅老師有提到,即使用了倒過來的 KL,問題也沒有被解決,就這樣,原本的 GAN 被 WGAN 作者拳拳到肉地打到感覺再起不能,他們也有實驗的 loss 圖很有說服力地證明這些論點

一些GAN範例code的loss寫法

所以,WGAN 作者精心演算了原本架構的不足之處,那後續有沒有提出解法呢?有的 (廢話不然怎麼有 WGAN),大神們不愧是大神,不過等我念完下集再紀錄吧 XD,超級感謝我看到的網路教學,帶領著讓我一步一步看懂這些本來像是天書般 的 paper 內容,希望以後有機會也能這樣讓別人理解。如果喜歡這篇的話,麻煩幫我按下拍手給些鼓勵吧,任何回饋或建議也都歡迎喔~謝謝,我們接著繼續向機器學習學習

延伸閱讀

--

--

Hoskiss
Hoskiss stand

生活是不斷成長以追求平衡的巧妙融合