[機器學習] Cycle GAN 筆記

Hoskiss
Hoskiss stand
Published in
8 min readJun 20, 2019

--

Cycle gan 是個可以用來產生不同風格/材質轉換的一個神經網路,主要應用在影像對影像的變換,從原 paper 的下圖可以觀察出一些很酷的效果,也是因為這個網路架構還蠻經典的,在此筆記起來

不過也有一些限制,所以還是要看目標是什麼來評估模型

  • 大範圍的形狀變換效果不佳
  • 物體跟背景可能不容易分別
  • 產生的圖形缺乏多樣化 ( GeneGAN的作者提到例如他們實驗發現 Cycle gan可以產生有眼鏡的影像,但都是同樣樣式的眼鏡 )
較大部分形狀的轉換,物體與背景同樣被轉換,失敗

Cycle gan 的想法是將 A domain (例如馬) 的影像轉成 B domain (例如斑馬),因為是 unsupervised learning 所以不好保證效果,但可以再從 B domain (例如斑馬) 的圖轉回 A domain (例如馬),轉回的還原圖要跟原本給的 input 越像越好,類似於 auto-encoder 的概念,以下是架構圖

看起來有點驚人(是在樂高?),但實際上就只是兩個網路的結合,分別是
把 A(馬)轉成 B(斑馬)再轉回 A(馬),把 B(斑馬)轉成 A(馬)再轉回 B(斑馬)。附帶說明上面的示意圖看起來是上下兩個架構,但實際上這兩個是同一個架構,注意 Discriminator A、B、Generator A2B、B2A 都是同一個人,只是整個網路在不同時間點有不同的 input 跟 output,所以總共是兩個 generator (產生假A、假B),兩個 discriminator (分辨 A/假A、B/假B)

那在程式實作上怎麼實現呢?我主要參考了這個 github,因為原 paper 架構也算單純,基本的 convolution layer 搭配 resnet block 等,不過這個 github 卻用 U-Net 來做 ( conditional GAN 的 generator 也是用 unet 來做),用 Unet 也就算了,還用了 ”遞迴只應天上有 凡人該當用迴圈” 的遞迴寫法,只好截圖記錄起來 ( p.s. 以下貼的程式碼都是 github 上的,但是我個人還是傾向寫出可讀性高的程式碼,包含變數名稱寧願寫長一點避免只用 g、d 這種縮寫,還有除非很好懂的遞迴,不然盡量還是用迴圈等等 )。如果想對照看 Unet 畫出來是怎樣的堆疊可以比對參考這裡 (長長的)

recursive 呼叫 block 的 Unet,用 keras plot_model 把超長架構畫出來才好懂一些

下圖是 discriminator 堆疊比較短一點就順手貼了,基本的 CNN layers 看清楚一點就比較不怕

另外補充說明,因為 generator 有使用到 conv2d_transpose 層,所以如果觀察被產生的影像時,有時候會發現棋盤效應 ( Checkerboard Artifacts ),就是會出現像棋盤一樣一格一格的現象,原因是因為在做 upsampling 的時候,每個做 convolution kernel 的數值之間很容易重疊,跟 stride 的大小也有關係,推薦 google brain 有篇動畫解釋很清楚的文章,以及說明如何用 resize (用些內插的方式直接放大) + convolution,取代原本的 conv2d_transpose,來避免這個問題,效果非常不錯,厲害了 google 的腦

接下來我們看看怎麼把網路串起來,直接看 github 的人可以略過這段,畢竟這邊我全部貼也沒意義,就貼一部分為了連貫我們的理解 (p.s 再強調一次,我個人會避免 g、d、GA 這種不好讀的縮寫)。底下這段可以看到先定義一個 function,主要是下面兩行被執行,netGA、netGB 就是兩個 generator,這邊看 function 定義裡面,real_input 想成馬,經過 G1 變成 fake_output 假班馬,這個 fake_output 假班馬再丟給 G2 當成 input,產生 cyclic 的結果,也就是變數 rec_input,要很像原本的 real_input馬

接著搭配 loss function,主要執行也是下面幾行,netDA、netDB 就是 discriminator,其他三項 real_A、fake_A、rec_A 就是上面 generator 相關的結果,可以想說這個 discriminator DA 就是判斷A類 (馬) 像不像的分類器,所以他的輸入都跟 A 有關,可以參考往上翻回馬跟斑馬的 cycle gan 架構圖

計算整個網路的 loss

分成三個部分來看

  • discriminator loss ( loss_DA、loss_DB ),看 function 內定義,有兩項( loss_D_real、loss_D_fake ),一項是真的馬跟斑馬 output_real 與 1 / True 的距離,也就是說他要去學我們餵進去的資料是分類到 True 的這一類,以及假的馬跟斑馬 output_fake ( generator產生 )與 0 / False 的距離 ),他要去學這是 generator 產生的,試圖分類到 False
  • generator loss ( loss_GA、loss_GB ),看 function 內定義,假的馬跟斑馬 output_fake 與 1 / True 的距離,也就是說 generator 他要去學怎麼樣讓自己產生的假圖被 discriminator 認作是真的
  • cycle-consistency loss ( loss_cycA、loss_cycB ),看 function 內定義,經過兩個 generator 產生的 cyclic 馬 (變數 rec ),跟最原本的餵進去真實資料馬( 變數 real ),這兩者直接相減的差異

是不是蠻有(ㄕㄚˇ)趣(ㄧㄢˇ)呢?了解完這段我覺得就能體會 cycle gan 的精神囉,等於是兩個基本的 DCGAN 加上 cycle 的 input要很像 output 的概念,理解時覺得概念不複雜但卻很巧妙的架構跟想法呢,雖然實際 training 的時候,loss 看起來都震盪的很誇張~

也可能是對稱的關係,覺得這個網路很漂亮,也許以後有機會做些有趣的應用。如果有任何回饋或指點都非常歡迎,喜歡這篇的話也可以拍拍手按個讚謝謝~,我們接著繼續向機器學習學習

對稱美(?

延伸閱讀

--

--

Hoskiss
Hoskiss stand

生活是不斷成長以追求平衡的巧妙融合