適用於小batch size的卷積神經網路結構改良-GN, WS

Cyan
Taiwan AI Academy
Published in
11 min readJun 11, 2020

Group Normalization

Batch Normalization(BN)的出現是深度學習的里程碑,它使得許多模型得以訓練得起來(或收斂得更快速),然而我們減少batch size的時候,往往會因為不準確的mini-batch統計量估算,導致BN的錯誤率急遽上升。

這讓許多需要大量gpu記憶體的電腦視覺課題(舉凡detection, segmentation, video)無法使用BN,所以本文作者(Yuxin Wu et al.)提出了Group Normalization來替代需要依賴batch size的BN。

很直觀地說,GN(Group Normalization)就是一種介於LN(Layer Normalization)與IN(Instance normalization)的標準化方法。

首先,標準化是指將中間層計算的特徵x,透過對特定的維度計算平均數(μ)與標準差(σ),將x減去平均數後除以標準差,而維度的定義以2D圖像為例,分別為N:一個batch的樣本數量、C:channel數量、H,W:圖片空間上的高與寬。(如下圖)

而計算平均數與標準差得方法如下圖:

S_i指的是要計算平均數或標準差的像素之集合,ϵ是很小的一個常數。

做完標準化後還要再將其乘上𝛾(伸縮量)再加上β(平移量),才會傳到下一層去。

以下分別針對四種歸一化的方法做解釋:

Batch Normalization

BN的S_i集合為

也就是說,當像素是在同一個channel的時候,就會一起做標準化。

Layer Normalization

LN的S_i集合為

意思是如果像素在同一個樣本內,就一起做標準化。

Instance Normalization

IN的S_i集合為

意即在同一個樣本內的同一個channel做標準化,換句話說其實就是將每個樣本中的各個channel的高、寬統一做標準化。

Group Normalization

最後是GN的集合,

看起來很複雜,其實就是同一個樣本中同一個”group”做歸一化。
而定義group的方法就是把channel數量除以定義好的group數量 — G(預設G=32,也就是分32群),得到平均每個group有幾個channel,再用所在的channel編號除以該值過下取整函數後,就得到分在第幾群了。
而GN分別是針對每個channel學習𝛾(伸縮量)與β(平移量),與BN一樣。

Experiment

我們再來看看作者實測的表現:

上圖是batch size=32,分別使用BN、LN、IN、GN的訓練誤差及驗證誤差,使用的模型為ResNet-50資料集為ImageNet。

可以看到BN在驗證誤差上的表現還是最好的,其次是GN,大概只小輸0.5%而已,而IN的表現則是最差。
另外可以發現GN的訓練誤差其實是最小的,其次是BN。作者的解釋是BN在計算平均數與標準差時,因為是隨機取樣batch資料而會產生不確定性,我們可以視此不確定性為幫助正則化(regularization)的工具。

上圖比較BN與GN在各個batch size下的驗證誤差。

很明顯地,因為BN的性質會造成在小batch下表現大幅下降,而GN則是會因為標準化與batch大小無關所以即便batch size = 2也能夠有與batch size=32一樣的表現。

Implementation

實作上也相當容易,只要改幾行code就可以使用(以tensorflow 為例):

