Attention Mechanism

WenWei Kang
Taiwan AI Academy
Published in
10 min readMar 29, 2019

在本文中,筆者會以自己的角度與想法來介紹Attention mechanism,包括一開始發跡的論文、架構與想法與廣泛的應用層面。

在閱讀本文前,希望讀者有兩個基礎模型的先備知識,下面也會簡單的介紹這兩個模型,之後的主軸會在討論Attention mechanism:

  1. Recurrent Neural Network (RNN)
  2. Sequence to sequence (seq2seq)

1. RNN

中文稱為遞歸神經網路,這個模型主要用於解決序列問題,給予模型一個序列,接著進行預測,常見的序列問題在自然語言處理領域如Text classification等等,除此之外還有如股價預測等等,只要有跟時間序列有關的任務都可以嘗試使用RNN來預測。

之後有名的RNN系列如LSTM, GRU的提出都是為了解決RNN的長序列問題,LSTM如Fig.1,這裡需要記住的重點是,不管是哪種RNN,在每個timestep中一定都會產生一個hidden state,每個time step的hidden state包含了過去time step的資訊。

Fig.1

2. Seq2seq

Seq2seq是一種用序列來預測序列的架構,簡單來說就是模型的輸出輸入都是序列,著名的任務如Neural Machine Translation(NMT)、Speech Recognition、Name Entity Recognition(NER)等等,最近小弟研究的東西也是以這個為基底。

Fig.2 Image source: https://jeddy92.github.io/JEddy92.github.io/ts_seq2seq_intro/

Seq2seq分為Encoder與Decoder,以Fig.2來說,兩個分別都是GRU,Encoder負責接收輸入序列,Decoder負責預測輸出序列,不妨可以想像成現在在做NMT任務,一段中文句子丟入Encoder之後,在最後一個timestep產生Encoder state(hidden state),Decoder拿到Encoder state並將其作為initial state,接著輸出一段英文句子,第一篇論文將Seq2seq應用在NMT,之後的研究基本上都圍繞在Encoder與Decoder之間如何做串接對應,著名的架構就是Attention,接下來開始介紹Attention mechanism。

Neural Machine Translation

第一個提出Attention這個想法的論文就是應用在NMT任務上,之後許多論文其實仔細看的話,會發現計算方式非常相似:

NEURAL MACHINE TRANSLATION BY JOINTLY LEARNING TO ALIGN AND TRANSLATE

網路上的seq2seq有許多種型式,為了之後介紹順利,接下來用Fig.3的seq2seq來介紹一些符號,這個seq2seq是最常見的:

Fig.3

上下紅框各為Encoder與Decoder,都是RNN:

  • Encoder:輸入Source_sequence分為S₀~Sn,你可以想像成一個中文句子有n+1個字,每個字作為RNN的一個timestep去輸入,經過了n+1個timesteps之後,輸出最後一個timestep的hidden_state En,En會作為context vector C,這個C我們可以認為是整個Source_sequence的representation,而在每個timestep中,都會產生一個hidden_state (E₀~En),以最基礎的seq2seq而言,就是拿En作為context vector C
  • Decoder:接收到Encoder傳來的context vector C之後,C會作為Decoder的initial state繼續RNN的訓練,在Decoder的部分稍微特殊一點,我們需要把Decoder分成Training與Inference來討論,以Fig.4輔助說明:
  1. Training:訓練階段時,Decoder的RNN前一個timestep的輸出不會作為下一個timestep的輸入,而是直接拿正確答案進去輸入,這個技巧叫做teacher forcing,以Fig.4來說,藍色句子原始為(y1,y2,…,yn),輸入為(<SOS>,y1,y2,…,yn),輸出為(y1,y2,…,yn,<EOS>),第一個timestep t0的輸入為起始信號<SOS>,接著下一個timestep t1的輸入會不管上一個timestep t0的輸出正確與否,直接在t1輸入正確答案y1,這個做法是為了確保Decoder不會一直一錯再錯,如果在t0輸出錯誤答案,那以以往RNN的做法是直接把錯誤答案輸入t1,為了避免這個現象,才有teacher forcing的做法。
  2. Inference:在Inference階段,我們因為沒有正確答案輸入給Decoder,所以就遵循以往RNN的做法,t0的輸出會作為t1的輸入。
