聯邦學習(Federated learning). 模型之間的知識分享體系

王柏鈞
機器學習歷程
Published in
11 min readMar 29, 2021

聯邦學習是很有趣的「機器學習」應用(沒錯!他不是僅限深度學習),聯邦學習對台灣也許會是最有價值的技術之一。

本文簡介了 :

聯邦學習的起源、通用流程,以及聯邦學習的經典 FedAVG如何實現,並在最後分享了一個聯邦學習中相當重要,但在深度學習領域不太容易發現的知識,希望可以讓讀者對聯邦學習有概括的認知。

聯邦學習最早的起源應該是來自google在智慧輸入法的應用

為什麼要聯邦學習 -FL起源簡介

聯邦學習(Federated learning)應該被歸納為分散式學習(Distributed learning)的分支,不同的是,聯邦學習是一種不得不的分散式學習。

分散式學習透過將資料分配到更多的運算單元,有了更高的平行運算能力。而聯邦學習則是因為個別運算單元(e.g. 手機)中收集到的資料,不允許離開該單元,為了追求更好的模型訓練效果只好透過分散式學習的方式來進行。

舉例來說,Google為了聯合不同用戶的使用習慣優化智慧輸入法(Gboard),同時不違反歐盟GDPR對用戶隱私的保護,於是透過聯邦學習來進行模型訓練,這個過程中只有梯度會被共享出去,所以能夠很好的保護用戶隱私。

不過事實上,目前研究 (Deep Leakage from Gradients, Song han et al. 2019)指出,共享出來的梯度也可以讓惡意工作者還原出 NLP的 token或是 pixel-wise等級的影像(64x64)。

Fig. 分散式學習和聯邦學習在系統層面十分相似,我們將其稱為分散式系統。

聯邦學習架構

分散式學習和聯邦學習在系統層面十分相似,我們將其稱為分散式系統。首先,儘管有非常多的結構差異與變體,但所有的分散式都可以被分成兩種形式:有參數伺服器(parameter server),沒有參數伺服器,在拓樸上,這兩種形式分別代表了中心式和去中心式的分散式計算,並與許多節點(local node)共同合作。

以上圖為例,儘管嚴格來說這裡的Server並沒有保有參數(參數指的是模型,也就是server上不存在模型),但Server協助了梯度聚合和傳遞,我們可以把它當成中心式的聯邦學習系統。

source: Federated Machine Learning: Concept and Applications, Qiang Yang et al. 2019

了解聯邦學習的架構後,接著敘述聯邦學習的流程。

聯邦學習的流程大致上可以分成4步驟:

  1. 確定架構(拓樸) Formulate topology
  2. 梯度計算 Gradient compute
  3. 資訊交換 Information exchange
  4. 模型聚合 model aggregation

Step1: 確定架構(拓樸) Formulate topology

這裡就是上一個章節所提到,我們決定好中心或去中心化,制定溝通的順序與規則、資訊流通的方向、模型在哪裡聚合。

key word: distributed learning topology, communication cost, model exploitation, center variable

Step2: 梯度計算 Gradient compute

對於聯邦學習,不論是中心或是去中心式計算,大致上來說首先各個節點會先基於各自保存的局域資料集(local dataset)的進行梯度計算,梯度計算的方法會基於拓樸和演算法有很大的差異,也可能需要先做特定的資料處理。是聯邦學習很重要的研究主題。

key word: optimizer, FedAVG, FedSGD

Step3: 資訊交換 Information exchange

完成參數更新後,會基於拓樸架構來傳送梯度:對中心式模型來說,梯度會先聚合在參數伺服器、然後分送給各節點。對去中心式模型來說,梯度會傳送給鄰近節點或指定節點,梯度交換之前會先進行加密,確保梯度不會導致隱私洩漏。
在加密環節,以大陸Webank為首的聯邦學習研究團隊則經常介紹到「同態加密」(Homomorphic encryption)技術。

key word: gradient encryption, Homomorphic encryption, communication cost

Step4: 模型聚合 model aggregation

而聚合的時候,對於兩種拓樸,由於聯邦學習中,每個local dataset幾乎都是Non-IID的資料,單純的聚合很難適用,所以聚合的方法也是聯邦學習很重要的研究主題。

key word: model aggregation, Non-IID

舉個例子 我們來推演一下中心式聯邦學習的經典: FedAVG

Fig. 我們用FedAvg推演一個有3個local node, 1個中心伺服器的聯邦學習

ref: Communication-Efficient Learning of Deep Networks from Decentralized Data, 2017

Step1: 確定架構(拓樸) Formulate topology

首先,server建立初始化的Server模型(w0),用一個參數C決定要抽百分之多少的local node來參與這一輪的聯邦學習,我們假設這裡抽出66%的local node,也就是總共3個中的2個。

這時候FedAVG演算法為了控制local node的訓練,會傳送三個參數給local node,分別是:

  1. B, local node在訓練時的minibatch size大小。
  2. E, local node在這一輪的訓練中可以訓練幾個Epochs。
  3. lr, learning rate,local node在這一輪的訓練中的學習率

Step2: 梯度計算 Gradient compute

我們接下來觀察local node,這時候被抽到的兩個local node會平行的做這些事:

  1. 從Server那裏下載初始化的Server模型(wo)
  2. 確定接收到的B、E和lr參數。 (我們假設參數分別是 B = 32、E = 5、lr=0.001)
  3. 基於B參數指定的batch size(32)以及lr(0.001)開始訓練並更新梯度,總共訓練E(5)個epochs
  4. 完成第5個epoch訓練後,將當前的模型上傳到Server

