FSDP is an essential technique for training large models that don't fit into a single device. There are tutorials out there on how to set it up in JAX, but none of them do it for the modern, Flax NNX API, and those tutorials don't cover the whole training recipe like checkpointing and rngs that a production code needs.
Closing this gap, this repository contains a tutorial on how to implement FSDP in JAX using NNX modules. You can view the step-by-step guide in the notebook or check out the complete code in the main.py file.
The code in this repository supports the following:
- Fully working FSDP implementation in JAX on TPU that evenly shards all weights across the devices, together with DDP
- Modern, native Flax NNX module API
- Checkpointing to disk or GCP bucket via Orbax
- Reproducible
nnx.Rngsfor noise generation and dropout - The same checkpoint can be run on TPUs with a different number of devices
- EMA version of the model
- All model operation functions are JIT compiled
- Tested on v4 and v5p GCP TPUs
Here are the instructions you need to follow to try out the main.py code yourself. The repository contains a convenient script to run main.py on GCP TPUs.
First, create a shell script file that looks like this:
TPU={your GCP tpu name}
EXP_NAME=fsdp_test
TIMESTAMP="$(date +"%Y%m%d_%H%M%S")"
LOGFILE=logs/output_${TIMESTAMP}_${EXP_NAME}.log
python run_on_tpu.py \
--resource-name "${TPU}" \
--gcp-zone "{your GCP zone}" \
--gcp-project "{your GCP project}" \
--git-branch "main" \
--run-command "python main.py \
--experiment_name=${EXP_NAME} \
--checkpoint_dir={path to a GCP bucket folder}" 2>&1 | tee "$LOGFILE"Now you can run your shell script file. It will execute run_on_tpu.py, which will download this repository onto the TPU, create the conda env, install the Python dependencies, and run main.py. It will save the checkpoint to the GCP bucket folder and outputs to the $HOME/outputs/ directory on the TPU machine with index 0.