Fig.4

接下來討論原始的seq2seq中會有什麼問題,之後衍伸出attention的想法:

  • Seq2seq中Encoder與Decoder之間串接的只有一個context vector C,這意味著不管輸入給Encoder的source_sequence長度多寡,都會被一個固定維度的向量C給限制住。
  • 不管RNN還是其餘系列(LSTM, GRU),其實都避免不了Long-Term Dependence的問題,一旦輸入序列拉長,越早輸入的Token就會越容易失去影響能力,更何況以seq2seq拿最後一個timestep作為context vector C的做法,同理,在Decoder是在第一個timestep接收到context vector C,一旦要預測的序列變長了,越後面timestep的預測值就越不受C控制,也就是跟Encoder越來越沒關係。

為了解決以上兩個問題而有了attention想法,attention的想法是說,有沒有辦法為Decoder中的每個timestep都設計一個context vector C?如Fig.5,在Decoder的第二個timestep準備進行預測T1時,就去計算屬於T1自己的context vector,這就是attention mechanism的核心思想。

Fig.5

Attention Mechanism

“計算屬於每個timestep自己的context vector”,這句話已經闡述了attention的想法,接下來會以上面提到的論文為主,介紹一開始最基礎的attention架構應用在NMT任務上。

Fig.6

Fig.6是這篇論文的主體,接下來會以藍線分為Encoder與Decoder個別介紹:

Encoder

相信讀者看到Encoder時已經有一點頭緒了,Encoder部分中作者使用Bidirectional GRU,在Forward與Backward階段都有一個GRU在處理,而Forward與Backward的差別在於輸入的序列順序相反,因為處理的是NMT任務,所以作者這樣做的目的是希望Encoder可以更好的學習句子的上下文關係,接著Forward與Backward各產生hidden state之後,這篇論文的做法是將兩個hidden states進行concatenation,作為Encoder的輸出。

也有其他做法來處理兩個hidden states,例如直接相加、通過一層全連接層、element-wise multiplication、取平均等等…

Fig.7

Decoder

當Decoder獲得了從Encoder的hidden state後,這時Decoder最重要目的在於如何利用這T個hidden state,對於Decoder來說,這T個hidden state有些重要而有些不重要,也就是說必須設計一個機制讓Decoder能有判別的能力,也就是讓Decoder能夠知道取捨。

Fig.8

這篇論文使用一層NN來計算t-1時Decoder對於所有Encoder的attention的值eᵢⱼ,這每一個e就代表著Decoder的hidden state St-1對於Decoder所有hidden state h的重要程度。

計算attention的方法不只有使用一層NN,還有直接用內積、concat之後丟NN、linear transform等等,最終都是用一個值來表示attention的程度。

Fig.9

得到e後,接下來較為普遍的作法是通過softmax,將每個e進行normalize得到a,每個a值介於0~1之間且總和為1,並視為attention weight,接著將attention weight與hidden state進行線性組合,得到最終的context vector c。

Fig.10

接著如何利用這個c就有許多版本了,以這篇論文來說是將c視為GRU另外一個輸入,得到這個時間點的hidden state s,另外還有其他種做法,例如:

  • 將s與c進行concate,然後通過一層NN輸出vocabulary的概率
  • 將embedding word Ey與c進行concate,然後通過GRU得到s,然後再通過一層NN輸出vocabulary的概率

以筆者的觀察,通常都是concatenation的效果比較好,用加的會把資訊互相抵銷,很多論文都是以concatenation的方式達到該任務的SOTA。

Conclusion

介紹到這邊,我們了解到attention mechanism的核心理念,之後許多論文都圍繞在Encoder與Decoder之間如何去attention,當然attention的應用不只有NMT任務,還有image caption、video caption等等,以image caption來說,就是生成caption時,讓每個timestep都去決定要attend圖片中哪些位置,基本上許多論文都有類似的想法。

除此之外,在2017年6月google發表了一篇paper: Attention is all you need.,該論文不管在Encoder還是Decoder都沒用到RNN或CNN系列的架構,從頭到尾就使用全連接層NN來實現Attention,並達到當時NMT的SOTA,實在令人驚嘆。

--

--