Photo by Markus Spiske on Unsplash

E16 : Grouped Query Attention

Praveen Thenraj
Research Papers Summarized
5 min readJan 20, 2024

--

Grouping the multiple query heads of MultiHeadAttention into subgroups of query heads and assigning each subgroup with one key and one value heads improves overall inference time without compromising the model performance

Paper Name : GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints

Paper URL : https://arxiv.org/pdf/2305.13245.pdf

Authors : Joshua Ainslie, James Lee-Thorp, Michiel de Jong, Yury Zemlyanskiy, Federico Lebrón, Sumit Sanghai

Conference : EMNLP 2023

Please find annotated paper here

Problem Statement :

  • MultiHeadAttention(MHA) mechanism uses multiple(h) query heads along with same number of multiple(h) key and value heads. Though the performance of MHA is good, inference time is more due to the memory bandwith.
  • MultiQueryAttention(MQA) mechanism on the other hand uses multiple(h) query heads with one key and one value head. Though the inference time is drastically reduced, there is a noticeable degradation in quality as well.
  • Though reducing the number of key and value heads reduces the inference time, it comes at the cost of model performance.

Solution :

  • There needs to exist an enhanced form of MQA that can match or come close to the performances of MHA based models.
  • The solution should not lose the flavour of both MHA and MQA in order to retain the advantages of both these mechanism.
  • GroupedQueryAttention(GQA) uses more than 1 key and 1value head (MQA) but less than number of query heads.
  • GQA is an intermediate between MHA and MQA.

Approach :

  • The authors do not pre-train MQA and GQA architectures from scratch, rather they up-train the existing pre-trained checkpoints.
  • They propose an approach to convert an already pre-trained MHA transformer model into MQA and GQA architecture.
  • Assume a MHA transformer architecture with 8 heads which means each query, key and value dimensions are split into 8 parallel heads.
  • To convert the MHA into MQA architecture, the authors propose taking the checkpoint of a pre-trained model and then apply mean pooling of all the heads (h=8) of key converting them into one head. Similar approach is done for value heads as well.
  • To convert the MHA into GQA architecture, the 8 query heads (h=8) of a pre-trained model checkpoint is divided into ’n’ (n=4) query groups.So each query group now has 2 query heads. Similarly the key and value heads are split into 4 key and 4 value groups respectively containing 2 heads in each of their groups.
  • Mean pooling is applied to the two key heads in each key group and converted into 1 key head. So a each query group (2 heads) will have 1 corresponding key head. Similar approach is done for value heads as well.
MultiHead has equal number of query, key and value heads. MultiQuery has one key and one value head for all heads of query. GroupedQuery has one key and one value head for each group of query heads
  • GQA is applied only on encoder-decoder models. Also GQA was specifically applied only to decoder module of these models and not applied to encoder module, as in encoder self-attention is computed in parallel.

Experimentation :

  • Pre-trained model taken for up-training - T5 Large and XXL
  • Up-training dataset - Colossal Clean Crawled Corpus (C4 dataset) (same as for T5 model pre-training in original paper)
  • Up-training steps - initial pre-training steps (T5 paper) + 5% additional training steps
  • Datasets for fine-tuning and evaluation - CNN/Daily Mail, arXiv, PubMed, MediaSum, WMT, Multi-News, TriviaQA
  • T5 variants used for evaluation :
    1. MHA-Large
    2. MHA-XXL (64 query heads aligning with 64 key and 64 value heads)
    3. MQA-XXL (64 query heads aligning with 1 key and 1 value heads)
    4. GQA-8-XXL (64 query heads split into 8 groups with each group consisting of 8 heads aligning with 1 key and 1 value heads)
  • Metrics for evaluation :
    1. ROGUE (summarisation) - CNN/Daily Mail, arXiv, PubMed, MediaSum, Multi-News
    2. BLEU (translation) - WMT
    3. F1 (Q&A) - TriviQA

Observations :

  • Evaluation of datasets was done using all four variants of T5 - MHA-L, MHA-XXL, MQA-XXL, GQA-8-XXL
GQA-8-XXL performance was almost always inline with performance of MHA-XXL
  • Results show that MQA-XXL was able to outperform performance of MHA-Large with decreased inference time
Inference Time Vs Performance
  • GQA-8-XXL was able to match the performance of MHA-XXL even with almost the same inference time as MQA-XXL.
  • A series of ablation studies were conducted to understand the effect of
    (i). considering different mechanisms to combine key heads for check point conversion
    (ii). using different up-training steps
    (iii). using different number of groups of query heads
  • Mean pooling, first key head (out of ‘h’ key heads), random initialisation of key head were considered during up-training of MQA-Large from MHA-Large. Results show that mean pooling helped in achieving better performance.
Performance Vs Different conversion technique from MHA to MQA during uptraining
  • Considering 5% off original pre-training steps in addition to the original pre-training steps helped MQA-XXL and GQA-8-XXL models to achieve better performance gains.
  • GQA-8-XXL achieved significant performance gain even without additional training steps (alpha = 0). At alpha=0, GQA-8-XXL outperformed MQA-XXL and was almost matching with the performance of MHA-XXL. This clearly indicates, grouping query heads and attending to more than 1 key and value heads improve performance of models.
Uptraining portion Vs Performance
  • Query heads were split into different number of groups and tested for inference time. Results showed that, with smaller number of groups the inference time of GQA-XXL was almost similar to MQA-XXL.
  • But as the number of query groups started to increase, it eventually led to more number of key and value heads, thus the inference time of GQA-XXL started to increase. This was due to increased inference overhead as number of query group start to increase.
Number of Query Groups Vs Inference time

Limitations :

  • The paper does not pre-train the GQA-n-XXL models from scratch rather it up-trains the existing pre-trained checkpoint by converting the existing MHA-XXL architecture.
  • There is no comparison of results between T5 with GQA-n-XXL and MHA-XXL pre-trained from scratch
  • GQA has been evaluated only on encoder-decoder models and not on decoder-only models.

Conclusion :

  • GQA can be a good starting point to interpolate advantages of MHA and MQA mechanisms.
  • Given the size of emerging decoder only large language models, GQA can definitely be useful in reducing the inference time for longer inputs without causing much degradation in performance.

--

--