【KARAKURI LM 10本ノック】#3 LLM学習時のGPU消費量

Shun Katakami
KARAKURI Techblog
Published in
13 min read4 days ago

こんにちは、カラクリ R&D チームの片上です。
本記事では、大規模言語モデル(LLM)をGPUで学習させる際に必要となるメモリ使用量の見積もり方をまとめようと思います。LLMの推論・学習時のメモリ使用量をあらかじめ把握しておくことでGPUの見積もりや選定に役立つかと思います。その準備として、LLMの構造を簡単に把握しておく必要があります。そこで、まずLLMの構造について実際のLLMを用いて確認し、そして、LLMの推論・学習の仕方を踏まえて、どのようにメモリが使用されるかを試算してみましょう。

大規模言語モデルの学習に必要なメモリ使用量

1. 大規模言語モデルの読み込みに必要なメモリ

まず、代表的なLLM(Transformer)の構造を確認してみましょう。以下のコードを実行することでLLMの構造を確認することができます。

from transformers import AutoModelForCausalLM

model_name= "karakuri-ai/karakuri-lm-70b-v0.1"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
model = model.to(torch.bfloat16)
print(model)

ここでは例としてmeta-llama/Llama-2–70b-hfを日本語用にFinetuningしたkarakuri-ai/karakuri-lm-70b-v0.1を取り扱っています。大きいモデルですので、float16で読み込んだ場合について以下説明していきます。お持ちのGPUのメモリが小さい場合はモデルを読み込むことができませんので、適宜自分の環境にあったサイズのモデルに置き換えてみてください。オススメは非常に軽量な本モデルと同型のモデルとして

model_name="ahxt/LiteLlama-460M-1T"

などがありますのでGPUメモリが小さい場合はこちらを用いて続きを一緒に確認してみてください。

上記のコードの実行により得られた構造は以下の通りでした。ここからkarakuri-lm-70bは文章の埋め込みを行い、self-attention+MLPを80回施して推論するモデル構造となっていることがわかります。

LlamaForCausalLM(
(model): LlamaModel(
(embed_tokens): Embedding(45416, 8192)
(layers): ModuleList(
(0-79): 80 x LlamaDecoderLayer(
(self_attn): LlamaAttention(
(q_proj): Linear(in_features=8192, out_features=8192, bias=False)
(k_proj): Linear(in_features=8192, out_features=1024, bias=False)
(v_proj): Linear(in_features=8192, out_features=1024, bias=False)
(o_proj): Linear(in_features=8192, out_features=8192, bias=False)
(rotary_emb): LlamaRotaryEmbedding()
)
(mlp): LlamaMLP(
(gate_proj): Linear(in_features=8192, out_features=28672, bias=False)
(up_proj): Linear(in_features=8192, out_features=28672, bias=False)
(down_proj): Linear(in_features=28672, out_features=8192, bias=False)
(act_fn): SiLU()
)
(input_layernorm): LlamaRMSNorm()
(post_attention_layernorm): LlamaRMSNorm()
)
)
(norm): LlamaRMSNorm()
)
(lm_head): Linear(in_features=8192, out_features=45416, bias=False)
)

具体的にパラメータの数を数え上げてみましょう。以下のコードを実行することで、パラメータの内訳を見ることができます。

# 各レイヤーのパラメータ数を表示
for name, parameter in model.named_parameters():
print(f"{name}: {parameter.numel()}")

# モデルのパラメータの総数を再計算
total_params = sum(p.numel() for p in model.parameters())
print(f"Total Parameters: {total_params}")

パラメータの内訳は以下のようになります。各層のパラメータの数は層の入力次元×出力次元で計算できるので、例えば、埋め込み層に関しては372047872=45416×8192, self-attentionに関しては67108864=8192×8192といった具合です。

model.embed_tokens.weight: 372047872
model.layers.0.self_attn.q_proj.weight: 67108864
model.layers.0.self_attn.k_proj.weight: 8388608
model.layers.0.self_attn.v_proj.weight: 8388608
model.layers.0.self_attn.o_proj.weight: 67108864
model.layers.0.mlp.gate_proj.weight: 234881024
model.layers.0.mlp.up_proj.weight: 234881024
model.layers.0.mlp.down_proj.weight: 234881024
model.layers.0.input_layernorm.weight: 8192
model.layers.0.post_attention_layernorm.weight: 8192
model.layers.1.self_attn.q_proj.weight: 67108864
model.layers.1.self_attn.k_proj.weight: 8388608
model.layers.1.self_attn.v_proj.weight: 8388608
model.layers.1.self_attn.o_proj.weight: 67108864
model.layers.1.mlp.gate_proj.weight: 234881024
model.layers.1.mlp.up_proj.weight: 234881024
model.layers.1.mlp.down_proj.weight: 234881024
model.layers.1.input_layernorm.weight: 8192
model.layers.1.post_attention_layernorm.weight: 8192
model.layers.2.self_attn.q_proj.weight: 67108864
model.layers.2.self_attn.k_proj.weight: 8388608
model.layers.2.self_attn.v_proj.weight: 8388608
model.layers.2.self_attn.o_proj.weight: 67108864
model.layers.2.mlp.gate_proj.weight: 234881024
model.layers.2.mlp.up_proj.weight: 234881024
...
model.layers.79.post_attention_layernorm.weight: 8192
model.norm.weight: 8192
lm_head.weight: 372047872
Total Parameters: 69196455936

上記からkarakuri-lm-70bのパラメータ数が69B程度であることがわかりました。試しに実際にモデルをGPUに載せてみましょう。

import torch
'''
モデルをGPUに載せてみる。
'''
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

