突破 Transformers 的速度瓶頸!Flash-Attention 介紹

黃顯堯
17 min readSep 2, 2023

--

前言

前一陣子在訓練公司內部的 LLM,為了最大化的使用公司內部所有的運算資源來加速整個訓練的過程,而嘗試了許多平行和優化的訓練方法,而其中也包含了 flash-attention ,在閱讀後 FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness 的論文,並且嘗試了一些實驗之後,覺得概念很有趣很值得分享,才有了寫這篇文章的想法!

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness. Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré

介紹

近年來, large-scale language models (LLM) 向世界展示出了非常強大的能力,在大家擔心自己會不會被 AI 取代的同時,各行各業也都積極的探索並且嘗試將這樣厲害的能力整合進不同的產品當中。

但是要駕馭 LLM (參數量 > 7B ) 本身並不是一件容易的事情,尤其是當我們本身的硬體資源不是那麼充足的時候 (e.g., 運算資源不夠或是 GPU memory 不夠)。舉例來說,假如我們想要 inference 一個參數量 7B 且 data type 為 float16 的 LLM 時,我們至少需要 7B * 2 bytes (= 13 GB) 的 GPU memory 來將整個 model 放進去 GPU 當中,而這也只是將 model 載入而已,不包含實際運算時 input data 和 activations 所需要的 memory!

儘管目前有許許多多的方法可以幫助我們降低硬體資源需求的負擔,像是將 model quantize 成更小的 data type (e.g., float16 -> int8) 或是將好幾個 operations fuse 成一個 operation 進而減輕 memory 存取的 loading,但是在現今以 Transformers 為底所建立的眾多 LLM 始終有一個根本的 bottleneck 使得我們沒有辦法讓 LLM infernece 更快更有效率,而這個根本原因就在於 Transformers 當中最重要的一個 operation — Self-Attention。

Structure of GPU Memory

在介紹 self-attention 為何會成為 bottleneck 之前,我們先來了解一下 GPU memory 本身的架構!

GPU Memory Hierarchy

在 GPU 當中,memory 也跟 CPU memory 一樣分成不同的 level,通常越上層空間越小但是速度越快,而大家平常主要提到的 GPU memory 通常是指 high bandwidth memory (HBM),以 A100 來說,HBM 大概有 40 GB ~ 80 GB 左右,且 HBM 的 bandwidth 為 1.5–2.0TB/s,再往上一層的 memory 稱為 SRAM,總容量大概有 192 KB * 108 (streaming multiproecssors) 左右,雖然大小少很多,但是他的 bandwidth 可以達到 19TB/s,因此當有運算需要從 HBM 當中不斷讀寫資料的時候,這樣的速度差就容易導致 HBM 的讀取變成整體效能的 bottleneck。

Execution Model

在 GPU 當中有非常大量的 threads (kernel) 負責執行 operation 的運算,而整個運算的過程基本上是從 HBM 當中將資料載入至 SRAM 中,執行運算並將 output 存回 HBM 當中。

Performance characteristics

而根據每個 operation 實際運算時間和 memory 存取的時間多寡,我們可以將 operations 歸納為兩個類別,分別是 compute-bound 以及 memory-bound

Compute-bound 的意思為運算的主要時間都耗費在 operation 的計算上,HBM 的存取只佔了其中一點點的時間,舉例來說,像是多維度的矩陣相乘或是高 channel 數的 convolution 都屬於這類。

Memory-bound 的意思為運算主要時間都耗費在 memory 的讀取上,而實際的運算只佔了其中一點點的時間,舉例來說,像是 elementwise (e.g., activation, dropout) 和 reduction (e.g., sum, softmax, batch norm, layer norm) 皆屬於 memory-bound。

Self-Attention

在了解了 GPU 簡單的架構後,我們可以來了解一下 self-attention 是怎麼運算的以及他的 bottleneck 是在當中的哪個地方!

Self-attention 是由 Google 在 2017 年時所提出的論文,而他主要的概念在於:

在處理 sequential data 時,self-attention 能夠讓模型為不同位置的 element 賦予不同的重要程度 (weights),以更好地捕捉它們之間的相互關係和重要性。

