-
Notifications
You must be signed in to change notification settings - Fork 647
Streaming offloading in (q)lora single device #1443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
c4f0366
f1178c7
c937396
9eea976
77c6488
b455efd
af3d22d
b4e2269
b35fa29
206cc88
8c91a32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import contextlib | ||
import sys | ||
import time | ||
|
||
|
@@ -30,8 +31,12 @@ | |
validate_missing_and_unexpected_for_lora, | ||
) | ||
from torchtune.recipe_interfaces import FTRecipeInterface | ||
from torchtune.training import DummyProfiler, PROFILER_KEY | ||
|
||
from torchtune.training import ( | ||
DummyProfiler, | ||
NoOpManager, | ||
OffloadActivations, | ||
PROFILER_KEY, | ||
) | ||
from tqdm import tqdm | ||
|
||
log = utils.get_logger("DEBUG") | ||
|
@@ -43,13 +48,25 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): | |
for single GPU training. Training on CPU is not supported. | ||
|
||
Features: | ||
- Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` | ||
- Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` | ||
felipemello1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep | ||
activations in memory and instead recompute them during the backward pass. This is especially | ||
helpful for larger batch sizes when you're memory constrained. But these savings in memory | ||
come at the cost of training performance. In most cases training can slow-down quite a bit as | ||
a result of this activation recomputation. | ||
|
||
- Activation Offloading. This can be controlled using the ``enable_activation_offloading`` | ||
flag. Activation offloading is a technique similar to activations checkpointing that helps | ||
reduce the memory footprint to prevent OOMs and enable bigger batches. Where activations | ||
checkpointing drops the activation in the forward to recompute it later in the backward, | ||
activations offloading will drop the activation in the forward to the CPU and bring it | ||
back during the backward pass. As always, there is a tradeoff--these savings in memory can | ||
come at the cost of training performance and CPU resources. To recover some runtime cost, | ||
specify ``offload_with_streams: True`` to enable offloading on a different stream to permit | ||
overlapping with the computation. This option is currently only available on PyTorch nightly | ||
version 2.5.0.dev20240907 or later. Activation offloading can be used in conjunction with | ||
activation checkpointing. | ||
|
||
- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` | ||
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In | ||
most cases this should halve the memory footprint of full precision (fp32) training, without | ||
|
@@ -222,6 +239,8 @@ def setup(self, cfg: DictConfig) -> None: | |
self._model = self._setup_model( | ||
cfg_model=cfg.model, | ||
enable_activation_checkpointing=cfg.enable_activation_checkpointing, | ||
enable_activation_offloading=cfg.get("enable_activation_offloading", False), | ||
offload_with_streams=cfg.get("offload_with_streams", False), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that this is the last thing to approve the PR: I dont like having two arguments for offloading. What do you think about making enable_activation_offloading have 3 options? False / "with_stream" / "without_stream" or something like that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One q here: is there any case that I wouldn't wanna use streams provided I'm on a sufficiently recent PyTorch version? If there are no feature gaps I'm inclined to just have a single There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We haven't tested streams use on other hardware or with distributed or with other shape models, so I was hesitant to make it the default. Since we're only landing the lora finetuning recipe on single device with this PR, it is fine to just not include that flag option for now. I've removed it but it would be p easy to add it back if required in the future. |
||
compile_model=cfg.compile, | ||
base_model_state_dict=checkpoint_dict[training.MODEL_KEY], | ||
lora_weights_state_dict=( | ||
|
@@ -367,6 +386,8 @@ def _setup_model( | |
self, | ||
cfg_model: DictConfig, | ||
enable_activation_checkpointing: bool, | ||
enable_activation_offloading: bool, | ||
offload_with_streams: bool, | ||
compile_model: bool, | ||
base_model_state_dict: Dict[str, Any], | ||
lora_weights_state_dict: Optional[Dict[str, Any]] = None, | ||
|
@@ -420,6 +441,25 @@ def _setup_model( | |
self.adapter_params.items(), dtype=self._dtype | ||
) | ||
|
||
self.activations_handling_ctx = contextlib.nullcontext() | ||
if enable_activation_offloading: | ||
self.activations_handling_ctx = OffloadActivations( | ||
use_streams=offload_with_streams | ||
) | ||
|
||
# Below is our hack to disable offloading the last output Linear in every | ||
# step, as the cost for offloading the activation and then soon after bringing | ||
# it back is expensive. Moreover, due to heuristics in our streaming API, | ||
# we actually use more memory if we offload it as it interferes with chunkedCE. | ||
if hasattr(model, "output") and isinstance(model.output, nn.Module): | ||
noop_ctx = NoOpManager() | ||
model.output.register_forward_pre_hook( | ||
lambda *args: noop_ctx.__enter__() | ||
) | ||
model.output.register_forward_hook( | ||
lambda *args: noop_ctx.__exit__(), always_call=True | ||
) | ||
|
||
log.info(f"Model is initialized with precision {self._dtype}.") | ||
|
||
if self._device.type == "cuda": | ||
|
@@ -576,7 +616,8 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: | |
input_pos = batch.get("input_pos", None) # shape [b, s] | ||
|
||
# run model | ||
logits = self._model(tokens, mask=mask, input_pos=input_pos) | ||
with self.activations_handling_ctx: | ||
logits = self._model(tokens, mask=mask, input_pos=input_pos) | ||
|
||
# Shift labels to compute loss | ||
# equivalent to doing labels[..., 1:] and logits[..., :-1, :] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ebsmothers is this okay? This is a new requirement as we use psutil to check cpu RAM usage and warn on too much usage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FYI: psutil does not pull in other deps!