Skip to content

Add LR Scheduler to single device full finetune #1350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion recipes/configs/llama2/7B_full_low_memory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ batch_size: 2
epochs: 3
optimizer:
_component_: bitsandbytes.optim.PagedAdamW
lr: 2e-5
lr: 1e-5
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
optimizer_in_bwd: True
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
5 changes: 4 additions & 1 deletion recipes/configs/llama3/8B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ batch_size: 2
epochs: 3
optimizer:
_component_: bitsandbytes.optim.PagedAdamW8bit
lr: 2e-5
lr: 1e-5
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
Expand Down
57 changes: 57 additions & 0 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,13 @@ def setup(self, cfg: DictConfig) -> None:
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch

# Setup lr scheduler
self._lr_scheduler = self._setup_lr_scheduler(
cfg_lr_scheduler=cfg.lr_scheduler,
num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)

# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
# if cfg is missing profiler key or if `cfg.profiler.enabled = False`
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
Expand Down Expand Up @@ -422,6 +429,53 @@ def _setup_optimizer(
log.info("Optimizer is initialized.")
return optimizer

def _setup_lr_scheduler(
self,
cfg_lr_scheduler: Optional[DictConfig],
num_training_steps: int,
last_epoch: int,
) -> Optional[Optimizer]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add docstrings here?

"""
Set up the learning rate scheduler based on the provided configuration.
It handles both standard optimization and optimizer-in-backward cases, and supports
schedulers from both torchtune.modules and torch.optim.

Args:
cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration.
num_training_steps (int): The total number of training steps.
last_epoch (int): The index of the last epoch.

Returns:
lr_scheduler (Optional[Optimizer]): The learning rate scheduler.
"""
if cfg_lr_scheduler is None:
log.info(
"No learning rate scheduler configured. Using constant learning rate."
)
return None

if self._optimizer_in_bwd:
# Use the first optimizer from the wrapper to represent the learning rate
optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values()))
else:
# Standard case: use the single optimizer
optimizer = self._optimizer

# Instantiate the learning rate scheduler
lr_scheduler = config.instantiate(
cfg_lr_scheduler,
optimizer,
num_training_steps=num_training_steps,
last_epoch=last_epoch,
)

if self._optimizer_in_bwd:
# Modify the scheduler for optimizer_in_bwd case
self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler)

log.info("Learning rate scheduler is initialized.")
return lr_scheduler

def _setup_data(
self,
cfg_dataset: DictConfig,
Expand Down Expand Up @@ -586,6 +640,9 @@ def train(self) -> None:
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Need to fix `lr_scheduler.step()` before `optimizer.step()` warning
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this comment referring to? : )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a warning when I ran the code. I think it says "Detected call of lr_scheduler.step() before optimizer.step() ". I do not get why it has such warning since the lr_scheduler.step()is afteroptimizer.step()`

if self._lr_scheduler is not None:
self._lr_scheduler.step()
self.global_step += 1

loss_to_log = running_loss.item()
Expand Down
2 changes: 2 additions & 0 deletions tests/recipes/test_full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def _get_test_config_overrides(self):
"max_steps_per_epoch=2",
"optimizer=torch.optim.AdamW",
"optimizer.lr=2e-5",
"lr_scheduler.num_warmup_steps=0",
"lr_scheduler.num_cycles=0",
"log_every_n_steps=1",
"clip_grad_norm=100",
] + dummy_alpaca_dataset_config()
Expand Down
59 changes: 58 additions & 1 deletion torchtune/training/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from typing import Any, Callable, Dict, Set, Type, Union

import torch

from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.optim.lr_scheduler import LRScheduler
from torchtune.utils import get_logger

_log: logging.Logger = get_logger()
Expand Down Expand Up @@ -91,6 +91,7 @@ class OptimizerInBackwardWrapper:

def __init__(self, optim_map: Dict[str, torch.optim.Optimizer]):
self.optim_map = optim_map
self.lr_scheduler = None

def state_dict(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -136,6 +137,62 @@ def get_optim_key(self, key: str) -> Any:
"""
return list(self.optim_map.values())[0].param_groups[0][key]

def set_lr_scheduler(self, lr_scheduler: LRScheduler) -> None:
"""
Sets the learning rate scheduler and modifies its step method to update all optimizers.

Args:
lr_scheduler (LRScheduler): The learning rate scheduler to use.
"""
self.lr_scheduler = lr_scheduler
original_step = self.lr_scheduler.step

def custom_step(epoch=None):
if epoch is None:
original_step()
else:
original_step(epoch)
new_lr = self.lr_scheduler.get_last_lr()[0]
for opt in self.optim_map.values():
for param_group in opt.param_groups:
param_group["lr"] = new_lr

self.lr_scheduler.step = custom_step

def step_lr_scheduler(self, epoch: int = None):
"""
Steps the learning rate scheduler if it exists.

Args:
epoch (int, optional): The current epoch number. Defaults to None.

Raises:
RuntimeError: If the LR scheduler has not been set.
"""
if self.lr_scheduler:
self.lr_scheduler.step(epoch)
else:
raise RuntimeError(
"LR scheduler has not been set. Call set_lr_scheduler first."
)

def get_last_lr(self) -> float:
"""
Gets the last learning rate from the scheduler if it exists.

Returns:
float: The last learning rate.

Raises:
RuntimeError: If the LR scheduler has not been set.
"""
if self.lr_scheduler:
return self.lr_scheduler.get_last_lr()[0]
else:
raise RuntimeError(
"LR scheduler has not been set. Call set_lr_scheduler first."
)


def create_optim_in_bwd_wrapper(
model: torch.nn.Module, optim_dict: Dict[torch.nn.Parameter, torch.optim.Optimizer]
Expand Down
Loading