小資料系列二-Metric-based Meta Learning後篇

Shu-Yu Huang
Taiwan AI Academy
Published in
12 min readDec 21, 2021

經典Metric-Based Meta Learning介紹

上次我們簡介完Metric-based Meta Learning是什麼,那我們這次帶來兩個主要的流派較經典的Model Siamese Network以及Prototypical Networks。

Siamese Network

這是一個古老的Method,在2015年被Koch特別介紹可以當作Few shot的方案以前[2],在1993就已經被LeCun等人拿來做筆跡鑑定了[1]。如Fig.1所示,將真跡與待鑑定筆跡各丟到同一串神經網路中,萃取出兩個筆跡特徵的latent vectors,並計算兩個vector之間的cosine similarity,若超過某個閾值則代表是同一個人。

Fig.1 最初Siamese Network被使用在筆跡鑑定上

Siamese Network的計算方式就是如此簡單,將待推論資料跟一個對比資料透過同個(或稍有不同)embedding function特徵萃取後做相似性/距離計算,然後取一個閾值。這個距離方程式可以用Euclidean Distance/Manhatten Distance/ Minkowski Distance等等,見Fig.2。

Fig.2 Siamese network示意圖以及參考的Distance function

在訓練時,會以pair(資料對)為單位送進model進行訓練,並且使同類的label為"1"、異類的label為"0"。例如Fig.3所示,假設有三個不同類別(紅、黃、綠)要來訓練siamese network,然後每個類別只有兩筆資料("1","2"),那樣我們可以不重複排列出15種pair (彼此不重複,無順序之分)進行訓練。

Fig.3 資料對與丟進Siamese network進行運算的範例,圖中紅、黃、綠三種顏色代表資料屬於三種不同類別

確切一點講,當不論類別總共n個資料,套用排列組合的公式,總共pair數會是:

這麼多的pairs,訓練資料量與原先的n個相比,是O(n²)這個等級,可大幅增加訓練資料量。

在資料訓練上,初版使用的loss function 是Contrastive Loss ,並加上一個margin避免將異類間距離無限放大。其公式如下:

這邊設定a與b分別是query(又稱為錨點)跟suuport的資料embedded的latent vector, D(.)是distance function,另外y則代表label值,同類的label為”1"、異類的label為”0",最後margin是某個閾值,限制學習的極限他的意義分同類時的loss跟異類時的loss兩個部分。以下Fig.4以計算metric為distance當範例做說明。

如Fig.4左所示,當a,b兩者同類時y為1,這時啟動前式:loss為D(a,b)²,距離越小loss越小,因此在graient descent時會想盡辦法使得同類距離變小。

如Fig.4右所示,當a,b兩者為異類時y為0,這時啟動後式:margin-D(a,b)小於等於0時,兩類的距離已經被拉開就不用再訓練;反之margin-D(a,b)代表兩個latent距離還不夠開,應該再更開一點,所以loss就是(margin-D(a,b))²,注意distance前面有個負號,所以距離越大loss越小,在graient descent時會想盡辦法使得異類類距離變大。

Fig.4 計算Metric為distance時訓練的兩個目標。左圖,a,b同類時的loss目標;右圖,a,b異類時的loss目標。

除了同類的鑑定或者兩者相比的應用,Siamese訓練好以後也可以拿來做分類預測。如Fig.5所示,我可以將query照片放進embedding function 萃取出latent_q,並把類別0的照片support 0和類別1的照片support1丟進embedding function 分別萃取出latent_0和latent_1。這樣我們就可以使用latents作為代表比較”query, support1"以及 ”query, support 1"這兩個pair,選出一個預測結果。這個metric 值可以是相似度或者距離,取pair相似度最高者或者距離最近者為預測值。

Fig.5 使用Siamese Nework做分類預測

在同類間變異度不大時,上述式子已經可以訓練出一個好的分類模型,不過當同類間變異以使得latent間的變異大過margin時,會出現判別困難的情況。如Fig.6 所示,由於超過margin異類就不會再被拉開,query的latent要是跟同類support1間距離跟異類的support2同樣為margin時,這個query就無法藉由這兩個support來判別類別。

Fig.6 變異大過margin的情況

所以後來就有人將contrastive loss改版,是為Triplet Loss[3]。為防止同類變異過大導致同類間與異類間有相同間距,Triplet Loss只訓練同類間距與異類間距的差距。其公式如下:

與前面相同,設定a是query latent vector,但比對的support latent vector有兩個,一個b是與a同類的正樣本, D(.)是distance function,最後margin是某個閾值,限制學習的極限。從公式可看出這個loss不需要給label y,b與c的本身就是從類別抽取的所以不需要再有標準答案。如Fig.7所示,在D(a,b)² -D(a,c)²小於等於margin時不訓練model,代表同類間距與異類間距的差距已經夠大。反之則會開始訓練,gradient descent之下會想辦法使D(a,c)²變大、同時把D(a,b)²變小。

Fig.7 Triplet loss一次比對同類與一類的support,同時滿足縮小同類距離、放大異類距離的目標

Prototypical Networks

