Skip to content

Streaming offloading in (q)lora single device #1443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 16, 2024
Merged
29 changes: 29 additions & 0 deletions docs/source/tutorials/memory_optimizations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,35 @@ and in most cases training can slow-down quite a bit as a result of this activat
To enable activation checkpointing, use the ``enable_activation_checkpointing`` config entry or flag
in any of our recipes, e.g. ``enable_activation_checkpointing=True``.

.. _glossary_act_off:

Activation Offloading
---------------------

*What's going on here?*

You may have just read about activation checkpointing! Similar to checkpointing, offloading is a memory
efficiency technique that allows saving GPU VRAM by temporarily moving activations to CPU and bringing
them back when needed in the backward pass.

See `PyTorch autograd hook tutorial <https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#saving-tensors-to-cpu>`_
for more details about how this is implemented through saved_tensors_hooks.

This setting is especially helpful for larger batch sizes, or longer context lengths when you're memory constrained.
However, these savings in memory can come at the cost of training speed (i.e. tokens per-second), as it takes runtime
and resources to move Tensors from GPU to CPU and back. The implementation in torchtune has the ``offload_with_streams``
option to use multiple CUDA streams in order to overlap the extra communication with the computation to hide the extra
runtime. As the communication workload is variable depending on the number and size of tensors being offloaded, it is
common to not offload every single activation. In fact, once can use offloading in conjunction with activations
checkpointing, where all activations will either be recomputed later in the backward or brought back from the CPU.

*Sounds great! How do I use it?*

To enable activation offloading, use the ``enable_activation_offloading`` config entry or flag
in our lora finetuning single device recipe, e.g. ``enable_activation_offloading=True``. To allow
usage of streams, make sure you are on a torch version later than PyTorch 2.5.0.dev20240907 and
specify ``offload_with_streams=True``.

.. _glossary_grad_accm:

Gradient Accumulation
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"numpy<=1.26.4", # Pin here until https://github.com/tensorflow/tensorboard/issues/6869 is addressed
"tqdm",
"omegaconf",
"psutil",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers is this okay? This is a new requirement as we use psutil to check cpu RAM usage and warn on too much usage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: psutil does not pull in other deps!


]
dynamic = ["version"]
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/code_llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
dtype: bf16

# Logging
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/code_llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
dtype: bf16

# Logging
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
2 changes: 2 additions & 0 deletions recipes/configs/llama2/13B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama3/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Profiler (disabled)
profiler:
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama3/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: True

# Profiler (disabled)
profiler:
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama3_1/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Profiler (disabled)
profiler:
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama3_1/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Offloading
enable_activation_checkpointing: True
enable_activation_offloading: False

# Profiler (disabled)
profiler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/mistral/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/mistral/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/phi3/mini_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16

# Logging
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/phi3/mini_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16

# Logging
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/qwen2/0.5B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/qwen2/1.5B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/qwen2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Offloading
enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
53 changes: 49 additions & 4 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import sys
import time

Expand All @@ -30,8 +31,12 @@
validate_missing_and_unexpected_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY

from torchtune.training import (
DummyProfiler,
NoOpManager,
OffloadActivations,
PROFILER_KEY,
)
from tqdm import tqdm

log = utils.get_logger("DEBUG")
Expand All @@ -43,13 +48,25 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface):
for single GPU training. Training on CPU is not supported.

Features:
- Activation Checkpointing. This can be controlled using the ``activation_checkpointing``
- Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing``
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
activations in memory and instead recompute them during the backward pass. This is especially
helpful for larger batch sizes when you're memory constrained. But these savings in memory
come at the cost of training performance. In most cases training can slow-down quite a bit as
a result of this activation recomputation.

- Activation Offloading. This can be controlled using the ``enable_activation_offloading``
flag. Activation offloading is a technique similar to activations checkpointing that helps
reduce the memory footprint to prevent OOMs on CUDA and enable bigger batches. Where activations
checkpointing drops the activation in the forward to recompute it later in the backward,
activations offloading will drop the activation in the forward to the CPU and bring it
back during the backward pass. As always, there is a tradeoff--these savings in memory can
come at the cost of training performance and CPU resources. To recover some runtime cost,
we've added an option to enable offloading on a different stream to permit overlapping with
the computation. This option is currently only available on PyTorch nightly 2.5.0.dev20240907
or later and will be enabled by default if an acceptable torch version is found. Activation
offloading can be used in conjunction with activation checkpointing.

- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
most cases this should halve the memory footprint of full precision (fp32) training, without
Expand Down Expand Up @@ -101,6 +118,7 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface):
Raises:
ValueError: If ``dtype`` is set to fp16.
RuntimeError: If ``dtype`` is set to bf16 and the hardware does not support bf16.
RuntimeError: If ``enable_activation_offloading`` is True and device is not CUDA.

"""

Expand Down Expand Up @@ -138,6 +156,13 @@ def __init__(self, cfg: DictConfig) -> None:
self._save_adapter_weights_only = cfg.get("save_adapter_weights_only", False)
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
self._enable_activation_offloading = cfg.get(
"enable_activation_offloading", False
)
if self._enable_activation_offloading and self._device.type != "cuda":
raise RuntimeError(
"enable_activation_offloading should only be enabled for training on CUDA"
)

def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -222,6 +247,7 @@ def setup(self, cfg: DictConfig) -> None:
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
enable_activation_offloading=self._enable_activation_offloading,
compile_model=cfg.compile,
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
lora_weights_state_dict=(
Expand Down Expand Up @@ -367,6 +393,7 @@ def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
enable_activation_offloading: bool,
compile_model: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -420,6 +447,23 @@ def _setup_model(
self.adapter_params.items(), dtype=self._dtype
)

self.activations_handling_ctx = contextlib.nullcontext()
if enable_activation_offloading:
self.activations_handling_ctx = OffloadActivations()

# Below is our hack to disable offloading the last output Linear in every
# step, as the cost for offloading the activation and then soon after bringing
# it back is expensive. Moreover, due to heuristics in our streaming API,
# we actually use more memory if we offload it as it interferes with chunkedCE.
if hasattr(model, "output") and isinstance(model.output, nn.Module):
noop_ctx = NoOpManager()
model.output.register_forward_pre_hook(
lambda *args: noop_ctx.__enter__()
)
model.output.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
)

log.info(f"Model is initialized with precision {self._dtype}.")

if self._device.type == "cuda":
Expand Down Expand Up @@ -576,7 +620,8 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
input_pos = batch.get("input_pos", None) # shape [b, s]

# run model
logits = self._model(tokens, mask=mask, input_pos=input_pos)
with self.activations_handling_ctx:
logits = self._model(tokens, mask=mask, input_pos=input_pos)

# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
Expand Down
Loading
Loading