2019-NLP最強模型: XLNet

WenWei Kang
Taiwan AI Academy
Published in
14 min readJul 8, 2019

在2019年6月中旬Google提出一個NLP模型XLNet,在眾多NLP任務包括RACE, GLUE Benchmark以及許多Text-classification上輾壓眾生,尤其是在號稱最困難的大型閱讀理解QA任務RACE足足超越BERT 6~9個百分點,其中XLNet模型改善了ELMo, GPT, BERT的缺點,有ELMo, GPT的AR性質,又有跟BERT一樣,使用AE性質能夠捕捉bidirectional context的訊息,最後再把Transformer-XL能夠訓練大型文本的架構拿來用,比我家的貓還愛吃。

目前NLP的發展趨勢越來越靠近Pre-train model+Downstream(transfer learning),即先訓練一個夠generalize的模型,接著依照下游任務的需求去更改結構並Finetune模型,真正重要的其實是Pre-train model token之間依賴關係,即word embedding,有良好的word embedding基本上效果也會不錯,而目前在Pre-train word embedding的方式都是預測Sequence本身的單詞,即依賴上下文來預測單詞,著名的模型如ELMo, GPT, BERT, ERNIE。

本篇文章會從基本NLP在Pre-train時會使用的AR以及AE模型性質開始討論,接著帶入XLNet的細節,介紹XLNet如何把這兩種性質同時實現:

AutoRegressive(AR):

Eq.1

又稱自迴歸,即給定一段Sequence {x1,x2,…,xt}, 在Pre-train時使用{x1}預測x2,接著使用{x1,x2}預測x3,直到最後使用{x1,…,xt-1}預測xt,這種性質的模型往往都認為下一個字的出現依賴於上文,也就是找出一個參數θ最大化{x1,x2,…,xt}的log-likelihood (Eq.1)。

Fig.1

目前在NLP使用到AR的模型如Fig.1,GPT使用Transformer搭配Mask來實現AR,而比較特別的是ELMo這個模型,ELMo使用Bi-LSTM,Foward與Backward兩部分各產生一個hidden state,接著把兩個hidden state concatenate,其實只是把資料反過來(Reverse)再訓練一次,兩個方向的LSTM在訓練過程中仍然是互相獨立的,講簡單一點就是一次訓練兩個AR,本質上還是一個AR模型。這種AR模型的缺點在於一個詞通常是要上下文一起判斷的,AR的思想就是一個詞只能從上文或者下文任一個方向來判斷。

AutoEncoding(AE):

Eq.2
Fig.2

又稱自編碼,這種模型的早期的提出是為了降維,後來出現許多種變形,例如Denosing AE(DAE)的提出是為了降噪,而常見的降噪例如說音訊雜質或是圖片的污點等等,而在Pre-train模型中首次有使用到DAE的模型就是BERT的MLM(Masked Language Model),即Eq.2的log-likelihood,做法是Pre-train時隨機把所有句子中15%的token使用<Mask> token來替代,然後在預測時將<Mask> token的位置預測成原來的字,把替換成<Mask>然後再還原的作法可以視為一種DAE,這種作法可以幫助預測<Mask>時充分運用到上下文的資訊,但是作者認為<Mask>只有在Pre-train時會用到,在Finetune時就完全不會用到,這會造成Pre-train和Finetune之間資訊不對稱的問題(Input noise),而這也是XLNet想要解決的方向,Fig.3假想是BERT在self-attention時的權重矩陣,橘點為attention的位置,能夠明顯觀察到BERT在Pre-train與Finetune時的差別。

Fig.3

另外BERT隨機將15%的token使用<Mask>,若當一個Sequence要預測的<Mask>有兩個以上時,以BERT的訓練方式是同時輸出<Mask>的位置,這會造成<Mask>之間相互獨立的現象(Independence Assumption),示意圖如Fig.4,x3與x4之間是有依賴關係的,但是在BERT的訓練方式則是用{x1,x2,x5}來一起預測{x3,x4},如Eq.4,所以說為了解決這個情況,必須套用AR的想法,先預測x3再預測x4。

