Quantized Neural Networks: Training Neural Networks with Low Precision Weights and Activations

totokk
3 min readApr 14, 2017

--

2016 QNN

MNIST

no preprocessing

vanilla implementations:

3 hidden layers, 4096 binary unit, L2-SVM as output layer,

Dropout, square hinge loss, ADAM, Batch Normalization,

2.1 Deterministic vs Stochastic Binarization

Deterministic:

Stochastic:

Stochastic 理論分析上常用,但因為硬體要生成隨機數較困難些及耗時,因此實作上 Deterministic 比較常用。(例外: activations at train-time)

2.2 Gradient Computation and Accumulation

雖然 BNN training 用 binary weights and activations 來計算參數的 gradient , 但權重的梯度是實數的 (real-valued), 且於一實數變數中積累( accumulated)

SGD 是需要的,但 SGD 在參數空間中以 small and noisy step 探索參數空間,noise 被 weight accumulation 平均掉了, 所以 accumulator 的 resolution 要夠! ( 不夠會怎樣? )

另外,額外加一個 noise 到 weights and activations 上 ( variational weight noise, Dropout etc) 提供了 regularization , 可以提供更多的泛化能力.

我們 train BNN 時, 用了類似 Dropout 的方法,但不是在計算參數的梯度時隨機地設一半數量的 activations 為零,而是 binarize both weights and activations.

2.3 Propagating Gradients Through Discretization

除了在 0,符號函數可微分,其導數為0。backprop 無法使用,該怎麼辦?(為何無法使用? 因為 cost 的梯度在離散化 ( discretization ) 之前就是 0了),這種限制就算是使用 stochastic quantization 也是存在。

Bengio (2013) 研究了關於該如何使用 stochastic discrete neuron 來估計梯度或傳遞梯度。結論是使用 straight-through estimator ( Hinton, 2012) 是最快的方式,QNN 是用稍微修改後的 ST estimator,考慮了飽和效應 ( saturation effect ),並且使用確定性來取樣 bit,而非 stochastic 來取樣 bit。

Sign function quantization

假設想估計的loss(cost)梯度的 estimator 已經造出:

另外梯度的 ST estimator 也造出:

where 1x 是 indicator function:

為了更瞭解為何 ST estimator 有用,考慮:

但,重寫機率函數,

其中,HT(r) 為 hard tanh (本來是 hard sigmoid )

如此,下一層 layer 的 input:

這裡用了:

n(r) 是 binarization noise with mean equal to zero. 這項如果layer is wide enough,可以當作 0。

HT(r) 是 hb(x) 的期望值,看下二式

所以,不能計算的

可以用下式取代:

阿 (6) 這不就是:

類似的 binarization process ,我們用在 weight 上的包含兩個要素

  1. 每個 real-valued weight 投影到 [-1, 1],這樣一來,real-valued weight 就算值成長到非常大也不會影響 weight。
  2. 當用 weight

quantize it using

簡單說就是作了以下處理,使得函數在[-1, 1]之間是可微分的,且梯度永遠為 1。

Round function

— — — — 補充資料 — — —

Ans: straight-through estimator (Bengio, 2013 or Hinton, 2012)

Training BNN

What is float16, float32?

https://zh.wikipedia.org/wiki/IEEE_754

--

--