Transfer Learning 轉移學習

陳明佐
我就問一句,怎麼寫?
11 min readApr 10, 2019

What is Transfer Learning?

來自台大李宏毅教程的介紹:
轉移學習就是把已經訓練好的模型、參數,轉移至另外的一個新模型上
使得我們不需要從零開始,重新訓練一個新model

舉例來說,你可以train好一個based on Cifar 10的CNN
然後把這個訓練出來的模型套用至其他影像辨識的數據上、
甚至是使用這個模型成為一個特徵萃取機制,串接傳統的SVM方法。

主旨在於,機器學習、深度學習領域,用於解決數據標記困難、數據取得不易的問題,重要的解決手段。

Why Transfer Learning?

僅基於數據的角度來談的話有三點

  • 收集數據比較困難
  • 標記(label)數據很耗時,很繁瑣,需要大量的人力
  • 訓練one2one的model,缺失了Machine Learning 中對於泛化性能的要求

上面三點,就使得我們必須進行Transfer Learning

Transfer Learning 方法分類

Based on instance or said sample
通過權重的分配,來分別作用到源域(source)和目標域(target)來進行轉移
EX:
在源域中有一個樣本和目標域中的一個樣本非常的相似,那麼我們就可以加大此樣本對應的權重。

Based on features
將源域和目標域的特徵變換到同一個空間
EX:
在兩個域上的feature具有很大的區別,那麼我們就可以通過將這兩個域的feature變換到同樣的空間,這個時候我們就可以很方便的研究這兩個域上的相關內容和性質了。

Based on model
通過源域和目標域的參數共享機制
EX:
這也是我們在做Deep Learning中用到最多的一個方法了,比如說,將pre-trained的模型拿過來,通過固定一些layer的parameters,修改部分layer的parameters得到最終的非常好的結果。

轉移學習的定義

Source Domain : 已存在的知識或已學習到的領域
Target Domain : 欲進行學習、訓練的領域

Task : 由目標函數、學習結果組成,為最終的學習結果。

我們手上有兩種域:Source domain, 有個function叫做 fs
此function的作用是來完成一個名為Ts的任務
另外一個域稱作 Target domain, 一樣存在function叫做 ft
並且需要此function去完成任務Tt

而轉移學習就是需要透過Source domain及Ts的學習,使得能夠幫助Target domain中的ft學習,進而使得更好完成Tt

轉移學習運作

李宏毅教授投影片

轉移學習有分成四大種情況,主要差別來自於資料是否有標注。

以下先以最簡單的,兩大Data都有label的情況 — Fine Tuning下手。

Fine Tuning

目前常見於DL,即比如做圖像分類,我們往往會拿已經訓練完成的ImageNet作為我們任務的pre-trained model,然後做一些參數調整及可以完成。

意味著Target data是屬於比較少量的數據集,如果Target data資料多,其實我們就不需要做什麼轉移學習,直接拿模型再次常規訓練即可。

假設,我們的 Source data是上千個人的語音資料,而 Target data為幾個人的兩三句話;

最直接的方式就是,直接將Source data拿去train出一個model,再用target data去tune原本的model。但因為target data數量過少,亦有可能產生over-fitting的問題。為了避免此問題發生,而產生了Conservative Training

Conservative Training

此方式並不是直接將target data丟入原先train好的model
而是在新的model加入一些正則化項Regularization
限制新model與原先的model在相同的輸入情況之下,盡可能得到相同的輸出

當如果拿掉這個限制,透過 target data fine tune出來的新model,再丟source data進行測試,會發現整個模型早已經不具備原先model的表現能力,僅能針對 target data進行表現,這是我們常說的災難性遺忘

Layer Transfer

此為另一種訓練方式,有別於conservative training
Layer transfer ,顧名思義就是層的轉移。
一樣透過source data去train好一個model,但是這次不把整個model 複製過來,而是萃取其中幾層的參數,其餘沒有萃取的參數再透過 target data去訓練學習。至於說該複製哪些層哪些參數,其實沒有硬性規定,但通常來說:

語音識別任務,我們通常transfer 模型最後幾層的參數
影像辨識任務,通常transfer 模型最前面幾層的參數

對於語音信號來說,每個人使用同樣的發音方式,所得到的聲音結構是不一樣的,因為他們的口腔結構肺部組織什麼的都會存在差異。在train中,網絡前幾層的目的就是辨識到這些人的發音方式,然後後幾層才會去聽到底這些人都說了什麼。所以,我們發現model的後幾層是和發音者完全無關的,而前面幾層才是在尋找發音者的發音方式。

