Skip to content

Latest commit

 

History

History
241 lines (181 loc) · 9.67 KB

File metadata and controls

241 lines (181 loc) · 9.67 KB

Checkpointing

This document describes AReaL's checkpointing system, which handles model saving for evaluation and fault-tolerant recovery during distributed RL training.

Overview

AReaL provides two complementary checkpointing mechanisms:

Mechanism Purpose Format Includes Optimizer/DataLoader State
Saver Export models for evaluation or publishing HuggingFace No
RecoverHandler Resume training after failures DCP (Distributed Checkpoint) Yes

Both mechanisms are invoked automatically during training and can be configured via config.saver and config.recover respectively.

Checkpoint Formats

HuggingFace Format

Used by Saver for model export:

  • Standard HuggingFace model format (safetensors + config.json)
  • Compatible with transformers.AutoModel.from_pretrained()
  • Can be uploaded to HuggingFace Hub
  • Does not include optimizer state

DCP Format (Distributed Checkpoint)

Used by RecoverHandler for fault tolerance:

  • Backend's native distributed checkpoint format (torch.distributed.checkpoint or Megatron distributed checkpoint)
  • Sharded across all ranks for efficient parallel I/O
  • Includes model weights, optimizer state, RNG state, etc
  • Backend-specific: checkpoints are only compatible with the same parallelism configuration
  • Overwrites previous checkpoint to save disk space

Architecture

PPOTrainer.train()
│
├── Training loop
│   ├── Rollout, compute values, PPO update...
│   │
│   ├── _save_hf()                          # HuggingFace export
│   │   └── Saver.save()
│   │       └── engine.save(weight_format="hf")
│   │
│   └── _save_recover_checkpoint()          # Fault tolerance
│       └── RecoverHandler.dump()
│           └── engine.save(weight_format="dcp", with_optim=True)
│
└── On restart
    └── RecoverHandler.load()
        ├── Restore dataloader, saver, evaluator states
        └── engine.load(weight_format="dcp", with_optim=True)

Saver: HuggingFace Model Export

The Saver periodically exports model weights in HuggingFace format for evaluation or deployment.

Save Mode

The mode parameter controls how checkpoints are written:

Mode Behavior
auto Use async for Archon engine, sync for others (default). Zero-config optimal for all engines.
sync Always synchronous dcp.save().
async Always process-based async with pinned-memory staging. Archon engine only; other engines fall back to sync with a warning. Extra CPU pinned memory proportional to per-rank model shard size (e.g., ~17.5 GB/rank for 70B on 8 GPUs).

With the default auto mode, Archon engine users get async checkpoint saving automatically - the training loop blocks only while the checkpoint is staged to pinned CPU memory, and the actual disk I/O happens in a background process.

Configuration

Configure via config.saver:

Parameter Type Default Description
mode str "auto" Save mode (see above).
freq_epochs int | None None Save every N epochs. None disables.
freq_steps int | None None Save every N steps. None disables.
freq_secs int | None None Save every N seconds. None disables.

Example configuration:

saver:
  freq_epochs: 1      # Save at end of each epoch
  freq_steps: null    # Disabled
  freq_secs: null     # Disabled
  # mode defaults to "auto" - Archon users get async automatically

Saving is triggered when any of epoch/step/time condition is met.

Output Location

Checkpoints are saved to:

{fileroot}/checkpoints/{user}/{experiment_name}/{trial_name}/default/
└── epoch{E}epochstep{S}globalstep{G}/
    ├── config.json
    ├── model.safetensors (or model-00001-of-00002.safetensors, etc.)
    ├── tokenizer.json
    └── ...

Usage

Load saved checkpoints with standard HuggingFace APIs:

from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "/path/to/checkpoint/epoch0epochstep99globalstep99"
)
tokenizer = AutoTokenizer.from_pretrained(
    "/path/to/checkpoint/epoch0epochstep99globalstep99"
)

RecoverHandler: Fault Tolerance

The RecoverHandler enables resuming training after failures by saving complete training state.

Configuration

Configure via config.recover:

Parameter Type Default Description
mode str "disabled" Recovery mode: "on"/"auto" or "off"/"disabled"
freq_epochs int | None None Checkpoint every N epochs
freq_steps int | None None Checkpoint every N steps
freq_secs int | None None Checkpoint every N seconds
retries int 3 Number of recovery retries when recovery enabled

Recovery Modes

Mode Behavior
on or auto Automatically resume if valid checkpoint exists
off or disabled No checkpointing or recovery

When recovery is enabled (on/auto), the system will:

  1. Periodically save recovery checkpoints (model weights, optimizer state, dataloader position)
  2. Automatically resume from the last valid checkpoint on restart
  3. Retry up to retries times on failure

Example configuration:

recover:
  mode: on            # or "auto" for backward compatibility
  freq_steps: 100     # Checkpoint every 100 steps
  retries: 3

What Gets Saved

RecoverHandler saves complete training state:

Component Contents
Model weights DCP format, sharded across ranks
Optimizer state Momentum, variance (Adam), learning rate scheduler
RNG state Python, NumPy, PyTorch, CUDA random states
Dataloader state Current position in dataset
Training progress Epoch, step, global_step counters
Auxiliary states Saver, Evaluator, StatsLogger states

Output Location

Recovery checkpoints are saved to:

{fileroot}/checkpoints/{user}/{experiment_name}/{trial_name}/
├── default/
│   └── recover_checkpoint/     # Model + optimizer (DCP format)
│       ├── __0_0.distcp
│       ├── __1_0.distcp
│       └── ...
├── critic/                     # If using critic
│   └── recover_checkpoint/
└── recover_info/               # Metadata
    ├── step_info.json
    ├── saver_info.json
    ├── evaluator_info.json
    ├── stats_logger_info.json
    ├── checkpoint_info.json
    └── dataloader_info.pkl

Recovery Process

When training resumes:

  1. RecoverHandler.load() restores all saved state (if any)
  2. Training continues from last_step_info.next().global_step
  3. Inference engine weights are synchronized to match recovered state

Best Practices

Frequency Guidelines

Scenario Recommended Setting
Long training runs freq_epochs: 1 or freq_steps: 1000
Unpredictable time freq_secs: 7200
Unstable clusters freq_steps: 100 with recover.mode: on
Limited disk space Lower frequency, rely on final checkpoint
Debugging freq_steps: 1 for quick iteration

Disk Space Considerations

  • Saver: Each save creates a new directory. High frequency consumes significant space.
  • RecoverHandler: Overwrites previous checkpoint. Only one copy exists at a time.

Recovery Tips

  1. Verify checkpoint validity: Check recover_info/step_info.json for the last saved step
  2. Same config required: DCP checkpoints require identical parallelism configuration, experiment name, and trial name
  3. Clean restart: Delete recover_info/ directory to start fresh