モデル読み込みによるGPUの使用量は以下のようにして見積もることができます。ここでは、モデルをfloat16で読み込んでいるので、karakuri-lm-70bのパラメータ1つ1つに16bit、すなわち2Bytes割り当てられます。したがって

  • モデルをfloat16で読み込んだ場合、メモリ使用量はパラメータ数×2Bytesです。

よってkarakuri-lm-70bをfloat16で読み込むのに必要なメモリは69,196,455,936×2Bytes≒131982MiBとなります。(MiB=2²⁰Bytesであることに注意。)

実際にGPUへ読み込みを行うと、この値よりやや多くメモリが使用されますが、システムによるGPU利用やライブラリ内部で別途メモリ使用が生じているためです。

2. 大規模言語モデルの推論に必要なメモリ

次にLLMに文章生成をさせたときに追加で必要となるメモリ使用量も確認しておきましょう。以下のように適当な文章を用いて、文章生成する手続きを簡単に確認します。LLMの文章生成は基本的には以下のようにoutputsを計算し進めることになります。

input_text = "人類が初めて人工物に対し「知能らしきもの」を知覚したと言われているものの一つに、オートマタ※があるという話を、ご存知ですか?"
input_ids = tokenizer.encode(input_text, return_tensors='pt')
input_ids = input_ids.to(device)
with torch.no_grad():
outputs = model(input_ids)

outputsの中にはlogitsとcacheを有効にしていた場合はpast_key_values(attentionのkeyとvalueの対応)が保存されています。cacheが無効の場合はpast_key_valuesは保存されません。各々の構造は

  • logits : [バッチサイズ, シーケンス長, 辞書サイズ]
  • past_key_values : [アテンション数, 2(key, value), バッチサイズ, シーケンス長, 分散表現サイズ]

のようになっており、

例えば、50 tokenのinput_textをバッチサイズ1で辞書サイズ45416, 分散表現サイズ 8192のkarakuri-lm-70bで推論した場合は

  • logits : 1×50×45416×2Bytes ≒4.3MiB
  • past_key_values : 80×2×1×50×8192×2Bytes=125MiB

のメモリが使用されます。このときのメモリ使用量は、合計4.3MiB+125MiB=129.3MiBとなります。実際のメモリ使用量はライブラリの関係でこれよりもやや大きくなる傾向にあります。

従って、推論時のメモリ使用量はモデルパラメータの読み込み時のメモリ使用に加えて

  • 推論時に追加で必要なメモリ使用量 cache有効の場合 = (辞書サイズ+2×分散表現サイズ×アテンション数)×バッチサイズ×シーケンス長×2Bytes
  • 推論時に追加で必要なメモリ使用量 cache無効の場合 = 辞書サイズ×バッチサイズ×シーケンス長×2Bytes

程度と見積もることが出来ます。

3. 大規模言語モデルの学習に必要なメモリ

LLMの学習時のメモリ使用量は、モデル読み込みと推論時のメモリ使用量に加えて、中間結果の保持に必要なメモリも考慮する必要があります。LLMの学習は勾配法によって行われるため、まず勾配の値をパラメータ数分保持しておく必要があり、更に最適化を行うにあたってはOptimizerを用いていくのですが、代表的なOptimizerとしてAdam等を用いる場合、勾配と2次勾配を内部で保持する必要があります。こちらに関してもパラメータ数分確保する必要があります。また勾配を計算するのに全ての中間表現を一時的に保存しておくことが効率的なようです。(古い記事ですが参考リンク。)

従って、以上をまとめるとLLMの学習を行うために確保する必要のあるメモリは

  • 学習用に必要なメモリ使用量 =3×パラメータ数×2Bytes + 2×全中間表現×バッチサイズ×シーケンス長×2Bytes

となります。内訳は以下の通りです。

  • 第1節で求めたモデルの読み込みのメモリ使用量=パラメータ数×2Bytes
  • 勾配保持のためのメモリ使用量=パラメータ数×2Bytes
  • 2次勾配保持のためのメモリ使用量=パラメータ数×2Bytes
  • 全中間表現のメモリ使用量=全中間表現×バッチサイズ×シーケンス長×2Bytes
  • 全中間表現の勾配のメモリ消費量=全中間表現×バッチサイズ×シーケンス長×2Bytes

として見積もることができます。

例えばkarakuri-lm-70bの場合、

  • 学習用にモデルを読み込むのに必要なメモリ使用量~3×131982MiB≒386GiB

であり、全中間表現の1 tokenあたりのメモリ使用量を2.58MiBであるとすると(ここでは全中間表現をlogitsとpast_key_valuesとした。)、バッチ数 8, トークン長 1024とした学習に対して

  • 全中間表現のメモリ消費~2×8×1024×2.58MiB≒41GiB

となり学習に必要なメモリ消費量は386GiB+41GiB427GiB程度であると見積もることができます。

ここでは2次勾配まで用いるOptimizerを想定しましたが、Optimizerを変更し、勾配が不要なものにすると勾配を保持する必要がなくなるのでメモリ使用の軽量化になります。また、実際のメモリ使用量は、システムなどによるGPU使用等により、ここで見積もったものより大きくなる傾向にあります。ここでは話しておりませんがdeepspeedなどを用いることでより効率的にGPUを利用することも可能です。

おわりに

本記事では、実際のLLMを用いておおまかにGPUのメモリ使用量の見積もり方を示しました。本内容が、LLMの推論・学習時におけるGPUの見積もり、選定の助けになれば幸いです。

今後も、KARAKURI LM を多くの皆様に活用していただくために、KARAKURI LM 10本ノックと称して様々な情報を共有してまいります。 最後に、皆様にお願いがあります。執筆の励みになりますので、Clap(拍手アイコン)のタップをお願いします!

--

--