為了完成分類預測,Siamese Network由兩者比對的訓練漸漸演變成三者比對的訓練。Prototypical Network更進一步的乾脆把這個比對的support變成從n個類別中各取k個做對比。如Fig.8所示,假設有四個類別(紅、藍、綠、黃),我們會將query與各個類別的support資料丟進embedding function形成latent vector。將query與各support的latent vectors之間一一算出matric,相似度最高或距離最近的latent所屬的類別即是預測的類別。

Fig.8 Prototypical的比對方式概念圖

Prototypical network其實也不是一一比對,而是將query的latent與各個support種類的prototype(原型)做比對。如Fig.9所示,最簡單的prototype就是取平均,假設貓的類別有三張貓的圖片,貓的prototype就是在三張貓圖latent取平均。依此類推,我們可以找出獅子與老鼠的prototype,拿來跟query圖的prototype做比對算距離(或相似度),並取最小者(相似度取最大者)為預測的類別。

Fig.9 Prototpye示意圖

假設來源資料集原本有N’個類別,每個類別有K’個資料,在訓練時,會從任意不重複N'個類別中各取K個support資料再從N’類別剩餘資料中抽取1個query組成goup(資料組)為單位送進model進行訓練,並且label只能表達相對類別而不是絕對類別,只表明這個組內的分類。例如Fig.10所示,假設有四個不同類別(紅、黃、綠、藍)要來訓練siamese network,然後每個類別只有兩筆資料(“1”,”2")。若設定只做三選一的訓練(N=4,N’=3)且設K為1,則每次從任3類別中取出各1個資料做為support,然後再從這3類剩餘資料中再抽取1個資料作為query,訓練模型預測這個query是support的哪類。

Fig.10 訓練Prototypical 的資料組範例,圖中紅、黃、綠、藍四種顏色代表資料屬於四種不同類別

我們可以不重複排列出非常多種資料組來訓練,其公式比較複雜,如下式:

用剛剛Fig.9中的四種各兩個樣本、取3-way-1-shot support為例:

  1. 四個種類取三種為4!/((4–3)!3!),共4種組合
  2. query要三種取一個,共3種選擇
  3. 沒被抽到query的種類,要從2個抽1個當support,共2種選擇,總共要抽2次,所以是2²,4種組合
  4. 有被抽到query的種類,包含support根query要從2個抽2個,只有1種選擇
  5. 有被抽到query的種類,要從全部抽選資料中選1個當query,共2種選擇
  6. 全部乘起來,共有4*3*4*1*2=96個可能的資料組

*也可以套入公式2=[4!/((4-3)!(3-1)!)]*[2/((2-1)!1!)]³*[2-1]=12*8*1=96

前述例子原本有4*2=8個資料,但換成資料組卻可以到96個資料組。大致上,跟原本資料量K'*N'比較起來會讓整個輸入數量成長很多。詳細抽出組合數分布如Fig.11所示,不論是抽組合還是抽資料,當欲抽出數量為總數量一半時取得的組合數最多,不過都是級數等級的數量成長就是了,所以都會比原本的資料總數來得多。

Fig.11 排列組合數分布(log10 scale)- x軸: 待抽出數量;y軸: 選項總數

在每一組訓練時我們將N個類別的K個support都取平均得到prototype,算出prototype對query的負距離平均(或相似度平均),再對N類別作softmax得到output,如Fig.12所示。
(註: 圖上先算負距離再取平均只是為了化簡字句,實際上會先平均再取距離或相似度)

Fig.12 四選一的prototypical network算法,其中綠、紅、藍、黃代表不同的類別的圖的latent vector。

最後再對答案,將預測值與label算loss,做gradient descent,如Fig.13所示。這邊的loss可以使用各種拿來做分類問題的loss,像是Categorical Cross Entropy、 Focal Loss等等。在gradient escent之下,因為對應label的類別的答案是1,所以該類別的負距離會拉高,距離會拉近(或者相似度變高);剩下其他類別的答案是0,所以這些類別的負距離會降低,距離會推遠(或者相似度變低)。這種一次納入多個類別作Metric Learning的優點就是可以一次訓練所有牽涉到的類別間距。

Fig.13 預測值與label對應算loss

這次我們介紹了Metric-based的經典演算法,透過訓練塑形latent space,以至於兩個不同的類別在這個space中可以被良好的分開。除這個流派之外實還有許多不同的流派,Optimization-based, Memory-based,往後會一一介紹。

References

[1] Koch, G., Zemel, R., & Salakhutdinov, R. (2015, July). Siamese neural networks for one-shot image recognition. In ICML deep learning workshop (Vol. 2).

[2] Bromley, J., Bentz, J. W., Bottou, L., Guyon, I., LeCun, Y., Moore, C., … & Shah, R. (1993). Signature verification using a “siamese” time delay neural network. International Journal of Pattern Recognition and Artificial Intelligence, 7(04), 669–688.

[3] Schroff, F., Kalenichenko, D., & Philbin, J. (2015). Facenet: A unified embedding for face recognition and clustering. In Proceedings of the IEEE conference on computer vision and pattern recognition (pp. 815–823).

參考網站

度量學習中的pair-based loss — GetIt01

Proxy Anchor Loss for Deep Metric Learning | by Ahmed Taha | Medium

--

--

Shu-Yu Huang
Taiwan AI Academy

AI engineer in Taiwan AI Academy| Former process engineer in TSMC| Former Research Assistance in NYMU| Studying few-shot-learning and GAN