(ML) Seq2seq -> Attention -> Self-attention -> Transformer

YEN HUNG CHENG
17 min readFeb 10, 2023

--

Photo by Jéan Béller on Unsplash

RNN (Recurrent Neural Network) 回顧

為何需要 RNN ?

Example:

今天對訓練好的模型輸入一句話,做簡單的詞性標註 (POS tagging)

例句:I saw a saw.

我們所期待模型輸出的結果:

I -> N

saw -> V

a -> DET

saw -> N

但是今天模型並沒有那麼聰明,我們輸入相同的 saw ,對模型而言 input 一樣的東西,output 也是一樣的東西,這時候我們就會想要我們的 network 具有記憶力,而具有記憶力的網路就稱作 RNN。

source

圖中在不同時間點所用到的 U, V, W 都一樣

Seq2seq (Sequence to sequence)

我們先看一下 Seq2seq 的架構

source

以上的架構跟前面提到的 RNN 是否有些相似呢?而 Seq2seq 採用的架構就是 Encoder-Decoder 的架構

Encoder-Decoder Network

上面的圖我們可以再簡化一下:

….

在機器學習中,我們對訓練模型輸入的就是一串向量/矩陣對吧?不管是影像辨識、聲音辨識、語音翻譯 等等任務。

….

Encoder-Decoder Network 我們可以把它看成兩個模型:

Encoder(編碼) 將輸入的資訊變成一個 Code (Content Vector)

Decoder(解碼)將輸入的 Code 解出成一串資訊

為何要採用 Encoder-Deconder 架構呢?

思考一下 Encoder-Decoder 的運作方式,是不是跟我們人學習語言的時候一樣呢?

Example :

今天我們不斷學習語言,我們大腦中已經學到了蘋果的英文是 Apple ,當有人問我說蘋果的英文是什麼時,因為我已經學習語言一陣子了,並知道蘋果的英文就是 Apple,所以我可以回答蘋果的英文就是 Apple。

大腦讀入 蘋果時 -> Encoder

從大腦中得知 蘋果 = Apple 的資訊 -> Code

嘴巴說出 Apple -> Decoder

缺點:

今天我們把一個不限長的輸入,給 Encoder 成一個 Code (固定長度的向量),也就是說隨著我們的輸入越來越長,Code 的訊息損失就會越大。

Example : 假設給你看一篇很長很長的文章,在你讀完一整篇文章後,你可能就會對越前面的資訊越模糊。

Seq2seq

再把焦點拉回 Seq2seq ,我們已經大概了解到了 Seq2seq 中所採用的架構就是 Encoder-Decoder 架構,而可以運用在架構中的可以是 RNN、LSTM、GRU。

而 Seq2seq 就是用來解決序列到序列之間的任務,而 Seq2seq 的 output 長度是由模型自行去學習的。

一般RNN可處理的問題是輸入序列長度與輸出序列長度一樣的case。實際上,它能處理輸入與輸出長度不一樣的case。
翻譯時就需要處理輸入與輸出長度不一樣的case。例如把機器學習翻成machine learning。翻譯時不知道翻出來的長度比輸入長度長還是短,此時需要 Seq2seq model。

Example :

機器翻譯

機器學習 -> machine learning

source

Encoder

將 input Sequence (機器學習) 表達成一個 vector ,然後輸入到 Seq2seq model 中的 Encoder 做 Training ,在圖中的 Encoder 最後的時間點的 hidden layer 的 output (Encoder 中紅色部分)抽出來,抽取出來的部分可能就包含了所有 input 的句子資訊。

Decoder

接下來將 Encoder 的 output 丟入 Seq2seq model 中的 Decoder 產生 文字 (machine),下一個時間點再丟一次,它就會輸出 learning ,最後會產生 . (句號)。