我們可以舉一個簡單的例子來說明這個精神,假設今天我們有一個 sentence “I am Eric”,而對於單字 “I” 來說,每個單字對於他應該都要有不同的重要程度,好比說 “Eric” 應該要比 “am” 來得重要,因為他指的是 “I” 這個人的名字,而同樣的概念也可以 apply 在單字 “am” 上,每個單字對於他也應該要有不同的重要程度,這樣子的關係也很好地反映了我們的話語 (natural language),而假如每個單字都可以找到他對於其他單字的重要程度的話,那我們就可以建構出一個 N*N 的 weight matrix,而我們稱這個 weight matrix 他為 P。

而 P 是如計算的呢?

我們需要先知道對於每個單字都有三個用來表示自己的 vectors,分別是 query vector, key vector, 和 value vector,因此假如我們有 N 個單字的句子,我們總共會有三個維度各自為 N*D (D 為 vector 的 dimension) 的 query matrix, key matrix, 和 value matrix。

而為了得到每個單字對於其他每個單字的關係性,我們可以將 query matrix 和 key matrix 的 transpose 做矩陣相乘,這時我們就可以得到一個維度為 N*N 的 matrix S,而在 S 上第 i 個 row 第 j 個 column 的 element 便是第 i 個單字的 query vector 和第 j 個單字的 key vector 內積的結果 (因此也可以當作反映著第 j 個單字對於第 i 個單字的重要程度!)。

然而,儘管我們得到了一個 matrix S,但是他還是沒有辦法真的被我們當作是權重,原因是他的 row-wise summation 並不是 1,因此我們還需要 apply 一個 row-wise 的 softmax 來真的得到我們要的 weight matrix P。

有了這樣的概念,和剛剛對於 GPU 的簡單介紹,我們可以透過以下的 algorithm 來看一下 self-attention 實際在 GPU 當中運算是怎麼進行的:

首先我們會需要將 Q (query matrix)和 K (matrix) 做矩陣相乘來得到 S,這時候我們需要先將 S 存入 HBM,接下來為了得到總和為 1 的權重,我們需要將 S 從 HBM load 出來再透過 softmax 來計算出真正的權重 P,這時候我們需要再將 P 存入 HBM 當中,然後最後還需要將 P 和 V 從 HBM 當中讀取,並且做矩陣運算得到最後的 output O 並完成整個 self-attention 的過程!

然而這樣的運算所造成 memory access 的時間複雜度為 O(N*D + N*N),其中通常 N >> D(e.g., N 為 4096 而 d 為 64),因此我們可以發現 S 和 P 的 memory access (N*N 的複雜度) 便是整體 self-attention 的 bottleneck!

但是,再仔細思考一下,其實我們可以發現對於整個 self-attention 當中,其實我們真正需要的是最後面的 output O 而已,過程當中不管 P 和 S 長什麼樣子其實對於我們來說都沒有很重要,既然他不重要為什麼我們還是要將他存入 HBM 呢?主要是因為以下兩個理由:

  1. 我們需要這些 intermediate activations 來幫助我們在 backward 的時候透過 backpropagation 計算 gradients,這也使得我們很難將多個 operations fuse 成一個 operation。
  2. 由於 SRAM 本身不夠大,而 softmax 這種需要計算 sum 的 operation,需要整個 row 的 element 都到齊後才可以計算,使得我們沒有辦法 apply 一些 divide and conquery 的 algorithm ,更使得我們沒有辦法把所有運算一口氣在 SRAM 當中計算完。

因此只要想辦法繞過上面這兩個理由,我們是不是就可以避免掉將 S 和 P 存入和讀取 HBM 所造成的 O(N*N) 的時間複雜度呢?

沒錯!!接下來我們就來介紹 flash-attention 是透過什麼樣子的魔法來解決這個問題的!

Flash-Attention: Tiling and Recomputation

在 flash-attention 當中,主要將 matrix 拆分成多個 blocks 並且用到了兩個概念: Tiling 和 Recomputation

Tiling:

在上一章節的介紹當中,假如我們有辦法避免 P 和 S 對於 memory 的存取,我們就有辦法讓 self-attention 整體的運算更加快速!因此最最極端的例子便是直接從 K 和 Q 一口氣計算出 O,直接避免中間先計算 S 和 P 的過程。