Eq.3
Fig.4
Eq.4

Permutation Language Modeling

XLNet的想法就是要使用AR的方式來預測單詞,又要能在不使用<Mask> token的前提下學習到上下文的資訊,所以XLNet提出的Permutation Language Modeling(PLM),即使用permutation實現上下文對於單詞的預測,其實訓練方式還是transfomer的self-attention,只是對輸入與attention matrix進行一點修飾。

Fig.5

Fig.5是PLM的實現想法,假設有一Sequence{x1,x2,x3,x4},則一開始先對Sequence做permutation,得到一個新Sequence{x2,x4,x3,x1},接著再隨機選擇一個target作為預測目標,以Fig.5的例子是x3的位置,這時你會發現新Sequence的排列方式不僅能夠讓x3 attend上文x2,也能夠attend下文x4,這就是PLM的核心思想,其中mem⁰是Transformer-XL的memory架構,先不討論,最後得到的Attention matrix會像下圖,橘點代表說token之間是否有互相attention,發現x3可以往後attend到{x3,x2,x4},x4往後attend到{x4,x2},x2則是attend到自己,而空白的部分則是被Mask起來(softmax(x-e30))。

Fig.6

PLM很容易的讓target去利用上下文來幫助預測,但是在細節實現中不太可能這麼做,因為每當你排列一次,都要記錄當前的排列前後的對照,然後在預測完後依照dictionary排回來,所以作者選擇在不更動原始Sequence順序的前提下,使用mask實現AR+permutation。

Fig.7

這裡一樣假設有一Sequence{x1,x2,x3,x4},排列新Sequence為{x1,x4,x2,x3},這時target若為x3,則x3可以attend{x1,x4,x2},x2可以attend{x1,x4,x2},x4可以attend{x1,x4},x1則attend{x1},真正在實現時不會動到attention matrix的兩軸原始順序,而是依照排列後所有的token中,各自往後attend的tokens,如Fig.8右邊就是兩軸原始順序的attention matrix,橘點就是有attention的位置。

Fig.8

以上PLM就是XLNet如何確保使用AR又能夠同時捕捉到上下文的作法,但筆者一開始看到PLM時,心中一開始的想法是若把Sequence原始順序打亂不就失去token之間的順序關係了嗎?不過事後想想如果這種方法若有效的話,是不是代表在token之間只需要attention來判斷,而不需要順序的概念了?

Two-Stream Self-Attention

筆者前面介紹了PLM如何捕捉上下文,但是還沒解決如何取代<Mask>這個token,在BERT中,<Mask>告訴了模型要預測的單詞位置還有前後文關係,而在XLNet則是用Two-Stream Self-Attention來實現這兩種目的,分別是Conten stream以及Query stream:

Fig.9

整個XLNet在Pre-train時分為兩個stream,Content stream負責學習上下文,而Query stream這個角色就是用來代替<Mask>token,其負責把Content stream產生的representation拿來做預測,如同BERT的<Mask>一樣,Query stream只有在Pre-train時預測單詞會用到,到了Finetune時就不會用到了,以下介紹這兩個stream的作法:

Content Stream

Content Stream是一個標準的self-attention(Eq.5),附圖Fig.10表示h1在進行Attention時會使用到QKV, h1作為Q,h1~4作為K與V,意思就是h1與h1~4進行attention,最後得到Attention weight,接著再與V相乘得到h1在第二層的representation。

Fig.10
Eq.5

真正計算時是所有h1~4一次計算完,如附圖Fig.11,這邊的Sequence排列順序是{x3,x2,x4,x1},橘點一樣表示每個token是否有往前attention,紅框則是上圖Fig.10在計算h1的位置。

Fig.11

Query Stream