Encoder 與 Decoder 是 Jointly train 的,也就是你給機器 Encoder 的 input ,還有 Decoder 該有的 output ,你就能一起訓練他們的參數,圖中 Encoder 與 Decoder 的參數是不一樣的,但是也能使用同一組參數去 train ,參數少也就比較不容易 overfiting,如果 data 比較多時,那就用不同組的參數去 train ,這樣比較有可能得到較好的 performance 。

以上就是機器翻譯的流程,其中所使用的架構就是使用 RNN

Attention Mechanism

Attention Mechanism 也就是注意力機制,為何需要 Attention 呢?試著回憶一下前面所提到 Seq2seq 的問題,當今天輸入的訊息過長時,Code 的損失也就越大,所以我們是不是能夠讓 Code 只專注 (Attention)在某幾個輸入資訊,不必輸入整個資訊做學習呢?

Example :

我們讓 Decoder 在每一個時間點,看到的資訊都不一樣,這樣做有什麼好處呢?

<Ans> 今天你的 input 資訊很複雜,Encoder 會很難將 input 的 sequence 變成一個 Code,

也就是說我們讓 Decoder 只考慮比較需要的部分,原本的 Decoder 的 input 是整個句的資訊,今天我們讓 Decoder 在產生 machine 這一個詞的時候,我們只去考慮 input 中 『機、器』這個部分就好,因為只有機器跟 machine 是有關的,這樣 model 其實會學習的更好,因為它並不需要再從整個句子的資訊去提取出需要的部分。

source

model 在看 learning 的時候, input 只需要看『學、習』這兩個詞就好,這樣 Decoder 就能產生出 learning 。

source

Attention-based model

實際作法 :

我們一樣先用 RNN 來做處理,在每一個時間點每一個詞,都能用一個 vector 來表達,這些 vector 也就是 RNN 中 hidden layer 的 output,接下來會有一個初始的 vector z0,可以把 z0 當作是 network 的參數,這是可以根據 training data 學出來的

source

再來會將 z0 和 h1 丟入一個 match function 中計算 α ,也就是在初始狀態時 會去計算 z0 與 h1 有多匹配。

source

match function 其實沒有固定的做法,比較常見的作法如下:

  1. 計算 z 與 h 的 cosine similarity
  2. 一個簡單的 NN (Neural Network)輸出一個 scalar
  3. 下圖數學式子, W 是 network 學出來的

若 match function 中有其它參數的話,會與 network 的其它部分,一起 train 出來。

有了 match function 後,我們就能對 每一個 hidden layer 的 output 計算 match 的分數,得到每個 match score 之後,可以通過 softmax 做 normalize,但這一步不一定要做,有人嘗試過不做 softmax 之後,performance 反而比較好

source

c0 是如何計算的呢?可以查看下面的圖來了解。

source

再來將 c0 丟入 Decoder ,而這一個丟入 c0 也就跟 『機、器』比較有關,因為它們的 score 比較高,這時 Decoder 就會 output 一個 machine

source

z1 它是把 c0 丟入 RNN 裡面後,RNN 的 hidden layer 的 output ,接下來你就可以用 z1 再去算一次 match score

source

持續剛剛 z0 所做過的事,我們會得到新的 α,得到的 α 再通過 softmax 得到 match score,再來就能計算出 c1

source

再來將 c1 丟入 Decoder ,而這一個丟入 c1 也就跟 『學、習』比較有關,因為它們的 score 比較高,這時 Decoder 就會 output 一個 learning

source

以上的步驟會持續下去,直到 Decoder 產生出 . (peroid),它就會停下這個循環

Self-attention (自注意力機制)

還記得上面我們使用 RNN 來產生 machine 與 learning 時,它是先產生 machine,再產生 learning 的,也就是說它並不容易被平行運算,而 Self-attention 就是用來處理平行運算的問題

實際做法:

把在 Encoder-Decoder 中的 RNN 架構改成 Self-attention

source

