AWS TrainiumでMoEを学習する

Yasuhisa Nakashima
KARAKURI Techblog
Published in
7 min readMay 21, 2024

はじめに

本記事では、AWS Trainiumを使用してMoEを学習する方法について解説します。

使用するコードは、KARAKURI LM 8x7B Chat v0.1の開発時に使用されたものをベースにしています。KARAKURI LM 8x7B Chat v0.1は、MoEをTrainiumで学習させた世界初の事例です。

AWS Trainiumとは

AWS Trainiumは、AWSによって開発された機械学習モデルのトレーニングに特化したアクセラレーターです。Trainiumは、同等のGPUインスタンスと比較して、トレーニングに要するコストを最大50%削減できるとされています。

Trainiumについての詳細は、AWSによるブログ記事を参照してください。

MoEとは

TransformerモデルにおけるMoE(Mixture of Experts)とは、フィードフォワード層を複数の独立したユニット(エキスパート)に分割し、その中から一部のエキスパートのみを選んで活性化する手法です。これにより、パラメータ数の増加に伴う計算コストの増加を抑えることができます。

MoEについての詳細は、Hugging Faceによるブログ記事を参照してください。

分散学習ライブラリ

AWS Trainiumでは、主にAWS Neuron Reference for NeMo Megatron(neuronx-nemo-megatron)が分散学習ライブラリとして使用されます。

今回使用するライブラリは、neuronx-nemo-megatronの一部を改変したものです。具体的には、Mixtralモデルで用いられているスパースMoE層に関する実装を追加しています。また、Hugging Face Hub上のデータセットを使用しやすくするため、Hugging Face Datasetsにも対応させています。

1. 環境の準備

まずは、AWSのインフラを構築します。以下のリンクを参考に、VPCとParallelClusterの設定を行ってください。

2. 必要なツールのインストール

次に、AWS公式のチュートリアルを参考に、必要なツールをインストールします。

ヘッドノードへの接続

まず、SSHでヘッドノードに接続します。

ssh -i YOUR_KEY.pem ubuntu@HEAD_NODE_IP_ADDRESS

仮想環境の有効化

仮想環境を有効化します。

cd ~
source ./aws_neuron_venv_pytorch/bin/activate

リポジトリのクローン

リポジトリをクローンします。

cd ~
git clone https://github.com/karakuri-ai/neuronx-nemo-megatron.git
cd neuronx-nemo-megatron

AWSのチュートリアルではaws-neuron/neuronx-nemo-megatronを使用しますが、ここではkarakuri-ai/neuronx-nemo-megatronを使用していることに注意してください。

パッケージのビルドと依存関係のインストール

パッケージをビルドし、依存関係をインストールします。

pip install wheel
./build.sh

pip install ./build/*.whl
pip install -r requirements.txt protobuf==3.20.3

cd ~
python -c "from nemo.collections.nlp.data.language_modeling.megatron.dataset_utils import compile_helper; \
compile_helper()"

3. データセットの準備

データセットのダウンロードと前処理を行います。

cd ~/neuronx-nemo-megatron/nemo/examples/nlp/language_modeling
python create_mixtral_sft_dataset.py

ここではNo Robotsデータセットを使用しています。他のデータセットを使用する場合は、 create_mixtral_sft_dataset.pyを編集してください。

4. モデルの学習

チェックポイントの変換(HF -> NeMo)

チェックポイントをHugging FaceフォーマットからNeMoフォーマットに変換します。

cd ~/neuronx-nemo-megatron/nemo/examples/nlp/language_modeling/checkpoint_conversion
python convert_hf_checkpoint_to_nemo_mixtral.py \
--path_to_checkpoint /path/to/hf_checkpoint \
--config_file /path/to/hf_checkpoint/config.json \
--model_bin_file /path/to/hf_checkpoint/pytorch_model.bin.index.json \
--output_path /path/to/nemo_checkpoint \
--tp_degree 8 \
--pp_degree 8 \
--save_bf16 True \
--num_shards 19

ParallelClusterの外部で実行した場合は、変換されたチェックポイントをFSxと紐付けているS3バケットにアップロードしてください。

パスの編集

mixtral_8x7b.shを編集して、データセットのパスとチェックポイントのパスを指定します。

モデルの事前コンパイル

モデルを事前コンパイルする場合は、以下のコマンドを実行します。

sbatch --nodes 2 compile.slurm ./mixtral_8x7b.sh

学習の開始

学習を開始します。

sbatch --nodes 2 run.slurm ./mixtral_8x7b.sh

チェックポイントの変換(NeMo -> HF)

学習が完了したら、チェックポイントをHugging Faceフォーマットに変換します。

cd ~/neuronx-nemo-megatron/nemo/examples/nlp/language_modeling/checkpoint_conversion
python convert_nemo_checkpoint_to_hf_mixtral.py \
--path_to_checkpoints /path/to/nemo_checkpoint \
--config_file /path/to/hf_config_file \
--output_path /path/to/hf_checkpoint \
--is_xser True \
--dtype bfloat16

5. 推論

推論はGPUやAWS Inferentia2で実行できます。Inferentia2による推論の実装は、AWSによるブログサンプルコードを参考にしてください。

まとめ

KARAKURI LM 8x7B Chat v0.1はHugging Face Hubにて公開されています。モデルの公開時点で、このモデルは日本語のマルチターン会話能力を測るベンチマークであるMT-Bench-jpで、オープンモデルの中でトップクラスの性能を記録しています。

また、Inferentia2による推論のデモも公開しています。

--

--