Octoマルチノード学習
複数ノードでのOctoファインチューニングの実行
こんにちは、松尾・岩澤研究室リサーチエンジニアのAlfredoです。
コンテキスト
研究室の研究者たちは既存のモデルを独自のデータでファインチューニングしようとしていますが、単一ノードでは処理が遅すぎるため、コードを修正して複数ノードで並列実行できるようにすることが可能か、またその難易度を確認しようとしています。
対象となるモデルはJAXで記述された「トランスフォーマーベースの拡散ポリシー」のコレクションであるOctoです。このモデルを使用したことはありませんが、研究の多くはPyTorchを用いて行われています。それでも時間を割いて検討する価値はあります。
環境
まず、研究者が使用している環境を再現し、エラーが発生した場合に今後も再現できる基準を確立する必要があります。
今回はオリジナルのリポジトリをフォークしたこちらを使用しているため、メインプロジェクトではなくそのフォークを使用することが重要です。
研究室ではいくつかのHPCクラスターを利用できますが、今回は研究者が使用しているABCIを選び、OSライブラリやモジュールの不一致を避けます。
まず、インタラクティブノードをGPU付きで要求します。これはインストール時に必要となる可能性があり、また最終的にはGPU上でコードを実行する予定があるためです。
# ABCIにログイン
ssh abci
# 2時間のインタラクティブノードを要求
qrsh -l rt_F=1 -g gcb50389 -l h_rt=02:00:00
次に、git
を使用してコードの最新コピーを取得します。
# コードを取得
git clone https://github.com/TMats/octo
cd octo
リポジトリには詳細なインストール手順が記載されているため、それに従い、指定がある場合にはパッケージのバージョンを合わせるよう特に注意します。
# Conda環境でOctoをインストール
conda create -n octo python=3.10 -y
conda activate octo
pip install -e .
pip install -r requirements.txt
このプロジェクトは、Googleが設計したアクセラレータ向けのPythonライブラリであるJAXを使用しています。JAXにはTPU用とGPU用の2つのバージョンがあります。GPUを使用しているため、GPU用の推奨コマンドを実行します。
pip install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
インストールは正常に完了し、提供されたサンプルコマンドでテストを進めます。
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small --debug
残念ながら、すぐに例外が発生します。
Traceback (most recent call last):
# 不要な出力や過剰な出力は簡略化のため省略
[...]
AttributeError: module 'scipy.linalg' has no attribute 'tril'
少し調べたところ、SciPYのリリースノートで、linalg
関数がバージョン1.13
で非推奨となり削除されたことがわかりました。
既存のrequirements.txt
ファイルにはバージョン1.6.0
以上という指定しかないため、上限も追加してパッケージを再インストールします。
# パッケージバージョンの上下限を指定
pip install "scipy>=1.6.0,<1.13"
[...]
Successfully installed scipy-1.12.0
代わりにv1.12
がインストールされ、再度サンプルコードを実行します。
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small --debug
[...]
W0604 12:25:56.079283 22767111436096 xla_bridge.py:697] CUDA backend failed to initialize: Found cuDNN version 0, but JAX was built against version 8600, which is newer. The copy of cuDNN that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
動作はしますが、警告メッセージにより、検出されたcuDNNライブラリのバージョンがJAXでコンパイルされたものと一致していないことがわかります。このため、非常に遅くなり、CPUを使用することになります。プロセスを終了し、さらに調査を進めます。
現在インストールされているcuDNNのバージョンは次のとおりです。
# cuDNNバージョンを確認
pip list | grep cudnn
jaxlib 0.4.20+cuda11.cudnn86
nvidia-cudnn-cu11 9.1.1.17
JAXの依存関係に問題があるようで、バージョン8.6
ではなくバージョン9.1
のcuDNNを取得しています。適切なバージョンを取得するため、まずPyPIで完全なバージョン番号を検索し、8.6
が8.6.0.163
に対応していることを確認します。
その後、このバージョンをインストールします。
# JAXでコンパイルされたものと同じcuDNNバージョンをインストール
pip install --upgrade nvidia-cudnn-cu11==8.6.0.163
[...]
Successfully installed nvidia-cudnn-cu11-8.6.0.163
再びコードを実行してみます。
# サンプルコードを再実行
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small --debug
[...]
I0604 16:00:50.505642 23000339048256 compilation_cache.py:101] Writing jit_train_step to persistent compilation cache with key jit_train_step-9723010c1ea073770e5495bd7365e7d217da4b1e678506ee5800f821775a8738.
0%|▏ | 94/50000 [01:58<6:14:33, 2.22it/s]
問題なくトレーニングが進んでいるようです。これで警告が発生しない作業環境が整いました。
サマリー
インストール手順を以下にまとめます。
git clone https://github.com/TMats/octo
cd octo
conda create -n octo python=3.10 -y
conda activate octo
pip install -e .
# 注:scipyの行末に「,<1.13」を追加します。手動で編集することも可能です。
sed '/scipy/ s/$/,<1.13/' requirements.txt > requirements_fix.txt
pip install -r requirements_fix.txt
pip install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade nvidia-cudnn-cu11==8.6.0.163
マルチノード
研究室ではPyTorchが最もよく使用されていますが、JAXで複数ノードでのトレーニング方法を学ぶため、まずそのドキュメントを確認しました。その結果、いくつかの規約に従えば、ライブラリがプロセス間の同期を処理してくれることがわかりました。
簡単に言えば、各ノードで少なくとも1つのプロセスを実行し、それらが同じコードを同じ順序で実行する必要があります。ユーザーとしては、ほとんどコードを変更せずに単一ノードコードを再利用できます。
JAXの初期化コードは、動作するためにプロセス数全体、コーディネータとして動作するプロセスのアドレス、およびグループ内の現在のプロセスの順序を知る必要があります。そのため、この情報を提供する必要があります。
scripts/finetune.py
ファイルの“dunder main”パターン内で、分散機能の初期化と終了を追加し、メインコードをラップします。
# 初期化
# [...]
jax.distributed.initialize(
coordinator_address=coordinator_address,
num_processes=world_size,
process_id=world_rank
)
# 実行
app.run(main)
# 終了
jax.distributed.shutdown()
これらの変数を設定する方法はプロセスの起動方法に依存します。今回の場合はジョブシステム(ABCIではSun Grid Engine、略してSGE)を使用するため、シェル環境から読み取ることにします。
coordinator_address = os.environ.get('COORDINATOR_ADDRESS') or 'localhost:12345'
world_size = int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
world_rank = int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
COORDINATOR_ADDRESS
はジョブファイル内で直接設定する変数で、その他の2つはジョブスケジューラによって自動的に設定される標準的なMPI変数です。
world_size
はすべてのノード間のプロセス総数を表し、world_rank
はそのセット内での現在のプロセスのインデックスです。
このor
パターンは、対話モードで実行する際の互換性のために存在します(この場合、これらの変数は設定されません)が、コードがバッチ処理システム用に準備できたら削除することも可能です。
変更したコードを単一ノードのインタラクティブモードで実行すると、重要な変更はありませんが、システムがすべてのGPUを適切に検出し(この場合4つ)、グローバルバッチサイズがそれらの間で分割されることを確認します。
[...]
# Devices: 4
Batch size: 256 (64 per device)
複数ノードでこれを実行するには、まずSGE用のジョブファイルを作成する必要があります。
#!/bin/bash
#$ -l rt_F=1
#$ -l h_rt=0:10:00
#$ -j y
#$ -cwd
#$ -l USE_SSH=1
#$ -v SSH_PORT=2299
# システムのMPIライブラリを使用
source /etc/profile.d/modules.sh
module load hpcx/2.12
# Conda環境を有効化
source ~/miniforge3/etc/profile.d/conda.sh
conda activate octo
# フォルダへ移動
cd ~/blog/octo
# 最初のノードをコーディネータに設定
export COORDINATOR_ADDRESS=`head -1 $SGE_JOB_HOSTLIST`:12345
# 注:JAXはSLURM/OpenMPIで1GPUあたり1プロセスを期待
export NUM_GPUS=`nvidia-smi -L | wc -l`
mpirun -npernode $NUM_GPUS -hostfile $SGE_JOB_HOSTLIST \
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small
ジョブスケジューラによってジョブに割り当てられた最初のノードのホスト名をコーディネータアドレスとしてエクスポートし、ランダムに未使用のポート番号12345
を指定します(他のアプリケーションが使用している場合は別のものを選択可能)。
多GPUノードでは、すべてのGPUを使用する単一プロセスを起動するか、各GPUに1つのプロセスを割り当てるかを選択できます。たとえば、ABCIのrt_F
ノードには4つのV100-16G
GPUがあります。このため、各ノードで1プロセスを起動して4つすべてのGPUを使用するか、各ノードで4プロセスを起動して各プロセスに1つのGPUを割り当てることができます。ただし、JAXはこちらの仕様によりSLURMまたはOpenMPIで1GPUあたり1プロセスを期待しています。
そのため、MPIにmpirun
の-npernode
パラメータを使用して各ノードで4プロセスを起動するよう指示します。
コードがバッチジョブとして正常に動作するか確認するため、ジョブを送信します。ジョブファイルの冒頭にあるrt_F=1
は1つのノードを要求し、h_rt=0:10:00
は短い時間制限(10分)を示しています。
$ qsub -g gcb50389 -N test1 job.sh
ジョブが実行を終了した後(正常に終了した場合でもエラーで終了した場合でも)、ログを確認します。
[...]
# Devices: 4
Batch size: 256 (64 per device)
[...]
FileExistsError: [Errno 17] File exists: '/home/acb11899xv/blog/octo/wandb/run-20240613_165553-experiment_20240613_165552/run-experiment_20240613_165552.wandb'
[...]
FileExistsError: [Errno 17] File exists: '/home/acb11899xv/blog/octo/wandb/run-20240613_165553-experiment_20240613_165552/run-experiment_20240613_165552.wandb'
[...]
FileExistsError: [Errno 17] File exists: '/home/acb11899xv/blog/octo/wandb/run-20240613_165553-experiment_20240613_165552/run-experiment_20240613_165552.wandb'
これにより、各GPUごとに1つのプロセスが実行されていることが確認されました。ただし、4つのプロセスがすべてWandBライブラリを初期化しようとしており、最初の1つがログファイルを作成すると他の3つが失敗しています。
一時的な対策として、ジョブファイル内で export WANDB_MODE=disabled を設定し、WandBライブラリを無効化します。この設定を用いて分散処理の変更に集中し、その後に複数初期化の問題を修正します。
この設定で再度ジョブを実行すると、次のような結果が得られます。
# Devices: 4
Batch size: 256 (64 per device)
[...]
# Devices: 4
Batch size: 256 (64 per device)
[...]
# Devices: 4
Batch size: 256 (64 per device)
[...]
# Devices: 4
Batch size: 256 (64 per device)
これにより、以前は1つのプロセスだけがログを記録していたのに対し、すべてのプロセスがログを記録していることが確認されました。また、GPUが正しく検出され、組み合わされているため、デバイスごとのバッチサイズが依然として 64 であることが確認できます。
しかし、その後に別のエラーが発生しました。
[...]
File "/home/acb11899xv/blog/octo/octo/model/octo_module.py", line 125, in __call__
assert horizon <= self.max_horizon, "horizon must be <= max_horizon"
AssertionError: horizon must be <= max_horizon
このエラーは、提供されたテンソルの形状が期待される形状に一致していないことに関連するエラーのようです。これはいくらかの進展です。
デバッグプロセスを続行する前に、まず2ノードを使用してGPUが複数ノードで正しく検出されるかを確認します。
ジョブファイルで以下を変更します。
# 2ノードを要求
#$ -l rt_F=2
再度ジョブを実行します.
[...]
# Devices: 8
Batch size: 256 (32 per device)
[...]
# Devices: 8
Batch size: 256 (32 per device)
[...]
# Devices: 8
Batch size: 256 (32 per device)
[...]
AssertionError: horizon must be <= max_horizon
8つのログプロセスが生成され、すべてのGPU(2ノード、4 GPUずつ、合計8デバイス)が検出されました。その後、同じアサーションエラーが発生し、前回の結果が確認されました。また、指定されたバッチサイズ256が8デバイスに分割され、4 GPUの単一ノードを使用した場合の半分(32対64)になったことにも注目してください。
エラーメッセージの原因を特定するためにソースコードを調査したところ、jax.tree_util.tree_leaves
関数、具体的にはその引数observations
に問題があることがわかりました。
batch_size, horizon = jax.tree_util.tree_leaves(observations)[0].shape[:2]
ジョブスクリプトでGPUの数を一時的に1に固定します。
[...]
#export NUM_GPUS=`nvidia-smi -L | wc -l`
export NUM_GPUS=1
その後、インタラクティブノードで手動で起動し、異なるテンソルの形状をデバッグおよび検査します。これを行うには、適切なポイントで実行を一時停止するコードを追加します。
import pdb; pdb.set_trace()
コードが一時停止している状態で、pdb
のコンソールを使用してテンソルの形状を出力します。
# 1 GPUのみでの実行時
(Pdb) observations
{'image_primary': Traced<ShapedArray(uint8[1,2,256,256,3])>with<DynamicJaxprTrace(level=1/0)>, 'image_wrist': Traced<ShapedArray(uint8[1,2,128,128,3])>with<DynamicJaxprTrace(level=1/0)>, 'pad_mask': Traced<ShapedArray(bool[1,2])>with<DynamicJaxprTrace(level=1/0)>, 'pad_mask_dict': {'image_primary': Traced<ShapedArray(bool[1,2])>with<DynamicJaxprTrace(level=1/0)>, 'image_wrist': Traced<ShapedArray(bool[1,2])>with<DynamicJaxprTrace(level=1/0)>, 'proprio': Traced<ShapedArray(bool[1,2])>with<DynamicJaxprTrace(level=1/0)>, 'timestep': Traced<ShapedArray(bool[1,2])>with<DynamicJaxprTrace(level=1/0)>}, 'proprio': Traced<ShapedArray(float32[1,2,8])>with<DynamicJaxprTrace(level=1/0)>, 'timestep': Traced<ShapedArray(int32[1,2])>with<DynamicJaxprTrace(level=1/0)>}
(Pdb) jax.tree_util.tree_leaves(observations)[0].shape
(1, 2, 256, 256, 3)
(Pdb) c
[...]
(Pdb) jax.tree_util.tree_leaves(observations)[0].shape
(1, 1, 256, 256, 3)
(Pdb) c
[...]
(Pdb) jax.tree_util.tree_leaves(observations)[0].shape
(256, 1, 256, 256, 3)
ここでは、Traced
型はJITコンパイラに関連しているため無視し、テンソルの形状に注目します。
特に、キーimage_primary
の形状(1, 2, 256, 256, 3)
は、バッチ処理された観測画像の文脈から(batch, horizon, width, height, depth)
に対応すると考えられます。しかし、最初の呼び出し時にはhorizon
が2、次の呼び出し時には1となり、その後も1に固定されています。
コードは異なる構成で複数回呼び出されているようです。詳細にコードを再調査したところ、以下の箇所が関係していることが判明しました。
# scripts/finetune.py
[...]
pretrained_model = OctoModel.load_pretrained(
FLAGS.config.pretrained_path,
step=FLAGS.config.pretrained_step,
)
[...]
del pretrained_model
model = OctoModel.from_config(
config,
example_batch,
text_processor,
rng=init_rng,
dataset_statistics=dataset.dataset_statistics,
)
del model
[...]
# on loss_fn
print("DEBUG before module bind")
bound_module = model.module.bind({"params": params}, rngs={"dropout": rng})
print("DEBUG after module bind")
OctoModel
は最初の2回はパラメータを読み取るために呼び出され、その後モデルが破棄されています。3回目はloss_fn
コード内でモジュールがバインドされるときに間接的にインスタンス化され、これにより何度も呼び出されます。このようなインスタンス化がどこで発生しているかを特定するために、print
文を追加しました。特にエレガントではありませんが、シンプルで効果的な方法です。
次に、1 GPUと2 GPU以上で動作が異なるかどうかを確認します。
jjjj;;;;;;j
2 GPUを使用してpdb
を実行するのは複雑(入出力のリダイレクト、タイムアウト、バリアなどの問題)であるため、簡単のため通常通り実行し、ログに直接形状を出力します。
# 2 GPUの場合
DEBUG OctoTransformer shape=(1, 2, 256, 256, 3)
DEBUG OctoTransformer shape=(1, 2, 256, 256, 3)
[...]
DEBUG OctoTransformer shape=(1, 256, 1, 256, 256, 3)
DEBUG OctoTransformer shape=(1, 256, 1, 256, 256, 3)
事前学習モデルを初期ロードした後、トレーニングが開始される段階(バッチサイズ256が設定されることからわかります)で、2 GPU以上で動作する場合に観測テンソルの次元が5から6に変更され、インデックスエラーが発生していることが確認されました。
octo_model.py
のソースコードを確認したところ、初期化メソッドfrom_config
内で、バッチを準備するためにmultihost_utils.process_allgather
が呼び出されていることがわかりました。
module = OctoModule.create(**config["model"])
rng = rng if rng is not None else jax.random.PRNGKey(0)
example_batch = multihost_utils.process_allgather(example_batch)
そのドキュメントを見ると、完全にアドレス可能でない配列の場合、データはそのまま複製され、完全にアドレス可能な配列(シャーディングされた配列)の場合は、tiled
パラメータに応じて挙動が変わることが記載されています。デフォルトでは、先頭に新しい次元を追加する(スタックする)挙動となります。
この挙動が今回の問題をうまく説明しています。テンソルに新しい次元が追加されているため、tiled
パラメータをtrue
に変更し、出力を結合させるように修正します。
#example_batch = multihost_utils.process_allgather(example_batch)
example_batch = multihost_utils.process_allgather(example_batch, tiled=True)
修正後に再実行してみると、コードは動作しましたが、新たなエラーが発生しました。
[...]
ValueError: Passing non-trivial shardings for numpy inputs is not allowed. To fix this error, either specify a replicated sharding explicitly or use `jax.experimental.multihost_utils.host_local_array_to_global_array(...)` to convert your host local numpy inputs to a jax.Array which you can pass to pjit. If the numpy input is the same on each process, then you can use `jax.make_array_from_callback(...) to create a `jax.Array` which you can pass to pjit. Please see the jax.Array migration guide for more information https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. Got arg shape: (256, 7), arg value: [[False False False ... False False True]
これは理想的な結果ではありませんが、進展です。
エラーメッセージの指摘に従い、明示的なレプリケートシャーディングを指定する必要があります。それでは、レプリケートシャーディングとは何でしょうか?
簡単に言えば、JAXは「メッシュ」という概念を使用してハードウェアを定義します。これは、デバイス識別子のセット(通常はjax.devices()
の出力)と、それらのデバイスを割り当てる必要がある次元を指定する軸のリストから構成されます。
たとえば、8 GPUのセットを1軸(1x8)、2軸(2x4)、3軸(2x2x2)、4軸(4x2x1x1)として定義できます。これは論理的な構造であり、実際のアクセラレータハードウェアの正確なトポロジーを使用することでさらなるパフォーマンスが得られる場合もありますが、多くの場合、ライブラリに任せるのが良い出発点です。
現在のコードでは「バッチ」と名付けられた単一の軸を設定しています。
# scripts/finetune.py
[...]
# "batch"と名付けられた1次元メッシュを作成
mesh = Mesh(jax.devices(), axis_names="batch")
# バッチはデータ並列シャーディングされ、各デバイスがバッチの一部を取得
dp_sharding = NamedSharding(mesh, PartitionSpec("batch"))
# モデルはデバイス間で複製される(データ並列のみ、モデル並列ではない)
replicated_sharding = NamedSharding(mesh, PartitionSpec())
その後、シャーディングは、メッシュからどの入力と出力がどの軸に割り当てられるかを定義します。
# scripts/finetune.py
@partial(
jax.jit,
in_shardings=[replicated_sharding, dp_sharding],
)
def train_step(state, batch):
この場合、train_step
の引数は順番にstate
とbatch
であり、partial
のin_shardings
パラメータは、state
が複製されないこと(空のPartitionSpec
名のreplicated_sharding
)を示し、batch
がバッチ軸を使用することを示しています(つまり、Nデバイス間で分割される)。
複数ノード用の明示的なシャーディングに変更するために、ドキュメントを確認し、shard_map
関数を使用することを決定しました。この関数を使用すると、デバイスごとに考えることができ、それ以外の処理をライブラリに任せることができます。
ドキュメントによると、shard_map
関数のシグネチャは現在使用しているpartial
とは少し異なります。メッシュとin_specs
およびout_specs
の両方をパラメータとして指定し、いくつかの追加インポートも必要です。
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P
さらに、NamedSharding
の代わりにPartitionSpec
を直接使用する必要があります。簡潔さのため、インラインでそれらを定義します。例にあるコードでは、リスト[]
をタプル`()に置き換える必要があることに注意してください。関数はそれを期待していますが、自動的に変換されないためです。
また、jit
関数をshard_map
に置き換えたため、jit
コンパイルを保持するために、jit
を前のデコレータとして追加します。このチェーン構成は、jit(partial(train_step))
と同等であり、トレーニングステップが各バッチで部分的に適用され、その後パフォーマンスのためにjitted
されることを意味します。
# scripts/finetune.py
@jax.jit
@partial(
shard_map,
mesh=mesh,
in_specs=(PartitionSpec(), PartitionSpec("batch")),
out_specs=(PartitionSpec(), PartitionSpec()),
)
def train_step(state, batch):
この修正を行った後、再度ジョブを起動すると次のエラーが発生しました。
[...]
NotImplementedError: No replication rule for erf_inv. As a workaround, pass the `check_rep=False` argument to `shard_map`. To get this fixed, open an issue at https://github.com/google/jax/issues
新たなエラーですが、今回は回避策についての説明が含まれています。
@jax.jit
@partial(
shard_map,
mesh=mesh,
in_specs=(PartitionSpec(), PartitionSpec("batch")),
out_specs=(PartitionSpec(), PartitionSpec()),
check_rep=False
)
def train_step(state, batch):
この修正を加え、再実行すると次の結果が得られました。
[...]
12%|█▏ | 5888/50000 [28:17<2:59:31, 4.10it/s
これで動作が確認でき、単一ノードバージョンの約2倍の速度で実行されます(残り時間が6時間から3時間に短縮)。スピードアップは完全には線形ではありませんが、コードを複数ノードで動作させることができました。さらに長時間テストや、より多くのノードでの検証を行い、その実際の挙動を確認できます。
現時点では、変更内容を新しいブランチにコミットし、研究者に進捗状況を伝え、テストを依頼しました。