這個概念其實沒有錯,然而前提是假如我們的 SRAM 夠大的話,原因如同剛剛提到的第二的理由,在計算 P 的過程中會需要計算 softmax,這代表我們一定要先有每個 row 當中的所有 element 才有可能計算出 summation,而這個需要先得到所有 row 的限制也使得我們也沒有辦法將 K 和 Q 切分成多個 sub blocks 各自計算出各自的結果。

而在 flash-attention 當中使用了 tiling 的 trick 來解決掉這個問題!接下來我們直接用以下的例子來說明什麼是 tiling:

假設我們目前的 S row vector (還沒經過 softmax)為 [0.1, 0.3, 0.5, 0.7],而 V column vector 為 [7, 8, 9, 10],以正常的 self-attention 做計算的話我們會先透過 softmax 得到 P row vector [0.1807, 0.2207, 0.2695, 0.3292],然後再和 V column vector 計算內積得到最後的值 8.7472。

而 tiling 的概念是直接將 S row vector 和 V col vector 切成多個 sub blocks,各自計算出各自的 value,但是這樣便會使得 softmax 的計算出現問題,因為我們不知道整個 row 的值,算不出分母的 exponential summation,而得不到正確的 weight,所以接下來我們可以接著用以下三個步驟來看 flash attention 是怎麼解決掉這個問題的:

首先我們可以先將 S 和 V 分別切割成 [[0.1, 0.3], [0.5, 0.7]] 和 [[7, 8], [9, 10]] 然後各自計算出真正的 value,也就如同以下的第一步和第二步:

第一步 — 計算第一個 block 的值

首先對 [0.1, 0.3] 計算 softmax 得到 [0.4502, 0.5498],並且和 [7, 8] 計算內積得到 7.5498,同時我們將 [0.1, 0.3] 的 exponential summation (softmax 的分母) 2.455 存下來

第二步 — 計算第二個 block 的值

再來因為我們有新的 block [0.5, 0.7] 出現,我們可以依照剛剛的方式計算 softmax 並和 [9, 10] 內積得出新的 value 9.5498,並且得到 [0.5, 0.7] 的 exponential summation (softmax 的分母) 3.66

再來為了算出正確的 value,flash attention 會將上面兩步的算出來的值先用各自的 exponential summation 進行還原,才重新計算出正確的值,也就是第三步 — 校準:

第三步 — 校準

在上面兩個 sub blocks 都計算完後,我們其實沒有辦法直接將兩個 blocks 內積後的 value 相加來當作是最後的 output,原因在於他們各自的 softmax 分母是不同的!

但是!由於我們剛剛有各自存下他們兩個 blocks 的 exponential summation,因此我們可以把它們各自的 output 和 expoential summation 相乘來還原出還沒經過 normalization 的 output,並且再將兩者的 expoential summation 相加作為新的 softmax 分母,也就是 (9.5498 * 3.66 + 7.5498 * 2.455) / (2.455 + 3.66) = 8.7469。

接下來,我們再將還原後的值除上新的 softmax 分母然後相加後神奇的事情就出現了!透過以上的步驟,我們最後會得到 8.746 這個值,而這個值就跟直接將整個 row 去做 softmax 得到的結果幾乎一樣 (實際 implement 會先減去 max value 來避免經過 exponential 後 overflow,這裡為了簡化就略過這個步驟了)。

而透過上面這個三個步驟我們可以發現,在第一步計算完之後,一直到第三步計算出正確的 value,sub-block 1 完全沒有再用到了!同時我們也還是可以得到正確的 O,甚至也不用管 P 和 S 的值是多,只需要額外多存下每個 sub block 的 exponential summation 就好,這使得我們可以從頭到尾只將 sub-block 1 對應的 Q 和 K 從 HBM load 到 SRAM 一次便可以直接算出最後的 O,中間也完全不用再多存取任何東西。

儘管這樣的方式不能讓我們避免 O(N*N) 的時間複雜度 (因為我們需要 for loop 將每個 Key vector 和 Query vector 做內積,如同下面的圖),但是這樣切割成 sub-block 直接計算出結果,且不用整個 row 一起存取的方式可以讓我們將整個時間複雜度除以 M (sub-block 數量),同時減少許多 O(N*N) memory 存取的次數,還是可以達到非常顯著的效果提升!

然而,tiling 只幫助我們解決了第二個理由,還有第一個理由 還沒解決—

我們需要這些 intermediate activations 來幫助我們在 backward 的時候透過 chain rule 計算 gradients

