Skip to content

Streaming offloading in (q)lora single device #1443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 16, 2024
Merged
27 changes: 27 additions & 0 deletions docs/source/tutorials/memory_optimizations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,33 @@ and in most cases training can slow-down quite a bit as a result of this activat
To enable activation checkpointing, use the ``enable_activation_checkpointing`` config entry or flag
in any of our recipes, e.g. ``enable_activation_checkpointing=True``.

.. _glossary_act_off:

Activation Offloading
---------------------

*What's going on here?*

You may have just read about activation checkpointing! Similar to checkpointing, offloading is a memory
efficiency technique that allows saving GPU VRAM by temporarily moving activations to CPU and bringing
them back when needed in the backward pass.

See `PyTorch autograd hook tutorial <https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#saving-tensors-to-cpu>`_
for more details about how this is implemented through saved_tensors_hooks.

This setting is especially helpful for larger batch sizes, or longer context lengths when you're memory constrained.
However, these savings in memory can come at the cost of training speed (i.e. tokens per-second), as it takes runtime
and resources to move Tensors from GPU to CPU and back. The implementation in torchtune uses multiple CUDA streams
in order to overlap the extra communication with the computation to hide the extra runtime. As the communication
workload is variable depending on the number and size of tensors being offloaded, it is common to not offload every
single activation. In fact, once can use offloading in conjunction with activations checkpointing, where all
activations will either be recomputed later in the backward or brought back from the CPU.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I remember, this only works if activation_checkpointing is True. Is that still right? If so, we should probably update this doc and add to the recipes to raise and error or set AC=True automatically.

Another option, which i would prefer, is to investigate allowing offloading without AC, since streaming seems promising

nit: I believe you meant "one can use offloading" instead of "once can use offloading

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you mean by "this" in the first sentence? Activations offloading works when AC is false as well, it's just super slow.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i remember trying to use only offloading, with AC=False, and it broke. Maybe its not the case anymore?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea it shouldn't break! It should just be reaaaaally slow


*Sounds great! How do I use it?*

To enable activation offloading, use the ``enable_activation_offloading`` config entry or flag
in our lora finetuning single device recipe, e.g. ``enable_activation_offloading=True``.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i dont think that this will stay exclusive to lora finetuning. Maybe you are suggesting that this should be updated when it becomes available in full_single_device?

PS: do you plan to do it? (its fine if the answer is no, just checking :) )

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a note about offloading in single device vs FSDP, since in distributed it is another command?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea I was thinking we should mention lora finetuning single device for now in particular as it's only enabled for those. It would be important to widen this when more support is added --> I'm happy to do it.

In distributed, it should be the same command, but I haven't tested at all, so did not want to promise anything for that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In distributed, it should be the same command

I thought that for FSDP we use fsdp_cpu_offload and let fsdp do the offloading. Unless you think it makes sense to have two different offloading strategies

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FSDP's CPU offload will just offload parameters, gradients, and optimizer states though, right? Not activations. Do we expect that the two will work together? I can't immediately think of any reason why they wouldn't but maybe there's something obvious I'm missing

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, i missed that FSDP doesnt offload activations. Thats nice! I thought it was a solved issue for FSDP and we were just improving single device

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They theoretically should work together but who knows man, we can't say so til we test! I'm suspicious that there will be unforeseen results with intranode comms, so the scope of this PR is just single device.


.. _glossary_grad_accm:

Gradient Accumulation
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"numpy<=1.26.4", # Pin here until https://github.com/tensorflow/tensorboard/issues/6869 is addressed
"tqdm",
"omegaconf",
"psutil",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers is this okay? This is a new requirement as we use psutil to check cpu RAM usage and warn on too much usage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI: psutil does not pull in other deps!


]
dynamic = ["version"]
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/code_llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
dtype: bf16

# Logging
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/code_llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False
dtype: bf16

# Logging
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/2B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/gemma/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
2 changes: 2 additions & 0 deletions recipes/configs/llama2/13B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama2/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama3/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Profiler (disabled)
profiler:
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama3/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: True

# Profiler (disabled)
profiler:
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama3_1/8B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Profiler (disabled)
profiler:
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/llama3_1/8B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Offloading
enable_activation_checkpointing: True
enable_activation_offloading: False

# Profiler (disabled)
profiler:
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/mistral/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
1 change: 1 addition & 0 deletions recipes/configs/mistral/7B_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/phi3/mini_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16

# Logging
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/phi3/mini_qlora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ device: cuda

# Memory management
enable_activation_checkpointing: True
enable_activation_offloading: False

# Reduced precision
dtype: bf16

# Logging
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/qwen2/0.5B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/qwen2/1.5B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Memory
enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
3 changes: 3 additions & 0 deletions recipes/configs/qwen2/7B_lora_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ log_peak_memory_stats: False
# Environment
device: cuda
dtype: bf16

# Activations Offloading
enable_activation_checkpointing: True
enable_activation_offloading: False

# Show case the usage of pytorch profiler
# Set enabled to False as it's only needed for debugging training
Expand Down
43 changes: 39 additions & 4 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
import sys
import time

Expand All @@ -30,8 +31,12 @@
validate_missing_and_unexpected_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY

from torchtune.training import (
DummyProfiler,
NoOpManager,
OffloadActivations,
PROFILER_KEY,
)
from tqdm import tqdm

log = utils.get_logger("DEBUG")
Expand All @@ -43,13 +48,22 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface):
for single GPU training. Training on CPU is not supported.

