ML : Self-attention
Self-attention(自注意力機制)
前言:
至目前為止,model用到的輸入皆可看為一個vector,但遇到更複雜的輸入時,像是輸入為一個sequence或是每次輸入長短不一的向量!
解決目標:
處理下述複雜的輸入!
舉較為複雜的例子來說:
1. 文字句子可以作為一個vector
2. 聲音訊號也可以是一段vector
3. Graph當成一串vector
4. Drug Discovery分子架構做為一段vector
Q.而複雜的輸入對應什麼輸出呢?
- 每個vector有自己的label (NtoN)
應用面:
a. 文字處理
POS tagging(詞性標注)
b. 語音處理
c. Social Network
- 輸入整個sequence,輸出一個label (Nto1)
應用面:
a. 文字判斷正負面:This is good =>正面
b. 給graph,輸出label:分子架構圖,決定親水性
- 訓練家不知要有多少label,由model自己決定(NtoN’)
N不一定等同於N’
應用面:
a. Sequence to sequence(Seq2Seq): 輸入一段語音,翻譯出一段文字。
Sequence Labeling
input number = output number
目前我們使用FC(Fully connected) neural network,對I saw a saw(我看到一把鋸子)做POS tagging分類,如果模型一單一個字彙訓練模型,句子中的兩個saw應該會被判斷為同一類型,但這不符合分類預期!
所以模型只針對一個字彙訓練,無法判斷名詞動詞,應該要考慮字彙間的關係,使FC考慮上下文的關係,用一個window蓋住部分sequence
But,此方法有極限,像是如果window 涵蓋整個sequence,容易導致訓練參數量暴增,且易overfitting!
解決方法 — Self-attention
Attention is all you need!
介紹:
Self-attention會吃一整個sequence的資訊,輸出相同數量的結果,且在訓練時他考慮一整個sequence 。
- a1~a4為input,或是hidden layer的output.
a1~a4可對應到 I saw a saw - b1~b4為相對應的output.
產生bx的步驟:
根據a1找出其他a2~a4跟a1的相關程度alpha
以b1來說明(b2~b4同理):
- 計算相關性:
- Dot-product(常用):
input1 乘上Wq矩陣形成q
input2 乘上Wk矩陣形成k
q與k做inner product形成alpha - Additive:
input1 乘上Wq矩陣形成q
input2 乘上Wk矩陣形成k
q與k做加法再透過activation function,
再乘上W矩陣,行程alpha
2. 計算過程
- 套用Dot-product在self-attention
alpha1,1~4稱為attention score
右上角的公式為soft-max的公式,不一定要soft-max,也可以用ReLU
- 根據alpha抽取重要資訊產生b1
如果attention數值越大,weighted sum 後的結果會偏向attention大的v
進階version:Multi-head Self-attention
用q找k,相關性可能有不同類別,所以需要多個q找出k來負責多種相關性。
說明看影片15:07
應用面
- NLP
- 語音辨識
*Truncated self-attention:不要看一整句話,看小範圍即可。 - Image
Image可視為vector sets.
- CNN: 可視為簡化版的self-attention filed,因為只專注於receptive field(擷取出的特徵)
- self-attention: 複雜化的CNN,receptive field自己被學出來
3. CNN v.s. self-attention:
當資料少時:選CNN ->無法從更大量的資料get好處
當資料多時:選self-attention->太少資料可能overfitting
- Graph
關聯性似乎不用計算,看圖內點之間有相連的node即可,
沒有相連(關係)則免去學習的過程。
補充說明
- self-attention的優點為output平行化產生
- 目前的self-attention沒有位置資訊
所有字句的位置對self-attention皆相同 #天涯若比鄰
但部分字句的位置資訊可能很重要
=>使用Positional Encoding
ei為位置資訊,每個字有自己的ei,且ei不重複。
3. self-attention的變形通常稱為xxformer