Self-attention 的 input 就是 vector,它可能是你整個 network 的 input,或是某層的 hidden layer 的 output

經過 Self-attention 的 vector 用黑色框表示,它代表已經考慮過完整 input 的 sequence 才得到的 inforamtion

而下圖中的 b1 ~ b4 是同時被計算出來的

source

如何計算 b1 呢?

先找出 a1 與其他 vector 的相關性,並用 α 來表示

source

較常使用的方法是利用 Dot-product

input1 乘上 Wq matrix 形成 q
input2 乘上 Wk matrix 形成 k
q 與 k 做 inner product 形成 α

source

接下來開始計算 a1 與其它 vector(a2 ~ a4)的關聯

  1. 讓 a1 乘上 Wq matrix 形成 q1
  2. 讓 a2 ~ a4 乘上 Wk matrix 形成 k2 ~ k4
  3. 再來將 q1 與 k2 ~ k4 做 inner product 形成 α (attention score)
  4. 當然也會跟自己算關聯,所以其實 a1 也會 乘上 Wk matrix 形成 k1,最後再與 q1 做 dot product 形成 α
source

產生出 4 個 attention score 之後,會經過 softmax,當然不一定要使用 softmax ,也能使用 ReLU

source

得到這些 α’ 之後,我們就能透過 α’ 來抽取 sequence 的重要資訊,根據這個 α’ (最左邊的α),我們就能知道哪些 vector 是 a1 最有關係的,接下來我們要根據 attention score 來抽取重要的資訊,如何抽取重要的資訊呢?

將 a1 ~ a4 再乘上 Wv matrix 形成 v1 ~ v4,再把 v1 ~ v4 都乘上 α’ (attention score),最後把它加總起來得到 b1

source

如果今天 a1, a2 的關聯性很強, a2 的 α’ 應該很大,最後做 weighted sum 的時候,b1 的值可能會比較接近 v2

以上就是 Self-attention 的運算過程,而 b1 ~ b4 都是以上過程所同時產生出來

整理:

  1. 每個 a 都會得到 q k v 三個 matrix
  2. 每個 q. k 都會做 dot-product 得到 α,最後做 softmax 得到 α’
  3. 最後 α’ * v 相加得到 b
  4. 以上你會發現只有 Wq, Wk, Wv 是未知的,也就是我們要讓 model 自行去學習出來的參數

以上步驟稱為:Scaled Dot-Product Attention

Scaled Dot-Product Attention

Multi-head Self-attention

Multi-head Self-attention 是 Self-attention 的延伸, 透過計算更多的聯性來希望做出更佳的結果

假設有兩個 head, 第一個 head 計算

qi,1 在算 attention score 時,它只管 ki,1 kj,1 就好,也就是 qi,1 跟 ki,1 做 dot-product,qi,1 跟 kj,1 做 dot-product ,然後再各別乘上 vi1, vj1 做 weighted sum ,得到 bi,1

source

第二個 head 計算

qi,2 在算 attention score 時,qi,2 跟 ki,2 做 dot-product,qi,2 跟 kj,2做 dot-product ,然後再各別乘上 vi2, vj2 做 weighted sum ,得到 bi,2

source

如果有更多的 head 的話,就以此類推下去

Multi-Head Attention

Positional Encoding

為何需要 Positional Encoding ?

回憶一下對一個 self-attention layer 而言,每一個 input 它是出現在 sequence 的最前面還是最後面呢?它是完全沒有這個資訊的,上面的 a1, a2, a3, a4 其實是讓大家好理解過程而標記上去的,對 self-attention 而言,位置 1, 2, 3, 4 在哪都是沒有差別的,這 4 個位置的操作其實是一樣的,對它來說 q1 ~ q4 的距離沒有這麼遠,它們的距離其實都一樣,但是這樣設計可能會有一些問題,因為有時候位置的資訊也是很重要的,就像是最開始提到的 詞性標註 (POS tagging),你可以將位置的資訊塞進去,實際作法就是你給每個位置一個 vector 稱作 positional vector ei,ei 其中的 i 代表 position ,再來把 ei + ai 就結束了