這使得我們還是要將 P 和 S 存下來,因此接下來就要介紹 flash attention 的第二個 trick — Recomputation,使得我們不再需要將 P 和 S 寫回去 HBM 當中!

Recomputation

Recomputation 的概念其實很簡單 —

不儲存 intermediate activations 而是在有需要的時候再重新計算

這個概念其實類似於 gradient checkpointing,然而 gradient checkpointing 的主要精神是稍微犧牲一些速度但是可以大幅度的減少 GPU memory 的需求 (時間換取空間的感覺),而在這邊 flash attention 當中的 recomputation 這樣的做法除了可以節省 GPU memory 之外還可以加速!

原因在於當我們在計算 backward 時,我們本來就要將 K, Q, and V 載入 SRAM,而與其我們在 forward 時將 S 和 P 這兩個 N*N 的 matrix 存入 HBM, 並且在 backward 時再將他們兩個從 HBM load 到 SRAM,我們直接用本來就在 SRAM 當中的 K, Q, and V 重新計算出 S 和 P 反而可以更快。這點也反應了 HBM 本身相較於 SRAM 和 GPU computing 速度的差距!

而也因為結合了 Tiling 和 Recomputation,使得 flash attention 有辦法將多個 operations fuse 成一個 operation,更進一步避免了 HBM 的 read 和 write 的 loading,而不用擔心 fusion 後使得在 backward 時會無法進行 chain rule。

Experiments

接下來簡單分享一下實際測試下來的速度優化!這邊是用 flash attention 2 來做測試,而 flash attention 2 和 1 的基本概念一樣,只是有更進一步的優化,未來有機會再跟大家分享!

實驗的環境為 A100,且 parameters 數為 7B 的 language model

BatchSize (per device): 4, Sequence Length: 1024

ZeRO2 + tensor parallel:
- Fwd Time: 0.15 sec
- Bwd Time: 0.484 sec
- Optim Time: 0.09 sec
- GPU memory (per device): 17 GB

ZeRO2 + tensor parallel + flash-attn 2:
- Fwd Time: 0.143 sec
- Bwd Time: 0.44 sec
- Optim Time: 0.092 sec
- GPU memory (per device): 16 GB

BatchSize (per device): 4, Sequence Length: 2048

ZeRO2 + tensor parallel:
- Fwd Time: 0.34 sec
- Bwd Time: 0.98 sec
- Optim Time: 0.092 sec
- GPU memory (per device): 23 GB

ZeRO2 + tensor parallel + flash-attn 2:
- Fwd Time: 0.27 sec
- Bwd Time: 0.80 sec
- Optim Time: 0.092 sec
- GPU memory (per device): 20 GB

BatchSize (per device): 4, Sequence Length: 4096

ZeRO2 + tensor parallel:
- Fwd Time: 0.85 sec
- Bwd Time: 2.29 sec
- Optim Time: 0.093 sec
- GPU memory (per device): 35.6 GB

ZeRO2 + tensor parallel + flash-attn 2:
- Fwd Time: 0.529 sec
- Bwd Time: 1.55 sec
- Optim Time: 0.094 sec
- GPU memory (per device): 28.3 GB

想更了解 ZeRO2,可以看這篇文章

透過上述的實驗結果,其實我們可以發現當 N (Sequence length) 越大時,所產生的 performance improvement 更加的明顯!主要原因也在於當 N 越大時,對於 HBN O(N*N) 的存取 loading 也會更明顯,而在 sequence length 為 4096 時,flash attention 甚至帶來了快 80% 的加速!

結論

訓練一個 LLM 是一個很有趣也很有挑戰的事情,除了需要有一定的硬體運算資源之外,也因為 LLM 本身非常龐大,更需要在有限資源內最大化資源的使用率,而由於 GPU 當中往往對於矩陣相乘有特別的優化,因此使得 bottleneck 很容易出現在 memory 存取的地方,而在這篇文章當中所介紹的 flash attention,便是透過調整了 self-attention 的計算方法,來減少大量 O(N*N) 的 HBM access loading,使得整體的速度可以大幅提升,同時不會損失任何 accuracy。

最後,謝謝看到這裡的你,假如有任何錯誤的地方再請你不吝糾正,假如有想一起討論的地方,也歡迎寄信給我!

--

--