Features:
- Activation Checkpointing. This can be controlled using the ``activation_checkpointing``
- Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing``
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep
activations in memory and instead recompute them during the backward pass. This is especially
helpful for larger batch sizes when you're memory constrained. But these savings in memory
come at the cost of training performance. In most cases training can slow-down quite a bit as
a result of this activation recomputation.

- Activation Offloading. This can be controlled using the ``enable_activation_offloading``
flag. Activation offloading is a technique similar to activations checkpointing that helps
reduce the memory footprint to prevent OOMs and enable bigger batches. Where activations
checkpointing drops the activation in the forward to recompute it later in the backward,
activations offloading will drop the activation in the forward to the CPU and bring it
back during the backward pass. As always, there is a tradeoff--these savings in memory can
come at the cost of training performance and CPU resources. Activation offloading
can be used in conjunction with activation checkpointing.

- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype``
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In
most cases this should halve the memory footprint of full precision (fp32) training, without
Expand Down Expand Up @@ -222,6 +236,8 @@ def setup(self, cfg: DictConfig) -> None:
self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
enable_activation_offloading=cfg.get("enable_activation_offloading", False),
offload_with_streams=cfg.get("offload_with_streams", False),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that this is the last thing to approve the PR: I dont like having two arguments for offloading. What do you think about making enable_activation_offloading have 3 options?

False / "with_stream" / "without_stream"

or something like that?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One q here: is there any case that I wouldn't wanna use streams provided I'm on a sufficiently recent PyTorch version? If there are no feature gaps I'm inclined to just have a single enable_activation_offloading flag that will run with streams if possible and otherwise run the vanilla version.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We haven't tested streams use on other hardware or with distributed or with other shape models, so I was hesitant to make it the default. Since we're only landing the lora finetuning recipe on single device with this PR, it is fine to just not include that flag option for now.

I've removed it but it would be p easy to add it back if required in the future.

compile_model=cfg.compile,
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
lora_weights_state_dict=(
Expand Down Expand Up @@ -367,6 +383,8 @@ def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
enable_activation_offloading: bool,
offload_with_streams: bool,
compile_model: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
Expand Down Expand Up @@ -420,6 +438,22 @@ def _setup_model(
self.adapter_params.items(), dtype=self._dtype
)

self.activations_handling_ctx = contextlib.nullcontext()
if enable_activation_offloading:
self.activations_handling_ctx = OffloadActivations(
use_streams=offload_with_streams
)

# Below is our hack to disable offloading the last output Linear in every
# step, as the cost for offloading the activation and then soon after bringing
# it back is expensive. Moreover, due to heuristics in our streaming API,
# we actually use more memory if we offload it as it interferes with chunkedCE.
noop_ctx = NoOpManager()
model.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__())
model.output.register_forward_hook(
lambda *args: noop_ctx.__exit__(), always_call=True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to test it with qwen or gemma, that have tied embeddings. They dont have model.output, since they use TiedTransformerDecoder, so this would fail. We would need to add something like "if hasattr(model.output)"

The issue is that they use the output = F.Linear(h, self.tok_embeddings.weight).

in this PR (#1527), i created a regular python class (no nn.Module) to replace the F.Linear logic, so we can get rid of the TiedTransformerDecoder

class TiedLinear:
    def __init__(self, tied_module: nn.Module):
        self.tied_module = tied_module

    def __call__(self, x: torch.tensor) -> torch.tensor:
        return F.linear(x, self.tied_module.weight)

model=TransformerDecoder(output_projection=TiedLinear(embeddings.weight))

I imagine that hooks only work with nn.module. Is that true? I didnt try to make the TiedLinear an nn.Module because it doesnt have its own weights, and i didnt want FSDP and other wrappers to interact with it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should probably add this NoOp logic to some utils, but no strong opinion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah yea, I agree we need to test these. The hooks do only work with module (since they're module hooks) so that would require some design...e.g., we enable adding hooks to this class you added, or we make it an nn.Module still but exclude it from FSDP/other wrappers. Or something else 😛

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah this tied embedding business is actually a pretty important point. As a hack we can maybe do something like if hasattr(model, "output") and isinstance(model.output, nn.Module) to gate the registration of these hooks?

Would have to think more about whether we can get away with making TiedLinear an nn.Module though.. half the reason we're doing things this way is to avoid the breaking of references for tied weights when FSDP calls to_empty. Is there an easy way to exclude an nn.Module from all FSDP's usual hook registration etc?

Anyways before I get too far down that rabbit hole, I'd propose just going with the hack (provided it works). It's dumb, simple, and explicit, which I like.

)

log.info(f"Model is initialized with precision {self._dtype}.")

if self._device.type == "cuda":
Expand Down Expand Up @@ -576,7 +610,8 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
input_pos = batch.get("input_pos", None) # shape [b, s]

# run model
logits = self._model(tokens, mask=mask, input_pos=input_pos)
with self.activations_handling_ctx:
logits = self._model(tokens, mask=mask, input_pos=input_pos)

# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
Expand Down
3 changes: 3 additions & 0 deletions torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from torchtune.training._activation_offloading import NoOpManager, OffloadActivations
from torchtune.training._compile import compile_loss, compile_model
from torchtune.training._distributed import (
contains_fsdp,
Expand Down Expand Up @@ -122,4 +123,6 @@
"setup_torch_profiler",
"compile_loss",
"compile_model",
"NoOpManager",
"OffloadActivations",
]
Loading
Loading