Query Stream負責在Pre-train擔任預測單詞的作用,因為在預測單詞時不允許模型看到當前的token是什麼,所以作者在這邊另外設置一個representation g來去attend其他位置,以Fig.12來說,就是用g1去attend h1~4(嚴格來說是h2~4,因為之後會把h1 mask掉),計算Attention公式如Eq.6,可以從紅底線的位置發現Eq.5和Eq.6之間的差別,Query stream會把當前t位置的attention weight mask掉。

Fig.12
Eq.6

如附圖Fig.13,各位會發現Query Stream為了防止模型看到當前的token,會把對角線的attention weight mask掉,紅框表示Fig.12的g1位置,g1只能attend h2~4。

Fig.13

針對Query stream有一個需要注意的地方,也就是PLM會造成預測一對多的情形,例子如下圖Fig.14,第一次排列後順序為{3,2,4,1}並預測4,以及第二次排列順序為{3,2,1,4}並預測1,會發線兩者的KV長得一模一樣,這是因為我們沒有告知模型預測token真正的位置,所以在計算時會告知模型預測位置,如公式Eq.7中的zt,而這個zt在實現中是一個one-hot vector。

Fig.14
Eq.7

Long Text Understanding

在XLNet中,作者為了讓模型有大型文本的學習能力,借鑑Transformer-XL的Segment recurrence mechanism和Relative positional encoding的方法,簡單來說就是讓不同Segment之間互相Attention,公式如Eq.8,h上方波浪符表示上一個Segment所有的hidden representation 。

Eq.8

從公式就知道是用當前Segment與兩者Segment concatenate的結果進行Attention,詳細的運作方式如Fig.15,Q是當前Segment,K是上一個Segment與當前Segment的concatenation,兩個矩陣相乘後就得到右邊的長方形Attetnion matrix,因為沒有舉例要Permutation的位置,筆者這裡隨便舉一個Sequence出來,重點是排列後的Sequence{z5,z7,z6,z8},並把z6當成Target去把Attention matrix沒有attend的位置mask掉。

XLNet重要的架構差不多都介紹完了,另外會在句子中使用[CLS]這個token輔助finetune的預測任務,其餘參數細節跟BERT都差不多,各項任務表現也請讀者自行查閱論文。

Conclusion

XLNet最主要的貢獻在於使用PLM讓AR和learning bidirectional contexts共存,並另外使用一個Query stream來代替BERT的<Mask>,接著借鑑Transformer-XL實現大型文本學習,其中最令筆者驚嘆和在意的還是PLM與AR之間的衝突,因為PLM會隨機把單詞順序打亂,而AR的預測方式是有順序性的,所以表示XLNet是用AR的方式來實現PLM,單詞順序在XLNet眼前感覺根本沒必要,只需要Attention搭配預測Target,就能夠無視單詞順序之間的關係,真的太厲害了,偉哉XLNet。

References

  1. Jacob Devlin, Ming-Wei Chang, Kenton Lee, and Kristina Toutanova. Bert: Pre-training of deep bidirectional transformers for language understanding. arXiv preprint arXiv:1810.04805, 2018.
  2. Matthew E Peters, Mark Neumann, Mohit Iyyer, Matt Gardner, Christopher Clark, Ken- ton Lee, and Luke Zettlemoyer. Deep contextualized word representations. arXiv preprint arXiv:1802.05365, 2018.
  3. Alec Radford, Karthik Narasimhan, Tim Salimans, and Ilya Sutskever. Improving language understanding by generative pre-training. URL https://s3-us-west-2. amazonaws. com/openai- assets/research-covers/languageunsupervised/language understanding paper. pdf, 2018.
  4. Zihang Dai, Zhilin Yang, Yiming Yang, William W Cohen, Jaime Carbonell, Quoc V Le, and Ruslan Salakhutdinov. Transformer-xl: Attentive language models beyond a fixed-length context. arXiv preprint arXiv:1901.02860, 2019.
  5. YANG, Zhilin, et al. XLNet: Generalized Autoregressive Pretraining for Language Understanding. arXiv preprint arXiv:1906.08237, 2019.

--

--