def GroupNorm(x, gamma, beta, G, eps=1e−5):
# x: input features with shape [N,C,H,W]
# gamma, beta: scale and offset, with shape [1,C,1,1]
# G: number of groups for GN
N, C, H, W = x.shape
x = tf.reshape(x, [N, G, C // G, H, W])

mean, var = tf.nn.moments(x, [2, 3, 4], keep_dims=True)
x = (x − mean) / tf.sqrt(var + eps)

x = tf.reshape(x, [N, C, H, W])

return x ∗ gamma + beta

在tensorflow Addons也已經實裝GN,只要呼叫tfa.layers.GroupNormalization(groups)就可以使用了
(ps.這邊預設的group size = 2)

pytorch則是直接呼叫torch.nn.GroupNorm(num_groups, num_channels)就可以了

Weight Standardization

與GN的研究動機相同,本篇研究之目的是希望模型能夠在小batch上的表現超過一般batch size的BN,很驚人地,WS(weight standardization)透過與GN共用,能夠在batch size = 1的情況下,勝過大batch size的BN(紅色直方圖部分)

How Does Batch Normalization Help Optimization?(arXiv:1805.11604)中,說明BN之所以可以讓收斂速度變快、模型表現變好並非因為ICS(internal covariant shift),而是因為BN可以使損失函數平滑化,讓learning rate可以變大、同時減少local optimal。
這篇論文採用了其論點,所以在後面的篇幅主要是透過理論與實驗證明WS可以使損失函數更平滑。

WS algorithm

在進入證明之前首先要知道權重是如何標準化的,以下會先介紹WS的方法。
簡單來說,就像是標準化激發函數後的特徵一樣,計算卷積層中filter的平均值與標準差,再將filter內的權重減去平均值除以標準差,只是不需要與BN或GN一樣乘上𝛾(伸縮量)後加上β(平移量)的轉換,原因是作者認為後面BN或GN層會再做一次標準化與伸縮平移,而那步對權重的轉換在實驗上反而是會對訓練有害的。

而WS的方向如右下圖,filters對不同的輸出channel分別作標準化,換句話說就是同一組(對上同一個輸出channel)的filter做一次標準化。

Proof

下圖是WS在前向傳播(藍色)與後向傳播(紅色)的流程,數字是方程式的編號

<,>表示內積,◦2表示阿達瑪次方(Hardmard power)

由上面的方程式可以發現,不同於往常的梯度到W^hat就停止,這裡的梯度傳到W的過程中某種程度上也經過了一個標準化(式8,9)。

接著要證明WS可以讓損失函數比較平滑的話(利普希茨常數(Lipschitz constant)變小),我們直接檢查梯度範數(gradient norm)有沒有比較小就好(這邊直接跳過推導過程XD):

上式對應式8,同時對應式6作標準化的伸縮
上式對應式9,同時對應式5作平均數的平移

透過式11,12我們可以發現式8,9都可以降低gradient norm,作者進一步實驗這兩個步驟對於模型表現的影響,以及其實質降低的幅度:

左、中圖是分別使用GN、GN+Eq.5(平移)、GN+Eq.6(伸縮)、GN+Eq.5&6(標準化)的訓練/驗證結果,右圖是對於利普希茨常數的平均下降百分比(左邊percentage的刻度是log刻度)

從實驗中可以發現其實實際在作用的是式5(平移),而式6(伸縮)雖然進步比較少然而還是有一點點幫助的;而右圖也可以發現藍色部分(式8的造成的減少量)也是小於橘色部分(式9造成的減少量)的。

Experiment

再來是實際測試的結果:

由上表可以明顯的看出GN+WS(batch size=1)是可以贏過目前常用的BN(batch size=64/32)在模型ResNet-50/101上。另外BN+WS的表現也是贏過只使用BN的模型。

另外上圖可以發現,WS對於更深層網路的效果是更好的,無論是ResNet-101或是ResNeXt-101,多加了WS可以同時讓訓練與驗證的表現大幅上升的。

Implementation

只要對簡單增加幾行code就能夠使用強大的WS,不過要注意,必須要在 後面接Normalization層(GN,BN…) 才可以。

Pytorch:

class Conv2d(nn.Conv2d):    def __init__(self, in_channels, out_channels, kernel_size, 
stride=1, padding=0, dilation=1, groups=1,
bias=True):
super(Conv2d, self).__init__(in_channels, out_channels,
kernel_size, stride, padding,
dilation, groups, bias)

def forward(self, x):
weight = self.weight
weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
keepdim=True).mean(dim=3, keepdim=True)
weight = weight — weight_mean
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1,
1, 1) + 1e-5
weight = weight / std.expand_as(weight)
return F.conv2d(x, weight, self.bias, self.stride, self.padding,
self.dilation, self.groups)

Tensorflow:

kernel_mean = tf.math.reduce_mean(kernel, axis=[0, 1, 2], keepdims=True, name=’kernel_mean’)
kernel = kernel — kernel_mean
kernel_std = tf.keras.backend.std(kernel, axis=[0, 1, 2], keepdims=True)
kernel = kernel / (kernel_std + 1e-5)

再把kernel放進tf.nn.conv2d(filters=kernel)就可以了

Reference:

  1. Group Normalization
  2. Weight Standardization
  3. How Does Batch Normalization Help Optimization?

--

--