positional vector 是 hand-crafted(人設的)

source

Transformer

什麼是 Transformer?當然不是變形金剛 😂,可以先看一下左圖 Transformer 完整的架構,看起來有點複雜對吧?但如果從上面文章看到這的你,實際上你已經都相當熟悉內部的操作了

而 Transformer 採用的也是 Encoder-Decoder 架構,右圖所示 紅色框為 Encoder 藍色框為 Decoder,其中的 Multi-Head Attention 與 Masked Multi-Head Attention 採用的其實就是 self-attention model,而 Positional Encoding 就是給定輸入的資料 position 資訊

Encoder

再次複習一下,對於一個 Seq2seq model 來說,今天 input 一排 vector 輸出另外一排 vector,RNN 或是 CNN 也都能做到

先來看一下 Transformer 中 Encoder 的下半部分吧,下方的三個輸入也就是 K, Q, V,至於怎麼產生的,在介紹 Self-attention 時已經說明過了,實際步驟如下:

  1. 做 Self-attention
  2. 輸入的 vector 做 residual connection 也就是 input 與 output 做相加
  3. 通過 Normalize(layer normalize) 得到輸出的 vector
source

最後 Transformer 中 Encoder 的上半部分步驟:

  1. 剛剛處理的 vector 通過一個 FFN (Feed Forward Networks)
  2. 原本的 vector 與 通過 FFN 的 vector 做 residual connection
  3. 通過 Normalize(layer normalize) 得到最後的輸出 vector
source
FC = FFN

FFN (Feed Forward Networks)

公式可以看到輸入 x 先做線性運算後,再送入 ReLU,最後再做一次線性運算

source

整理:

  1. inputs 會先加入 Positional Encoding 告訴 inputs 的 vector 位置資訊
  2. inputs 通過一次 Self-attention 後,再與 input 做 residual connection,最後再做 Layer Normalize 得到 outputs
  3. 得到的 outputs 後會先通過 FFN,再與 outputs 做 residual connection,最後再做 Layer Normalize 得到最終的 outputs
source

Decoder

來對比一下 Encoder 與 Decoder 的架構,先把 Endoer 與 Decoder 連接的部分遮住,我們可以發現 Encoder 與 Decoder 的架構基本上是一樣的,唯一較為不同的是 Decoder 下半部採用的是 Masked Mulit-Head Attention

source

Masked Self-attention

現在產生 b 的時候,不再看右邊的資訊,這是什麼意思呢?

產生 b1 時,只能看 a1 的資訊

產生 b2 時,只看 a1, a2 的資訊

產生 b3 時,只看 a1, a2, a3 的資訊

產生 b4 時,能看 a1, a2, a3, a4 的資訊

source

思考看看 Decoder 的運作方式,它的 output 是一個一個產生的,所以先有 a1 才有 a2,有 a2 才有 a3 以此類推下去,在 Encoder 時 Self-attention 的 a1 ~ a4 是同時產生的,但在 Decoder 時,你產生 a1 時 a2, a3, a4 根本還沒產生出來,所以你根本無法將還沒有產生的資料考慮進來

講完 Masked Self-attention 之後,剩下 Encoder 與 Decoder 連接的部分,而這連接計算的方式,稱作 Cross-attention

Cross-attention

我們可以發現 Decoder 中間部分的 Self-attention 有兩個 input 是來自 Encoder 的 output ,最後一個就來自 Decoder 的 output ,而計算 attention score 的過程都相同,只是今天的 inputs Wk, Wv 是來自 Encoder 的 output 所產生的,而 inputs Wq 是來自於 Decoder 的 output 所產生的

source
source

以上就是 Transformer 大致的運作過程

source

--

--