對於影像辨識來說,我們知道,在CNN做圖像分類的過程中,前面的捲積層做的工作都是在做feature extraction的過程,而提取到的那些特徵都是圖像中的最基本的特徵,比如邊邊角角,線等。所以,我們很需要這樣的底層結構基礎訊息,就可以transfer到其他的分類task上面了

Multitask Learning

Fine tuning 時,我們關注在於model對於 target data表現得好不好,至於說source domain data做得如何倒是無所謂。儘管source data進到新的模型訓練發現效果很差也無妨,只要新模型在target domain表現夠傑出即可。

但multitask learning的主旨就不同了,不僅要求最後的新model在target domain表現要好,而在原始的source domain也必須表現得同樣出色。

Multitask learning 常見的兩種結構方式

然而Multitask learning最成功的就是語音識別了。
收集到的聲學特徵,然後通過下圖的兩個紅色layer,這兩個layer是所有國家語言的基礎和共同性質訊息,然後對於不同國家語言的轉換,就對應了不同的task。

Domain-adversarial training

當我們的Source data具有label的時候,Target data不具有label的時候,我們就要採取一種叫做Domain-adversarial的方法

先以下列例子做為範例,MNIST為黑白照片並且具有標記
然而MNIST-M為彩色版照片但並無標記
我們想透過已經用MNIST訓練好的模型去預測MNIST-M
如果直接透過MNIST模型去預測MNIST-M,則準確率約為5成,非常低。
所以必須透過Domain-adversarial training調整。

綠色為feature extractor, 而藍色是classifier
Source data進入模型最後的數據都位於藍色的數據點上
而target data進入模型,最後的數據則是位於紅色上
得出一個結論即為:source & target data 呈現不匹配,也就是mismatch
於是想要透過下圖的方式進行改善。

為了使紅色數據點與藍色數據點能無差別的混合在一起,必須新增一個機制叫做domain classifier,這會使得整個training非常相似於GAN生成對抗網路

第一部分綠色的feature extractor其實要做的就是提取出source data和target data的feature,然後使得最後在做classification的時候,通過這些提取出來的feature,能夠得到一個非常好的精確度。還有盡可能讓這些mismatch的data混在一起,以至於domain classifier不能夠正確的判斷他們是否混在一起。

第二部分藍色的label predictor做的是,能夠盡可能大的輸出classification的精準度。

第三部分紅色的domain classifier做的是,能夠盡可能的將從feature extractor中提取出來的feature進行分開,將其各自歸屬到其所屬的domain裡。

domain-adversarial training的訓練方式如上圖
前面對於網絡結構的分析,我們知道了domain classifier和feature extractor之間的關係其實是互相傷害,也就是衝突的
換句話說,feature extractor希望將兩個domain提取出來的feature盡可能的混合在一起。但是domain classifier卻希望他能夠盡可能的把從feature extractor中提取出來的feature劃分到兩個domain中。
那麼在做BP的過程中,只需要將誤差對於參數的偏導數加一個“負號”,就能夠達到這種目的,其他部分的training和一般的網絡的training沒有什麼區別。

Source only指的是,只用source data訓練模型,然後測試target data
這種方法的效果是比較差的
Proposed approach指的是Domain-adversarial training,效果有很好的提高
Train on target指的是,直接拿target domain的data訓練,
是performance的upper-bound

Zero-shot learning

Zero-shot learning 的data label情況與domain-adversarial learning的情況相同,但Zero-shot learning 是source & target data針對不同的task。
例如,source data中有貓狗的圖片,沒有羊駝的圖片,而target data中是羊駝的圖片。那訓練出的模型能認出羊駝嗎?語音上常遇到Zero-shot Learning的問題,若把每個word都當做一個class,那麼training的時候與testing的時候就可能會看到不同的詞彙。解決辦法是,不去識別一段聲音屬於哪個word,而是識別一段聲音屬於哪個音素(音標),然後根據人的知識建立一個音標和詞彙關係的詞典。

Reference

  1. https://zhuanlan.zhihu.com/p/34589303
  2. https://zhuanlan.zhihu.com/p/49407624
  3. https://blog.csdn.net/xzy_thu/article/details/71921263

--

--