Step3: 資訊交換 Information exchange

因為這是比較早期的Federated learning,並且論文的方向也不是在安全隱私上,所以這裡的模型參數資訊並沒有被加密,簡單的被送給Server。

Step4: 模型聚合 model aggregation

現在模型參數已經被送到Server了,接下來Server會做這幾件事

  1. 依據被選中的local node中所含的樣本數量,給定一個權重(nk/n)
  2. 將每個local node的模型乘上權重,並加總起來
  3. 加總起來的模型參數,就是新的Server模型(w1)

nk: 第 k個 local node 當中的樣本數。
n :所有節點中的所有樣本加總。

以上,就完成了一輪的FedAVG算法,接下來就是不斷重複這些流程,直到完成整個聯邦學習系統的訓練。

FedAVG帶給我們在聯邦學習上最重要的知識

乍看之下,FedAVG很簡單,但有一個知識是整個聯邦學習的一大核心。

相同初始值的模型,經過不同資料訓練後聚合再一起,可以增進模型效果。

不同初始值的模型,經過不同資料訓練後聚合再一起,模型效果會大減。

我們有兩個dataset,將兩個不同初始值但架構相同的模型分別在不同dataset上訓練後,將參數取平均,loss會大幅增加。將兩個相同初始值且架構相同的模型分別在不同dataset上訓練後,將參數取平均,loss反而可以降到低點。

這個知識對深度學習基礎的聯邦學習方法有很大的影響,之所以有這個效果Communication-Efficient Learning of Deep Networks from Decentralized Data 論文中大概有這樣的說法:

在近期的實驗中指出,一個有足夠參數且充分訓練的模型基本上都能夠到達全域最低點,而不會陷入早期懷疑可能造成影響的局域低點。

而兩個相同初始值的模型即使在不同的dataset上訓練,也會傾向於找到相近的全域最低點,並依循相似的路徑或方向進行參數更新。因此,當兩個具備相同初始值的模型在經過訓練後,單純的將參數相加取平均,他們的loss surface表現會驚人的好,甚至比單一模型在各自的dataset上取得更好的成績(實驗以MNIST為例)。

對於這個現象我個人的假設是:在實務中,有兩件事發生了

  1. 對兩個模型我們都無法實現真正完美充分的訓練
  2. 相同初始值的模型的loss surface傾向於擁有相近的全域最低點。

於是經過訓練的兩個模型,儘管看起來loss不再降低,但實際上還沒真正達到全域最低點,而是停滯在loss surface中全域最低點附近的某個淺窪處。

將兩個模型加總平均時,由於有相近的目標,於是讓他們有重新調整位置的機會,脫離了淺窪,而更靠近了全域最低點。

模型的初始值控制與聯邦學習的應用

基於這個觀念的應用,在2014年有一個在分散式學習的論文提出很有效的方法。

Deep learning with elastic averaging sgd, 2014

Deep learning with elastic averaging sgd, 2014 提出的 EASGD 方法會讓每個節點在訓練時,用一個懲罰項控制 local node 模型在發展時,不能跟中心伺服器上的模型距離太遠。直觀的想像,就是用狗鍊束縛住每一個local node,再根據每個local node的習性調整鍊條的長度。

source: https://petsmao.nownews.com/20140718-3465 哈利波特長大後

那為什麼要控制local node的模型訓練呢?剛才有提到,EASGD事實上是在控制每個節點和中心模型的距離,如此一來,就可以在實現訓練的同時,確保模型不會受到local dataset的過度吸引,進而導致過擬合(Overfitting)在該local dataset上,並降低模型聚合時,整體模型的效果。

我們可以想像,在上面的哈利波特其實閉著眼睛,而周圍的狗狗(local node)則不斷的在各自管轄的範圍內挖掘著資訊,而鍊條限制了狗狗挖掘不必要或是太遠的資訊。

於是只要簡單的將每個狗狗挖掘到的資訊聚合,哈利波特就得到了他周圍的空間資訊,這樣他就知道該往哪邊走了(確定中心模型該更新的方向)。

結論

在發現聯邦學習後,我一直在想也許聯邦學習可以讓台灣的製造業有新的突破。

先是讓企業內不同部門能夠建立互不洩漏資訊、低成本、容易佈署的AI決策與數據系統,再來讓不同企業攜手創建屬於小區域的地方數據鏈,讓不同企業共同組成能夠更加靈活、快速應變的國家隊。
當生產效率和彈性能夠提高,AI在智慧製造最令人詬病的「能得到多少收益」問題也許就有一個答案。

最終,不管是接軌歐洲的工業4.0或是美國的先進製造、或是如工業3.5之類獨屬於台灣自己的節奏,資訊和技術的開放都可以為製造業帶來巨大的改變。讓不同企業攜手創建局域地方數據鏈,讓不同企業共同組成能夠靈活、快速應變並能夠順利互相溝通的國家隊,最終接軌先進的製程,不論是歐洲的工業4.0、美國的先進製造、或是如工業3.5之類獨屬於台灣自己的節奏,創造品牌。

如果讀者有想法,也許可以一起來討論,很希望台灣製造能夠有一個品牌,讓台灣的工具機或技術能夠得到歐盟、美國或其他先進國家的認同。

--

--