Unet系列(2) — Unet++實作

Martin Huang
Unet與FCN系列
Published in
Jul 8, 2022

本篇的論文分析對應此篇

這次我們一樣寫一個彈性架構,只要輸入想要的深度,就可以自動架出對應的神經網路。

Unet++的架構

先把架構圖貼出來。

圖片來源:參考資料[1]

和Unet的差別主要是在中間的單元。仔細觀察這些單元,會發現有三個性質:

  1. 他們不會把資訊往下傳,而只會往右邊,或往上。
  2. 每一個單元,在垂直層只會接收來自下一層的資訊,但在水平層會接收來自所有左邊單元的資訊。
  3. 最右邊一排的資訊只往上傳。最上一層的資訊只往右傳,但同時也要傳到最後面。

從這邊可以知道,對於每個單元,根據其所在的位置,傳輸路線有所不同:

  1. 最左一排:只接收上一層的,傳輸往右、往下。和Unet的encoder一樣。
  2. 中間和最右一排:接收來自下一層的,和水平層中比自己左側的每一個單元。

略為調整encoder和decoder

先把單元內的結構微調了一下:這邊以upsample為例。

改成先upsample,再進CNN block。同時把encoder改成先downsample,再進CNN block。這樣的目的是把每個單元盡量都能模塊化,整個網路,除了初始的那個單元,其他都可以用downsample或upsample處理:

如此一來,要架Unet++的時候:

圖2

就可以用「一個down stream」和「數個up stream」組合。

CNN block要裝什麼都可以。從只是單純的一層CNN,到多層CNN+batch normalization,甚至到Residual block,隨個人喜好。當然,也要考量參數量、計算速度跟記憶體負荷等等環境因素就是了,不過這不在本篇討論範圍。

架構Unet++

encoder部分,其實和Unet差不多,也只有一組,所以相對簡單。但decoder有許多條,而且稍後在建立傳遞路徑(forward)的時候,彼此間還有交互作用。為了夠精確,我這裡用dictionary的方式處理。給予每個decoder一個key,而每個decoder都是一個list,裡面包含好幾個單元。每一個decoder所含的單元數不同。

pytorch支援dict和list,分別對應到ModuleDict和ModuleList。同時,它也支援dict和list的組合(就是像python一樣list可以放在dict裡面)。

用雙迴圈,先把一個decoder建好,再給予對應的key,收到dictionary裡面。如果有把握index不會出錯的話,其實這邊用兩層list應該也可以,第一層list包住第二層的數個decoder list。

這邊最難的其實是channel要算對…一個不小心很容易沒對到,pytorch就跳錯。我們看一個較簡單的,3層Unet++的結構:

圖3

這個是每個單元輸出時的channel數目,請配合圖2看。y指的是最後分類的class數目。

decoder的單元,其輸出channel的規則為2^j。但麻煩的是輸入:

這是根據上面輸出的channel計算的,x為base filter,這裡是2。這個就比原本unet複雜一點,因為多了填滿中間的單元。以這個三層結構,關鍵就是(i=1,j=2)這個單元:在水平側,其接收(j=0)的兩個單元。來自下層的資訊,不管在哪個單元都固定只有一個;但來自水平層的訊息,隨著架構深度越深、單元的位置越靠近輸出側,接收的channel數越多。其基數為2^j,但須乘以(i-j+2)的倍數。那個2,就是來自下層的單元。

forward

架構建好了,就要告訴機器如何運算,即資訊的傳遞方向(propagation)。正向的傳播途徑建好之後,pytorch會自動建立反向傳播,所以偏微分、鏈鎖率等等的數學計算式就不必自己寫了。(其實有些運算式偏微分之後要寫還很費工夫)

在運算過程中,幾乎每一個單元都會反覆被使用(最右側decoder除外)。這點和Unet不同,後者在運算過程中,每一個單元都經過一次,只有concatenation的時候還需要再呼叫一次。Unet++位於中間的神經元會反覆呼叫encoder及它左側的神經元,因此要把他們都儲存起來。

照論文的圖,Unet++應該是有多個輸出的:每一條decoder的輸出都要送到損失函數去計算。這邊我的j index有點錯亂了,其實不是很好的寫法。在編架構時,decoder單元的list是從下往上寫的(j由大到小),但這邊呼叫的時候j又由小到大,j=0時,呼叫的其實是上面j=i-1的那個單元。(應該考慮再用另一個index名稱以免搞混)

分成幾個情況:

  1. 如果單元是在decoder最底端,那它就像Unet一樣,只需做concatenation就好。
  2. 單元在中間:要呼叫下一層的單元,以及水平層左邊的所有單元。我用「倒數」的方式呼叫,因為每一條decoder的長度不一樣,但倒數的位置是一樣的,所以才會在同一層。
  3. 單元在最上面:要把運算結果送到output去。

損失函數

根據論文,損失函數長這樣:

是dice coefficient和cross entropy的結合。其實是可以用pytorch裡面的函數組裝啦。這個彈性比較大,甚至不用論文的損失函數也可以,我自己有測試過,只用dice coefficient,訓練還是可以收斂的。

這邊我就不放程式碼上來了。

訓練結果

用小樣本測試,使用Cavana資料集。首先是Unet++的結果:

unet++ train loss
unet++ val dice

其實在初期有一些浮動。後面收斂之後可以保持穩定。

再來是unet:

unet train loss
unet val dice

我覺得比unet++稍微穩定一點。小樣本的資料集和少量的訓練epoch,其實不能比較什麼。只是證明自己刻的玩具可以跑得動,有點成就感而已。

在測試集上的dice分布比較:

1是unet++,2是unet。和訓練過程一樣,unet++似乎表現並不是很穩定,但隨著資料集增加是否還是如此,可以再看看。

以上就是unet++的實作分享,謝謝看到這邊的你。歡迎討論和分享!

參考資料

[1] Zhou et al., UNet++: Redesigning Skip Connections to Exploit Multiscale Features in Image Segmentation. DOI: 10.1109/TMI.2019.2959609

--

--

Martin Huang
Unet與FCN系列

崎嶇的發展 目前主攻CV,但正在往NLP的路上。 歡迎合作或聯絡:martin12345m@gmail.com