-
Notifications
You must be signed in to change notification settings - Fork 648
Add LR Scheduler to single device full finetune #1350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
b29b573
e13e593
3a002ef
ecd4141
e786b4c
7d30028
1e40e2b
631391d
22a2174
90cbfd6
69f4ca6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -272,6 +272,13 @@ def setup(self, cfg: DictConfig) -> None: | |
self._steps_per_epoch = self.max_steps_per_epoch | ||
self.global_step = self.epochs_run * self._steps_per_epoch | ||
|
||
# Setup lr scheduler | ||
self._lr_scheduler = self._setup_lr_scheduler( | ||
cfg_lr_scheduler=cfg.lr_scheduler, | ||
num_training_steps=self.total_epochs * self._steps_per_epoch, | ||
last_epoch=self.global_step - 1, | ||
) | ||
|
||
# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) | ||
# if cfg is missing profiler key or if `cfg.profiler.enabled = False` | ||
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) | ||
|
@@ -429,6 +436,45 @@ def _setup_optimizer( | |
log.info("Optimizer is initialized.") | ||
return optimizer | ||
|
||
def _setup_lr_scheduler( | ||
self, | ||
cfg_lr_scheduler: DictConfig, | ||
num_training_steps: int, | ||
last_epoch: int, | ||
) -> Optional[Optimizer]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add docstrings here? |
||
if self._optimizer_in_bwd: | ||
# Use the first optimizer from the wrapper to represent the learning rate | ||
optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values())) | ||
else: | ||
# Standard case: use the single optimizer | ||
optimizer = self._optimizer | ||
|
||
# Check if the lr_scheduler component is from torchtune.modules or torch.optim | ||
component = cfg_lr_scheduler.get("_component_", None) | ||
|
||
# Conditionally instantiate the scheduler based on the component | ||
if "torchtune.modules" in component: | ||
# Instantiate the learning rate scheduler | ||
lr_scheduler = config.instantiate( | ||
cfg_lr_scheduler, | ||
optimizer, | ||
num_training_steps=num_training_steps, | ||
last_epoch=last_epoch, | ||
) | ||
else: | ||
lr_scheduler = config.instantiate( | ||
cfg_lr_scheduler, | ||
optimizer, | ||
last_epoch=last_epoch, | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah is this a consequence of using In that case (and sorry for changing my mind on this), is there any reason we can't just make Separately though it's clear that our LR scheduler instantiation is not flexible enough. We should figure out how to better support schedulers from torch.optim moving forward (though that's a problem for another PR). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. I was thinking to make it more flexible also and support other types of torch.optim.lr_scheduler. I find that torchtune.modules.lr_scheduler only has get_cosine_schedule_with_warmup and num_training_steps input would not be compatible with other torch.optim.lr_scheduler. I can check and work with lr_scheduleer as None by the weekend. It is mainly due to the testing config that it always would input num_training_steps even when there is no args for that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Updated |
||
|
||
if self._optimizer_in_bwd: | ||
# Modify the scheduler for optimizer_in_bwd case | ||
self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler) | ||
|
||
log.info("Learning rate scheduler is initialized.") | ||
return lr_scheduler | ||
|
||
def _setup_data( | ||
self, | ||
cfg_dataset: DictConfig, | ||
|
@@ -589,6 +635,8 @@ def train(self) -> None: | |
self._optimizer.step() | ||
self._optimizer.zero_grad(set_to_none=True) | ||
|
||
# Need to fix `lr_scheduler.step()` before `optimizer.step()` warning | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this comment referring to? : ) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a warning when I ran the code. I think it says "Detected call of |
||
self._lr_scheduler.step() | ||
self.global_step += 1 | ||
|
||
loss_to_log = running_loss.item() | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -10,12 +10,12 @@ | |||||
from typing import Any, Callable, Dict, Set, Type, Union | ||||||
|
||||||
import torch | ||||||
|
||||||
from torch import nn | ||||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | ||||||
apply_activation_checkpointing, | ||||||
) | ||||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy | ||||||
from torch.optim.lr_scheduler import LRScheduler | ||||||
from torchtune.utils.logging import get_logger | ||||||
|
||||||
_log: logging.Logger = get_logger() | ||||||
|
@@ -91,6 +91,7 @@ class OptimizerInBackwardWrapper: | |||||
|
||||||
def __init__(self, optim_map: Dict[str, torch.optim.Optimizer]): | ||||||
self.optim_map = optim_map | ||||||
self.lr_scheduler = None | ||||||
|
||||||
def state_dict(self) -> Dict[str, Any]: | ||||||
""" | ||||||
|
@@ -136,6 +137,62 @@ def get_optim_key(self, key: str) -> Any: | |||||
""" | ||||||
return list(self.optim_map.values())[0].param_groups[0][key] | ||||||
|
||||||
def set_lr_scheduler(self, lr_scheduler: LRScheduler) -> None: | ||||||
""" | ||||||
Sets the learning rate scheduler and modifies its step method to update all optimizers. | ||||||
|
||||||
Args: | ||||||
lr_scheduler (LRScheduler): The learning rate scheduler to use. | ||||||
""" | ||||||
self.lr_scheduler = lr_scheduler | ||||||
original_step = self.lr_scheduler.step | ||||||
|
||||||
def custom_step(epoch=None): | ||||||
if epoch is None: | ||||||
original_step() | ||||||
else: | ||||||
original_step(epoch) | ||||||
new_lr = self.lr_scheduler.get_last_lr()[0] | ||||||
for opt in self.optim_map.values(): | ||||||
for param_group in opt.param_groups: | ||||||
param_group["lr"] = new_lr | ||||||
|
||||||
self.lr_scheduler.step = custom_step | ||||||
|
||||||
def step_lr_scheduler(self, epoch: int = None): | ||||||
""" | ||||||
Steps the learning rate scheduler if it exists. | ||||||
|
||||||
Args: | ||||||
epoch (int, optional): The current epoch number. Defaults to None. | ||||||
|
||||||
Raises: | ||||||
RuntimeError: If the LR scheduler has not been set. | ||||||
""" | ||||||
if self.lr_scheduler: | ||||||
self.lr_scheduler.step(epoch) | ||||||
else: | ||||||
raise RuntimeError( | ||||||
"LR scheduler has not been set. Call set_lr_scheduler first." | ||||||
) | ||||||
|
||||||
def get_last_lr(self): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Suggested change
Have you correctly configured your pre-commit hooks? Our linters should pick this up There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think there are merge conflicts, so linters will not run in CI until you resolve them. Also for missing type hints idk if flake8 actually catches those anyways? (Not positive though) |
||||||
""" | ||||||
Gets the last learning rate from the scheduler if it exists. | ||||||
|
||||||
Returns: | ||||||
float: The last learning rate. | ||||||
|
||||||
Raises: | ||||||
RuntimeError: If the LR scheduler has not been set. | ||||||
""" | ||||||
if self.lr_scheduler: | ||||||
return self.lr_scheduler.get_last_lr()[0] | ||||||
else: | ||||||
raise RuntimeError( | ||||||
"LR scheduler has not been set. Call set_lr_scheduler first." | ||||||
) | ||||||
|
||||||
|
||||||
def create_optim_in_bwd_wrapper( | ||||||
model: torch.nn.Module, optim_dict: Dict[torch.nn.Parameter, torch.optim.Optimizer] | ||||||
|
Uh oh!
There was an error while loading. Please reload this page.