Skip to content

Streaming offloading in (q)lora single device #1443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Sep 16, 2024
2 changes: 1 addition & 1 deletion docs/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ifneq ($(EXAMPLES_PATTERN),)
endif

# You can set these variables from the command line.
SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS)
SPHINXOPTS = -W -j auto $(EXAMPLES_PATTERN_OPTS) -T -v
SPHINXBUILD = sphinx-build
SPHINXPROJ = torchtune
SOURCEDIR = source
Expand Down
8 changes: 6 additions & 2 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,12 @@
validate_missing_and_unexpected_for_lora,
)
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import DummyProfiler, PROFILER_KEY
from torchtune.training._activation_offloading import NoOpManager, OffloadActivations
from torchtune.training import (
DummyProfiler,
NoOpManager,
OffloadActivations,
PROFILER_KEY,
)
from tqdm import tqdm

log = utils.get_logger("DEBUG")
Expand Down
3 changes: 3 additions & 0 deletions torchtune/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from torchtune.training._activation_offloading import NoOpManager, OffloadActivations
from torchtune.training._compile import compile_loss, compile_model
from torchtune.training._distributed import (
contains_fsdp,
Expand Down Expand Up @@ -122,4 +123,6 @@
"setup_torch_profiler",
"compile_loss",
"compile_model",
"NoOpManager",
"OffloadActivations",
]
14 changes: 7 additions & 7 deletions torchtune/training/_activation_offloading.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,11 @@ def pack_tensor(activation: torch.Tensor) -> int:
if use_streams:
# First, sync back and dereference previously offloaded tensors
# as the offloading should be done sufficiently long ago.
for k in [x for x in self.fwd_stash.keys()]:
if k <= tensor_id - self.max_fwd_stash_size:
_, ev = self.fwd_stash[k]
for id in [k for k in self.fwd_stash.keys()]:
if id <= tensor_id - self.max_fwd_stash_size:
_, ev = self.fwd_stash[id]
self.s0.wait_event(ev)
del self.fwd_stash[k]
del self.fwd_stash[id]
else:
break

Expand Down Expand Up @@ -266,10 +266,10 @@ def hook(outputs, inputs):
self.bwd_ev_stash[unpack_tensor_id] = event

# if there are still things in the fwd_stash, get rid of them as we're in bwd now
for k in [x for x in self.fwd_stash.keys()]:
_, ev = self.fwd_stash[k]
for id in [k for k in self.fwd_stash.keys()]:
_, ev = self.fwd_stash[id]
self.s0.wait_event(ev)
del self.fwd_stash[k]
del self.fwd_stash[id]

# wait on prev node's events and del those
for id in prev_node_ids:
Expand Down
Loading