This document describes AReaL's checkpointing system, which handles model saving for evaluation and fault-tolerant recovery during distributed RL training.
AReaL provides two complementary checkpointing mechanisms:
| Mechanism | Purpose | Format | Includes Optimizer/DataLoader State |
|---|---|---|---|
| Saver | Export models for evaluation or publishing | HuggingFace | No |
| RecoverHandler | Resume training after failures | DCP (Distributed Checkpoint) | Yes |
Both mechanisms are invoked automatically during training and can be configured via
config.saver and config.recover respectively.
Used by Saver for model export:
- Standard HuggingFace model format (safetensors + config.json)
- Compatible with
transformers.AutoModel.from_pretrained() - Can be uploaded to HuggingFace Hub
- Does not include optimizer state
Used by RecoverHandler for fault tolerance:
- Backend's native distributed checkpoint format (
torch.distributed.checkpointor Megatron distributed checkpoint) - Sharded across all ranks for efficient parallel I/O
- Includes model weights, optimizer state, RNG state, etc
- Backend-specific: checkpoints are only compatible with the same parallelism configuration
- Overwrites previous checkpoint to save disk space
PPOTrainer.train()
│
├── Training loop
│ ├── Rollout, compute values, PPO update...
│ │
│ ├── _save_hf() # HuggingFace export
│ │ └── Saver.save()
│ │ └── engine.save(weight_format="hf")
│ │
│ └── _save_recover_checkpoint() # Fault tolerance
│ └── RecoverHandler.dump()
│ └── engine.save(weight_format="dcp", with_optim=True)
│
└── On restart
└── RecoverHandler.load()
├── Restore dataloader, saver, evaluator states
└── engine.load(weight_format="dcp", with_optim=True)
The Saver
periodically exports model weights in HuggingFace format for evaluation or deployment.
The mode parameter controls how checkpoints are written:
| Mode | Behavior |
|---|---|
auto |
Use async for Archon engine, sync for others (default). Zero-config optimal for all engines. |
sync |
Always synchronous dcp.save(). |
async |
Always process-based async with pinned-memory staging. Archon engine only; other engines fall back to sync with a warning. Extra CPU pinned memory proportional to per-rank model shard size (e.g., ~17.5 GB/rank for 70B on 8 GPUs). |
With the default auto mode, Archon engine users get async checkpoint saving
automatically - the training loop blocks only while the checkpoint is staged to pinned
CPU memory, and the actual disk I/O happens in a background process.
Configure via config.saver:
| Parameter | Type | Default | Description |
|---|---|---|---|
mode |
str | "auto" |
Save mode (see above). |
freq_epochs |
int | None | None | Save every N epochs. None disables. |
freq_steps |
int | None | None | Save every N steps. None disables. |
freq_secs |
int | None | None | Save every N seconds. None disables. |
Example configuration:
saver:
freq_epochs: 1 # Save at end of each epoch
freq_steps: null # Disabled
freq_secs: null # Disabled
# mode defaults to "auto" - Archon users get async automaticallySaving is triggered when any of epoch/step/time condition is met.
Checkpoints are saved to:
{fileroot}/checkpoints/{user}/{experiment_name}/{trial_name}/default/
└── epoch{E}epochstep{S}globalstep{G}/
├── config.json
├── model.safetensors (or model-00001-of-00002.safetensors, etc.)
├── tokenizer.json
└── ...
Load saved checkpoints with standard HuggingFace APIs:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained(
"/path/to/checkpoint/epoch0epochstep99globalstep99"
)
tokenizer = AutoTokenizer.from_pretrained(
"/path/to/checkpoint/epoch0epochstep99globalstep99"
)The
RecoverHandler
enables resuming training after failures by saving complete training state.
Configure via config.recover:
| Parameter | Type | Default | Description |
|---|---|---|---|
mode |
str | "disabled" | Recovery mode: "on"/"auto" or "off"/"disabled" |
freq_epochs |
int | None | None | Checkpoint every N epochs |
freq_steps |
int | None | None | Checkpoint every N steps |
freq_secs |
int | None | None | Checkpoint every N seconds |
retries |
int | 3 | Number of recovery retries when recovery enabled |
| Mode | Behavior |
|---|---|
on or auto |
Automatically resume if valid checkpoint exists |
off or disabled |
No checkpointing or recovery |
When recovery is enabled (on/auto), the system will:
- Periodically save recovery checkpoints (model weights, optimizer state, dataloader position)
- Automatically resume from the last valid checkpoint on restart
- Retry up to
retriestimes on failure
Example configuration:
recover:
mode: on # or "auto" for backward compatibility
freq_steps: 100 # Checkpoint every 100 steps
retries: 3RecoverHandler saves complete training state:
| Component | Contents |
|---|---|
| Model weights | DCP format, sharded across ranks |
| Optimizer state | Momentum, variance (Adam), learning rate scheduler |
| RNG state | Python, NumPy, PyTorch, CUDA random states |
| Dataloader state | Current position in dataset |
| Training progress | Epoch, step, global_step counters |
| Auxiliary states | Saver, Evaluator, StatsLogger states |
Recovery checkpoints are saved to:
{fileroot}/checkpoints/{user}/{experiment_name}/{trial_name}/
├── default/
│ └── recover_checkpoint/ # Model + optimizer (DCP format)
│ ├── __0_0.distcp
│ ├── __1_0.distcp
│ └── ...
├── critic/ # If using critic
│ └── recover_checkpoint/
└── recover_info/ # Metadata
├── step_info.json
├── saver_info.json
├── evaluator_info.json
├── stats_logger_info.json
├── checkpoint_info.json
└── dataloader_info.pkl
When training resumes:
RecoverHandler.load()restores all saved state (if any)- Training continues from
last_step_info.next().global_step - Inference engine weights are synchronized to match recovered state
| Scenario | Recommended Setting |
|---|---|
| Long training runs | freq_epochs: 1 or freq_steps: 1000 |
| Unpredictable time | freq_secs: 7200 |
| Unstable clusters | freq_steps: 100 with recover.mode: on |
| Limited disk space | Lower frequency, rely on final checkpoint |
| Debugging | freq_steps: 1 for quick iteration |
- Saver: Each save creates a new directory. High frequency consumes significant space.
- RecoverHandler: Overwrites previous checkpoint. Only one copy exists at a time.
- Verify checkpoint validity: Check
recover_info/step_info.jsonfor the last saved step - Same config required: DCP checkpoints require identical parallelism configuration, experiment name, and trial name
- Clean restart: Delete
recover_info/directory to start fresh