Octo Fine-tuning on Multiple Nodes
Running Octo fine-tuning in multiple nodes
Hello, I’m Alfredo, a research enginner at the Matsuo-Iwasawa Lab.
Context
Researchers at the lab are trying to fine-tune an existing model on their own data, but the process is too slow on a single node and they want to check if it is possible -and how hard it would be- to change the code to support multiple nodes running in parallel so they can iterate faster.
The model in question is Octo, a collection of “transformer-based diffusion policies” written in JAX. We have never used this particular model, and most of the research is conducted with PyTorch, but it is worth to reserve some time and check the situation.
Environment
First, we need to recreate the environment the researchers are using, to have a baseline and be able to reproduce any error they might encounter from now on.
In this case, they are using a fork of the original repository, so it is important to use that instead of the main project.
At the lab, we have access to several HPC clusters, but in this case we choose the same one researchers are using, ABCI, to avoid any possible mismatch of operating system libraries, modules, etc.
To start, we request an interactive node with GPUs, since they may be required at installation time by some libraries and we are planning to execute the code in the GPUs anyway.
# Log in to ABCI
ssh abci
# Request an interactive node for 2 hours
qrsh -l rt_F=1 -g gcb50389 -l h_rt=02:00:00
Then we use git
to get an up to date copy of the code:
# Get a copy of the code
git clone https://github.com/TMats/octo
cd octo
Since there are detailed installation instructions at the repository, we follow them, paying special attention to using the same versions of the packages (when present):
# Install Octo in a Conda environment
conda create -n octo python=3.10 -y
conda activate octo
pip install -e .
pip install -r requirements.txt
The project uses JAX, a Python library for accelerators made by Google and designed mainly for their TPUs, so there are two versions of it: one for TPUs and one for GPUs. We are using GPUs, so we run the suggested command for them:
pip install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
The installation completes successfully, and we proceed to testing it with the provided example command:
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small --debug
Unfortunately we immediately get an exception:
Traceback (most recent call last):
# Unrelated and/or excessive output ommited for brevity
[...]
AttributeError: module 'scipy.linalg' has no attribute 'tril'
A bit of searching leads to the SciPY releases notes, indicating the linalg
function has been deprecated and removed in version 1.13
of SciPy.
The existing requirements.txt
file only asks for a version greater than 1.6.0
, so we add the upper limit too and install the package again:
# Specify both lower and upper bounds for package version
pip install "scipy>=1.6.0,<1.13"
[...]
Successfully installed scipy-1.12.0
So we get v1.12
instead, and try the example code again:
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small --debug
[...]
W0604 12:25:56.079283 22767111436096 xla_bridge.py:697] CUDA backend failed to initialize: Found cuDNN version 0, but JAX was built against version 8600, which is newer. The copy of cuDNN that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
It runs, but the warning message tells us the cuDNN library version detected does not match the one JAX was compiled with. It will run very slowly, since it will use the CPU, so we kill the process and investigate a bit more.
The currently installed cuDNN version is:
# Check cuDNN versions
pip list | grep cudnn
jaxlib 0.4.20+cuda11.cudnn86
nvidia-cudnn-cu11 9.1.1.17
There seems to be some issue with the JAX dependencies where we get version 9.1
of cuDNN instead of 8.6
. To get the right one we first search in PyPI to find the complete version number, and get that 8.6
corresponds to 8.6.0.163
.
Then, we install that version:
# Install the same cuDNN version that JAX was compiled with
pip install --upgrade nvidia-cudnn-cu11==8.6.0.163
[...]
Successfully installed nvidia-cudnn-cu11-8.6.0.163
And try running the code again:
# Run the example code again
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small --debug
[...]
I0604 16:00:50.505642 23000339048256 compilation_cache.py:101] Writing jit_train_step to persistent compilation cache with key jit_train_step-9723010c1ea073770e5495bd7365e7d217da4b1e678506ee5800f821775a8738.
0%|▏ | 94/50000 [01:58<6:14:33, 2.22it/s]
OK, it seems to be training properly. We now have a working version of the environment that shows no relevant warnings.
Summary
The installation can be summarized as:
git clone https://github.com/TMats/octo
cd octo
conda create -n octo python=3.10 -y
conda activate octo
pip install -e .
# NOTE adds ,<1.13 at the end of the scipy line. Could be edited manually too.
sed '/scipy/ s/$/,<1.13/' requirements.txt > requirements_fix.txt
pip install -r requirements_fix.txt
pip install --upgrade "jax[cuda11_pip]==0.4.20" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install --upgrade nvidia-cudnn-cu11==8.6.0.163
Multi-node
PyTorch is the most commonly used framework at the lab, so to learn how to train in multiple nodes in JAX we first check its documentation, and it turns out that if we follow a few conventions the library will take care of the synchronization between processes.
In short, we must run at least one process on each node and have them execute the same code in the same order. As a user, that means we can reuse the single node code with hardly any change.
The initialization code for JAX still needs to know the total number of processes, the address of the one that will act as coordinator, and the order of the actual process in the group for it to work, so we need to provide that information.
In the scripts/finetune.py
file, inside the “dunder main” pattern, we add initialization and shutdown of the distributed features, wrapping it around the main code:
# Initialize
# [...]
jax.distributed.initialize(
coordinator_address=coordinator_address,
num_processes=world_size,
process_id=world_rank
)
# Run
app.run(main)
# Shutdown
jax.distributed.shutdown()
Depending on the way we launch these processes we could set these variables in multiple ways. In our case we are going to use the job system in ABCI (Sun Grid Engine, or SGE for short), so we are going to read them from the shell environment:
coordinator_address = os.environ.get('COORDINATOR_ADDRESS') or 'localhost:12345'
world_size = int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
world_rank = int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
COORDINATOR_ADDRESS
is a variable that we will set directly in the job file, and the other two are standard MPI variables set automatically by the job scheduler when a job is executed.
The world size is the total number of processes across all nodes, and the world rank is the index of the current process inside that set.
Note the or
pattern is there just for compatibility when running interactively -where those variables will not be set-, but it could be removed once the code is ready for the batch processing system.
We then execute this modified code in a single interactive node it does not change anything of importance, but we confirm the system properly detects all the GPUs (4 in this case) and the global batch size is divided among them:
[...]
# Devices: 4
Batch size: 256 (64 per device)
To launch it in multiple nodes we need to write first a job file for SGE:
#!/bin/bash
#$ -l rt_F=1
#$ -l h_rt=0:10:00
#$ -j y
#$ -cwd
#$ -l USE_SSH=1
#$ -v SSH_PORT=2299
# Use the system MPI libraries
source /etc/profile.d/modules.sh
module load hpcx/2.12
# Activate our Conda environment
source ~/miniforge3/etc/profile.d/conda.sh
conda activate octo
# Go to folder
cd ~/blog/octo
# Make first node the coordinator
export COORDINATOR_ADDRESS=`head -1 $SGE_JOB_HOSTLIST`:12345
# NOTE JAX expects 1 process per GPU in SLURM/OpenMPI
export NUM_GPUS=`nvidia-smi -L | wc -l`
mpirun -npernode $NUM_GPUS -hostfile $SGE_JOB_HOSTLIST \
python scripts/finetune.py --config.pretrained_path=hf://rail-berkeley/octo-small
We export the coordinator address as the hostname of the first node that is assigned to the job by the scheduler, and a random unused port over 1024
, in this case 12345
(if this is occupied by another application, we can always choose another one).
On multi-gpu nodes we normally can choose to launch a single process that uses all GPUs, or to launch a single process per GPU. For example, on ABCI, a rt_F
node has 4 V100-16G
GPUs, so we could launch 1 process per node that gets access to all 4 GPUS, or 4 processes per node and assign one GPU to each of them. However, JAX expects 1 process per GPU when launched in SLURM or OpenMPI.
So we need to tell MPI that we want 4 processes per node with the -npernode
parameter of mpirun
.
We launch the job to verify the code still works properly as a batch job. Note the rt_F=1
at the top of the file requests a single node, and h_rt=0:10:00
indicates a short time limit (10 minutes). We only want to verify the code works at this stage.
$ qsub -g gcb50389 -N test1 job.sh
After the job finished running, either with successful completion or with an error, we check the logs and get:
[...]
# Devices: 4
Batch size: 256 (64 per device)
[...]
FileExistsError: [Errno 17] File exists: '/home/acb11899xv/blog/octo/wandb/run-20240613_165553-experiment_20240613_165552/run-experiment_20240613_165552.wandb'
[...]
FileExistsError: [Errno 17] File exists: '/home/acb11899xv/blog/octo/wandb/run-20240613_165553-experiment_20240613_165552/run-experiment_20240613_165552.wandb'
[...]
FileExistsError: [Errno 17] File exists: '/home/acb11899xv/blog/octo/wandb/run-20240613_165553-experiment_20240613_165552/run-experiment_20240613_165552.wandb'
So, we have now 4 processes running, one per GPU, as expected, but that means all four processes are trying to initialize the WandB library, and after the first one creates the log file the other three fail.
For the moment, we can set export WANDB_MODE=disabled
in the job file to disable the WandB library while we focus on the distributed changes, and later fix the multiple initialization problem.
Doing so and launching the job again we get:
# Devices: 4
Batch size: 256 (64 per device)
[...]
# Devices: 4
Batch size: 256 (64 per device)
[...]
# Devices: 4
Batch size: 256 (64 per device)
[...]
# Devices: 4
Batch size: 256 (64 per device)
Showing that all 4 processes are logging (instead of 1 before) and the GPUs are detected and combined, so that the per-device count is still 64
.
However, there is a different error after that:
[...]
File "/home/acb11899xv/blog/octo/octo/model/octo_module.py", line 125, in __call__
assert horizon <= self.max_horizon, "horizon must be <= max_horizon"
AssertionError: horizon must be <= max_horizon
OK, this looks more like a “provided tensor shape did not match the expected shape type” of error, so it is some progress.
Before continuing with the debugging process, we check first with 2 nodes too, to verify the GPUs are still detected properly when using multiple nodes.
Change this on the job file:
# Request 2 nodes
#$ -l rt_F=2
And launch the job again:
[...]
# Devices: 8
Batch size: 256 (32 per device)
[...]
# Devices: 8
Batch size: 256 (32 per device)
[...]
# Devices: 8
Batch size: 256 (32 per device)
[...]
AssertionError: horizon must be <= max_horizon
We get eight logging processes and all GPUs detected (2 nodes with 4 GPUs, a total of 8 devices) and the same assertion error after that, confirming the previous result. Also note the specified batch size, 256
, gets divided among all the 8 devices so it is now half of what it was when using a single node with 4 GPUs (32
vs 64
).
Search the source code for the error message, we find that it comes from the function jax.tree_util.tree_leaves
, or rather, its parameter observations
:
batch_size, horizon = jax.tree_util.tree_leaves(observations)[0].shape[:2]
We then temporarily fix the number of GPUs to 1 in the job script:
[...]
#export NUM_GPUS=`nvidia-smi -L | wc -l`
export NUM_GPUS=1
and launch it manually on an interactive node, to debug and inspect the shapes of the different tensors. To that, we need to add code to pause the execution at the right point:
import pdb; pdb.set_trace()
With the code paused, we use pdb’s console to print the tensor shapes:
# On 1 GPU only
(Pdb) observations
{'image_primary': Traced<ShapedArray(uint8[1,2,256,256,3])>with<DynamicJaxprTrace(level=1/0)>, 'image_wrist': Traced<ShapedArray(uint8[1,2,128,128,3])>with<DynamicJaxprTrace(level=1/0)>, 'pad_mask': Traced<ShapedArray(bool[1,2])>with<DynamicJaxprTrace(level=1/0)>, 'pad_mask_dict': {'image_primary': Traced<ShapedArray(bool[1,2])>with<DynamicJaxprTrace(level=1/0)>, 'image_wrist': Traced<ShapedArray(bool[1,2])>with<DynamicJaxprTrace(level=1/0)>, 'proprio': Traced<ShapedArray(bool[1,2])>with<DynamicJaxprTrace(level=1/0)>, 'timestep': Traced<ShapedArray(bool[1,2])>with<DynamicJaxprTrace(level=1/0)>}, 'proprio': Traced<ShapedArray(float32[1,2,8])>with<DynamicJaxprTrace(level=1/0)>, 'timestep': Traced<ShapedArray(int32[1,2])>with<DynamicJaxprTrace(level=1/0)>}
(Pdb) jax.tree_util.tree_leaves(observations)[0].shape
(1, 2, 256, 256, 3)
(Pdb) c
[...]
(Pdb) jax.tree_util.tree_leaves(observations)[0].shape
(1, 1, 256, 256, 3)
(Pdb) c
[...]
(Pdb) jax.tree_util.tree_leaves(observations)[0].shape
(256, 1, 256, 256, 3)
We can ignore for now the Traced
types since those correspond to the JIT compiler, and focus on the shapes of the tensors instead.
In particular, the shape (1, 2, 256, 256, 3)
of the image_primary
key likely corresponds to (batch, horizon, width, height, depth)
, given the context of batched observation images. However, the first time it is called the horizon is 2, the second one it is 1, and the third one it is kept at 1.
So the code seems to be called multiple times with different configurations. Looking again in detail at the code we find:
# scripts/finetune.py
[...]
pretrained_model = OctoModel.load_pretrained(
FLAGS.config.pretrained_path,
step=FLAGS.config.pretrained_step,
)
[...]
del pretrained_model
model = OctoModel.from_config(
config,
example_batch,
text_processor,
rng=init_rng,
dataset_statistics=dataset.dataset_statistics,
)
del model
[...]
# on loss_fn
print("DEBUG before module bind")
bound_module = model.module.bind({"params": params}, rngs={"dropout": rng})
print("DEBUG after module bind")
The first two times OctoModel
is called to read some parameters and then the model is discarded. Then the other call is from the loss_fn
code, where it is indirectly instantiated when the module is bound, and thus called many times. It is a bit tricky to tell at first sight, which is why we had to add extra print
statements to try to find where in the code the instantiation was taking place. Not especially elegant, but it is simple and it works.
We then verify if its behavior changes when run on 1 or 2 (or more) GPUs.
With 2 GPUs running pdb
is more complicated (input redirection, timeouts or barriers, etc.) so for simplicity we instead run normally and print the shapes we care about directly to the log:
# On 2 GPUs
DEBUG OctoTransformer shape=(1, 2, 256, 256, 3)
DEBUG OctoTransformer shape=(1, 2, 256, 256, 3)
[...]
DEBUG OctoTransformer shape=(1, 256, 1, 256, 256, 3)
DEBUG OctoTransformer shape=(1, 256, 1, 256, 256, 3)
So, it seems after the initial loading of the pretrained models, once the training is about to start -we can tell because the proper batch size of 256 is set-, the version with 2 or more GPUs changes the dimensions of the observation tensor, from 5 to 6, leading to the indexing error.
Looking into octo_model.py
source, we find out that in the from_config
method of initialization there is a multihost_utils.process_allgather
call that prepares the batch:
module = OctoModule.create(**config["model"])
rng = rng if rng is not None else jax.random.PRNGKey(0)
example_batch = multihost_utils.process_allgather(example_batch)
And looking at its documentation it appears that for a non-fully addressable array the data is replicated as-is, but for fully addressable arrays (i.e. sharded) the behavior depends on the tiled
parameter and it defaults to adding a new dimension in front (stack, instead of concatenate).
This describes pretty well our situation, where we find an additional dimension in the tensor, so we proceed to change the parameter to true
to have its output concatenated:
#example_batch = multihost_utils.process_allgather(example_batch)
example_batch = multihost_utils.process_allgather(example_batch, tiled=True)
And try again. It runs, but we get a new error instead:
[...]
ValueError: Passing non-trivial shardings for numpy inputs is not allowed. To fix this error, either specify a replicated sharding explicitly or use `jax.experimental.multihost_utils.host_local_array_to_global_array(...)` to convert your host local numpy inputs to a jax.Array which you can pass to pjit. If the numpy input is the same on each process, then you can use `jax.make_array_from_callback(...) to create a `jax.Array` which you can pass to pjit. Please see the jax.Array migration guide for more information https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. Got arg shape: (256, 7), arg value: [[False False False ... False False True]
Not ideal, but it is progress.
As the error message mentions, we need to specify an explicit replicated sharding. So, what is that?
In simple terms, JAX uses the concept of “mesh” to define the hardware to use, a set of device identifiers -commonly the output of jax.devices() to use everything available-, and a list of axes to use to assign those devices to the dimensions required.
For example, a set of 8 GPUs could be defined as a 1 axis of 8 GPUs (1x8), 2 axis of 4 (2x4), 3 axis of 2 (2x2x2), 4 axis with an extra dimensions set as 1 (4x2x1x1), etc. Note this is a logical construct, and using the exact topology of the actual accelerator hardware may yield extra performance, but for most cases leaving it to the library is a good starting point.
The code is currently setting a single axis called “batch”:
# scripts/finetune.py
[...]
# create a 1D mesh with a single axis named "batch"
mesh = Mesh(jax.devices(), axis_names="batch")
# Our batches will be data-parallel sharded -- each device will get a slice of the batch
dp_sharding = NamedSharding(mesh, PartitionSpec("batch"))
# Our model will be replicated across devices (we are only doing data parallelism, not model parallelism)
replicated_sharding = NamedSharding(mesh, PartitionSpec())
Then the sharding defines how the inputs and outputs are defined from that mesh, that is, what input and outputs go to which axis:
# scripts/finetune.py
@partial(
jax.jit,
in_shardings=[replicated_sharding, dp_sharding],
)
def train_step(state, batch):
In this case, the parameters of the train_step
, in order, are the state and the batch, and the in_shardings
parameter of partial
indicates the state should not be replicated (an empty PartitionSpec
named replicated_sharding
) and the batch should use the batch
axis of replication (that is, divide it among the N devices).
To change this to an explicit sharding for multiple nodes, we first do some reading and decide to use the shard_map function, as it allows us to think on a per-device level and leave the rest to the library.
As the documentation shows, the signature of the function is slightly different than the partial
we are using at the moment. We need to provide the mesh and both in_specs
amd out_specs
as parameters, as well as adding some extra imports:
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P
Also, instead of NamedSharding
we need to use PartitionSpec
directly -for simplicity we define them inline. Note the need to replace the lists []
with tuples ()
, as in the example code, since the function expects them but they won’t be converted automatically, since all sets are list but not all lists are sets.
And because we have replaced the jit
function with shard_map
, to keep the jit
compilation we add it instead as a preceding decorator. The chain composition is equivalent to jit(partial(train_step))
, meaning the training step is partially applied on each batch and then jitted for performance.
# scripts/finetune.py
@jax.jit
@partial(
shard_map,
mesh=mesh,
in_specs=(PartitionSpec(), PartitionSpec("batch")),
out_specs=(PartitionSpec(), PartitionSpec()),
)
def train_step(state, batch):
Once that is done, we launch the job again and get:
[...]
NotImplementedError: No replication rule for erf_inv. As a workaround, pass the `check_rep=False` argument to `shard_map`. To get this fixed, open an issue at https://github.com/google/jax/issues
Well, that’s another error, but this time with get a helpful description on how to work around it.
@jax.jit
@partial(
shard_map,
mesh=mesh,
#in_shardings=[replicated_sharding, dp_sharding],
in_specs=(PartitionSpec(), PartitionSpec("batch")),
out_specs=(PartitionSpec(), PartitionSpec()),
check_rep=False
)
def train_step(state, batch):
Adding that and running it once more:
[...]
12%|█▏ | 5888/50000 [28:17<2:59:31, 4.10it/s
So that’s it. It runs properly and with approximately twice of the speed of the single node version (estimated time left of 3 hours vs 6 hours). The speedup won’t be truly linear, but we manageg to get the code running in multiple nodes and the code can be later tested for longer and on more nodes to verify its actual behaviour.
For now, we commit out changes to a new branch and let the researchers know of the progress so they can test it themselves.