-
Notifications
You must be signed in to change notification settings - Fork 648
Streaming offloading in (q)lora single device #1443
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 3 commits
c4f0366
f1178c7
c937396
9eea976
77c6488
b455efd
af3d22d
b4e2269
b35fa29
206cc88
8c91a32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -83,6 +83,33 @@ and in most cases training can slow-down quite a bit as a result of this activat | |
To enable activation checkpointing, use the ``enable_activation_checkpointing`` config entry or flag | ||
in any of our recipes, e.g. ``enable_activation_checkpointing=True``. | ||
|
||
.. _glossary_act_off: | ||
|
||
Activation Offloading | ||
--------------------- | ||
|
||
*What's going on here?* | ||
|
||
You may have just read about activation checkpointing! Similar to checkpointing, offloading is a memory | ||
efficiency technique that allows saving GPU VRAM by temporarily moving activations to CPU and bringing | ||
them back when needed in the backward pass. | ||
|
||
See `PyTorch autograd hook tutorial <https://pytorch.org/tutorials/intermediate/autograd_saved_tensors_hooks_tutorial.html#saving-tensors-to-cpu>`_ | ||
for more details about how this is implemented through saved_tensors_hooks. | ||
|
||
This setting is especially helpful for larger batch sizes, or longer context lengths when you're memory constrained. | ||
However, these savings in memory can come at the cost of training speed (i.e. tokens per-second), as it takes runtime | ||
and resources to move Tensors from GPU to CPU and back. The implementation in torchtune uses multiple CUDA streams | ||
in order to overlap the extra communication with the computation to hide the extra runtime. As the communication | ||
workload is variable depending on the number and size of tensors being offloaded, it is common to not offload every | ||
single activation. In fact, once can use offloading in conjunction with activations checkpointing, where all | ||
activations will either be recomputed later in the backward or brought back from the CPU. | ||
|
||
*Sounds great! How do I use it?* | ||
|
||
To enable activation offloading, use the ``enable_activation_offloading`` config entry or flag | ||
in our lora finetuning single device recipe, e.g. ``enable_activation_offloading=True``. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i dont think that this will stay exclusive to lora finetuning. Maybe you are suggesting that this should be updated when it becomes available in full_single_device? PS: do you plan to do it? (its fine if the answer is no, just checking :) ) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we also add a note about offloading in single device vs FSDP, since in distributed it is another command? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea I was thinking we should mention lora finetuning single device for now in particular as it's only enabled for those. It would be important to widen this when more support is added --> I'm happy to do it. In distributed, it should be the same command, but I haven't tested at all, so did not want to promise anything for that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I thought that for FSDP we use fsdp_cpu_offload and let fsdp do the offloading. Unless you think it makes sense to have two different offloading strategies There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FSDP's CPU offload will just offload parameters, gradients, and optimizer states though, right? Not activations. Do we expect that the two will work together? I can't immediately think of any reason why they wouldn't but maybe there's something obvious I'm missing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, i missed that FSDP doesnt offload activations. Thats nice! I thought it was a solved issue for FSDP and we were just improving single device There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They theoretically should work together but who knows man, we can't say so til we test! I'm suspicious that there will be unforeseen results with intranode comms, so the scope of this PR is just single device. |
||
|
||
.. _glossary_grad_accm: | ||
|
||
Gradient Accumulation | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ dependencies = [ | |
"numpy<=1.26.4", # Pin here until https://github.com/tensorflow/tensorboard/issues/6869 is addressed | ||
"tqdm", | ||
"omegaconf", | ||
"psutil", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ebsmothers is this okay? This is a new requirement as we use psutil to check cpu RAM usage and warn on too much usage. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FYI: psutil does not pull in other deps! |
||
|
||
] | ||
dynamic = ["version"] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import contextlib | ||
import sys | ||
import time | ||
|
||
|
@@ -30,8 +31,12 @@ | |
validate_missing_and_unexpected_for_lora, | ||
) | ||
from torchtune.recipe_interfaces import FTRecipeInterface | ||
from torchtune.training import DummyProfiler, PROFILER_KEY | ||
|
||
from torchtune.training import ( | ||
DummyProfiler, | ||
NoOpManager, | ||
OffloadActivations, | ||
PROFILER_KEY, | ||
) | ||
from tqdm import tqdm | ||
|
||
log = utils.get_logger("DEBUG") | ||
|
@@ -43,13 +48,22 @@ class LoRAFinetuneRecipeSingleDevice(FTRecipeInterface): | |
for single GPU training. Training on CPU is not supported. | ||
|
||
Features: | ||
- Activation Checkpointing. This can be controlled using the ``activation_checkpointing`` | ||
- Activation Checkpointing. This can be controlled using the ``enable_activation_checkpointing`` | ||
felipemello1 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
flag. Activation checkpointing helps reduce the memory footprint since we no longer keep | ||
activations in memory and instead recompute them during the backward pass. This is especially | ||
helpful for larger batch sizes when you're memory constrained. But these savings in memory | ||
come at the cost of training performance. In most cases training can slow-down quite a bit as | ||
a result of this activation recomputation. | ||
|
||
- Activation Offloading. This can be controlled using the ``enable_activation_offloading`` | ||
flag. Activation offloading is a technique similar to activations checkpointing that helps | ||
reduce the memory footprint to prevent OOMs and enable bigger batches. Where activations | ||
checkpointing drops the activation in the forward to recompute it later in the backward, | ||
activations offloading will drop the activation in the forward to the CPU and bring it | ||
back during the backward pass. As always, there is a tradeoff--these savings in memory can | ||
come at the cost of training performance and CPU resources. Activation offloading | ||
can be used in conjunction with activation checkpointing. | ||
|
||
- Precision. Full fp32 and bf16 training are supported. Precision is controlled using the ``dtype`` | ||
flag. When ``dtype=bf16``, all activations, gradients and optimizer states are in bfloat16. In | ||
most cases this should halve the memory footprint of full precision (fp32) training, without | ||
|
@@ -222,6 +236,8 @@ def setup(self, cfg: DictConfig) -> None: | |
self._model = self._setup_model( | ||
cfg_model=cfg.model, | ||
enable_activation_checkpointing=cfg.enable_activation_checkpointing, | ||
enable_activation_offloading=cfg.get("enable_activation_offloading", False), | ||
offload_with_streams=cfg.get("offload_with_streams", False), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that this is the last thing to approve the PR: I dont like having two arguments for offloading. What do you think about making enable_activation_offloading have 3 options? False / "with_stream" / "without_stream" or something like that? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One q here: is there any case that I wouldn't wanna use streams provided I'm on a sufficiently recent PyTorch version? If there are no feature gaps I'm inclined to just have a single There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We haven't tested streams use on other hardware or with distributed or with other shape models, so I was hesitant to make it the default. Since we're only landing the lora finetuning recipe on single device with this PR, it is fine to just not include that flag option for now. I've removed it but it would be p easy to add it back if required in the future. |
||
compile_model=cfg.compile, | ||
base_model_state_dict=checkpoint_dict[training.MODEL_KEY], | ||
lora_weights_state_dict=( | ||
|
@@ -367,6 +383,8 @@ def _setup_model( | |
self, | ||
cfg_model: DictConfig, | ||
enable_activation_checkpointing: bool, | ||
enable_activation_offloading: bool, | ||
offload_with_streams: bool, | ||
compile_model: bool, | ||
base_model_state_dict: Dict[str, Any], | ||
lora_weights_state_dict: Optional[Dict[str, Any]] = None, | ||
|
@@ -420,6 +438,22 @@ def _setup_model( | |
self.adapter_params.items(), dtype=self._dtype | ||
) | ||
|
||
self.activations_handling_ctx = contextlib.nullcontext() | ||
if enable_activation_offloading: | ||
self.activations_handling_ctx = OffloadActivations( | ||
use_streams=offload_with_streams | ||
) | ||
|
||
# Below is our hack to disable offloading the last output Linear in every | ||
# step, as the cost for offloading the activation and then soon after bringing | ||
# it back is expensive. Moreover, due to heuristics in our streaming API, | ||
# we actually use more memory if we offload it as it interferes with chunkedCE. | ||
noop_ctx = NoOpManager() | ||
model.output.register_forward_pre_hook(lambda *args: noop_ctx.__enter__()) | ||
model.output.register_forward_hook( | ||
lambda *args: noop_ctx.__exit__(), always_call=True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we need to test it with qwen or gemma, that have tied embeddings. They dont have model.output, since they use TiedTransformerDecoder, so this would fail. We would need to add something like "if hasattr(model.output)" The issue is that they use the output = F.Linear(h, self.tok_embeddings.weight). in this PR (#1527), i created a regular python class (no nn.Module) to replace the F.Linear logic, so we can get rid of the TiedTransformerDecoder
I imagine that hooks only work with nn.module. Is that true? I didnt try to make the TiedLinear an nn.Module because it doesnt have its own weights, and i didnt want FSDP and other wrappers to interact with it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should probably add this NoOp logic to some utils, but no strong opinion There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah yea, I agree we need to test these. The hooks do only work with module (since they're module hooks) so that would require some design...e.g., we enable adding hooks to this class you added, or we make it an nn.Module still but exclude it from FSDP/other wrappers. Or something else 😛 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yeah this tied embedding business is actually a pretty important point. As a hack we can maybe do something like Would have to think more about whether we can get away with making Anyways before I get too far down that rabbit hole, I'd propose just going with the hack (provided it works). It's dumb, simple, and explicit, which I like. |
||
) | ||
|
||
log.info(f"Model is initialized with precision {self._dtype}.") | ||
|
||
if self._device.type == "cuda": | ||
|
@@ -576,7 +610,8 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: | |
input_pos = batch.get("input_pos", None) # shape [b, s] | ||
|
||
# run model | ||
logits = self._model(tokens, mask=mask, input_pos=input_pos) | ||
with self.activations_handling_ctx: | ||
logits = self._model(tokens, mask=mask, input_pos=input_pos) | ||
|
||
# Shift labels to compute loss | ||
# equivalent to doing labels[..., 1:] and logits[..., :-1, :] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As far as I remember, this only works if activation_checkpointing is True. Is that still right? If so, we should probably update this doc and add to the recipes to raise and error or set AC=True automatically.
Another option, which i would prefer, is to investigate allowing offloading without AC, since streaming seems promising
nit: I believe you meant "one can use offloading" instead of "once can use offloading
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you mean by "this" in the first sentence? Activations offloading works when AC is false as well, it's just super slow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i remember trying to use only offloading, with AC=False, and it broke. Maybe its not the case anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yea it shouldn't break! It should just be reaaaaally slow