Skip to content

Add LR Scheduler to single device full finetune #1350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion recipes/configs/llama2/7B_full_low_memory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ batch_size: 2
epochs: 3
optimizer:
_component_: bitsandbytes.optim.PagedAdamW
lr: 2e-5
lr: 1e-5
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
optimizer_in_bwd: True
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
5 changes: 4 additions & 1 deletion recipes/configs/llama3/8B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ batch_size: 2
epochs: 3
optimizer:
_component_: bitsandbytes.optim.PagedAdamW8bit
lr: 2e-5
lr: 1e-5
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
Expand Down
48 changes: 48 additions & 0 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,13 @@ def setup(self, cfg: DictConfig) -> None:
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch

# Setup lr scheduler
self._lr_scheduler = self._setup_lr_scheduler(
cfg_lr_scheduler=cfg.lr_scheduler,
num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)

# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
# if cfg is missing profiler key or if `cfg.profiler.enabled = False`
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
Expand Down Expand Up @@ -429,6 +436,45 @@ def _setup_optimizer(
log.info("Optimizer is initialized.")
return optimizer

def _setup_lr_scheduler(
self,
cfg_lr_scheduler: DictConfig,
num_training_steps: int,
last_epoch: int,
) -> Optional[Optimizer]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add docstrings here?

if self._optimizer_in_bwd:
# Use the first optimizer from the wrapper to represent the learning rate
optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values()))
else:
# Standard case: use the single optimizer
optimizer = self._optimizer

# Check if the lr_scheduler component is from torchtune.modules or torch.optim
component = cfg_lr_scheduler.get("_component_", None)

# Conditionally instantiate the scheduler based on the component
if "torchtune.modules" in component:
# Instantiate the learning rate scheduler
lr_scheduler = config.instantiate(
cfg_lr_scheduler,
optimizer,
num_training_steps=num_training_steps,
last_epoch=last_epoch,
)
else:
lr_scheduler = config.instantiate(
cfg_lr_scheduler,
optimizer,
last_epoch=last_epoch,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah is this a consequence of using ConstantLR in the test? I think it's a bit confusing to condition on whether the component is coming from torchtune.modules or not. In general there will be things from torch.optim that don't match the signature in the else statement, and there will probably also (in the future) be things from torchtune.modules that don't match the signature in the if statement.

In that case (and sorry for changing my mind on this), is there any reason we can't just make lr_scheduler optional? Then if it's not passed we just continue to use the usual constant learning rate defined in the optimizer and set self._lr_scheduler=None in the recipe (and in the training loop check that it's not None prior to stepping). This will make the recipe test a non-issue and keep the setup clean.

Separately though it's clear that our LR scheduler instantiation is not flexible enough. We should figure out how to better support schedulers from torch.optim moving forward (though that's a problem for another PR).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I was thinking to make it more flexible also and support other types of torch.optim.lr_scheduler. I find that torchtune.modules.lr_scheduler only has get_cosine_schedule_with_warmup and num_training_steps input would not be compatible with other torch.optim.lr_scheduler.

I can check and work with lr_scheduleer as None by the weekend. It is mainly due to the testing config that it always would input num_training_steps even when there is no args for that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated self._lr_scheduler=None
Also I modified the testing. I still use the cosine instead of constant but effectively change it equivalent to constant by setting num_cycles=0. It should be a bit more elegant imo


if self._optimizer_in_bwd:
# Modify the scheduler for optimizer_in_bwd case
self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler)

log.info("Learning rate scheduler is initialized.")
return lr_scheduler

def _setup_data(
self,
cfg_dataset: DictConfig,
Expand Down Expand Up @@ -589,6 +635,8 @@ def train(self) -> None:
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Need to fix `lr_scheduler.step()` before `optimizer.step()` warning
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this comment referring to? : )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a warning when I ran the code. I think it says "Detected call of lr_scheduler.step() before optimizer.step() ". I do not get why it has such warning since the lr_scheduler.step()is afteroptimizer.step()`

self._lr_scheduler.step()
self.global_step += 1

loss_to_log = running_loss.item()
Expand Down
3 changes: 3 additions & 0 deletions tests/recipes/test_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def _get_test_config_overrides(self):
"max_steps_per_epoch=2",
"optimizer=torch.optim.AdamW",
"optimizer.lr=2e-5",
"lr_scheduler=torch.optim.lr_scheduler.ConstantLR",
"lr_scheduler.factor=1.0",
"~lr_scheduler.num_warmup_steps",
"log_every_n_steps=1",
"clip_grad_norm=100",
] + dummy_alpaca_dataset_config()
Expand Down
59 changes: 58 additions & 1 deletion torchtune/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from typing import Any, Callable, Dict, Set, Type, Union

import torch

from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.optim.lr_scheduler import LRScheduler
from torchtune.utils.logging import get_logger

_log: logging.Logger = get_logger()
Expand Down Expand Up @@ -91,6 +91,7 @@ class OptimizerInBackwardWrapper:

def __init__(self, optim_map: Dict[str, torch.optim.Optimizer]):
self.optim_map = optim_map
self.lr_scheduler = None

def state_dict(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -136,6 +137,62 @@ def get_optim_key(self, key: str) -> Any:
"""
return list(self.optim_map.values())[0].param_groups[0][key]

def set_lr_scheduler(self, lr_scheduler: LRScheduler) -> None:
"""
Sets the learning rate scheduler and modifies its step method to update all optimizers.

Args:
lr_scheduler (LRScheduler): The learning rate scheduler to use.
"""
self.lr_scheduler = lr_scheduler
original_step = self.lr_scheduler.step

def custom_step(epoch=None):
if epoch is None:
original_step()
else:
original_step(epoch)
new_lr = self.lr_scheduler.get_last_lr()[0]
for opt in self.optim_map.values():
for param_group in opt.param_groups:
param_group["lr"] = new_lr

self.lr_scheduler.step = custom_step

def step_lr_scheduler(self, epoch: int = None):
"""
Steps the learning rate scheduler if it exists.

Args:
epoch (int, optional): The current epoch number. Defaults to None.

Raises:
RuntimeError: If the LR scheduler has not been set.
"""
if self.lr_scheduler:
self.lr_scheduler.step(epoch)
else:
raise RuntimeError(
"LR scheduler has not been set. Call set_lr_scheduler first."
)

def get_last_lr(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
def get_last_lr(self):
def get_last_lr(self) -> float:

Have you correctly configured your pre-commit hooks? Our linters should pick this up

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot 2024-09-09 at 11 29 12 PM

I got pre-commit all passed with this. Kinda strange. I installed it according to CONTRIBUTING.md

Copy link
Contributor

@ebsmothers ebsmothers Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you correctly configured your pre-commit hooks? Our linters should pick this up

I think there are merge conflicts, so linters will not run in CI until you resolve them. Also for missing type hints idk if flake8 actually catches those anyways? (Not positive though)

"""
Gets the last learning rate from the scheduler if it exists.

Returns:
float: The last learning rate.

Raises:
RuntimeError: If the LR scheduler has not been set.
"""
if self.lr_scheduler:
return self.lr_scheduler.get_last_lr()[0]
else:
raise RuntimeError(
"LR scheduler has not been set. Call set_lr_scheduler first."
)


def create_optim_in_bwd_wrapper(
model: torch.nn.Module, optim_dict: Dict[torch.nn.Parameter, torch.optim.Optimizer]
Expand Down