-
Notifications
You must be signed in to change notification settings - Fork 648
Add LR Scheduler to single device full finetune #1350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
b29b573
e13e593
3a002ef
ecd4141
e786b4c
7d30028
1e40e2b
631391d
22a2174
90cbfd6
69f4ca6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -272,6 +272,13 @@ def setup(self, cfg: DictConfig) -> None: | |
self._steps_per_epoch = self.max_steps_per_epoch | ||
self.global_step = self.epochs_run * self._steps_per_epoch | ||
|
||
# Setup lr scheduler | ||
self._lr_scheduler = self._setup_lr_scheduler( | ||
cfg_lr_scheduler=cfg.lr_scheduler, | ||
user074 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_training_steps=self.total_epochs * self._steps_per_epoch, | ||
last_epoch=self.global_step - 1, | ||
) | ||
|
||
# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) | ||
# if cfg is missing profiler key or if `cfg.profiler.enabled = False` | ||
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) | ||
|
@@ -429,6 +436,34 @@ def _setup_optimizer( | |
log.info("Optimizer is initialized.") | ||
return optimizer | ||
|
||
def _setup_lr_scheduler( | ||
self, | ||
cfg_lr_scheduler: DictConfig, | ||
num_training_steps: int, | ||
last_epoch: int, | ||
) -> Optional[Optimizer]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add docstrings here? |
||
if self._optimizer_in_bwd: | ||
# Use the first optimizer from the wrapper to represent the learning rate | ||
optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values())) | ||
else: | ||
# Standard case: use the single optimizer | ||
optimizer = self._optimizer | ||
|
||
# Instantiate the learning rate scheduler | ||
lr_scheduler = config.instantiate( | ||
cfg_lr_scheduler, | ||
optimizer, | ||
num_training_steps=num_training_steps, | ||
last_epoch=last_epoch, | ||
) | ||
|
||
if self._optimizer_in_bwd: | ||
# Modify the scheduler for optimizer_in_bwd case | ||
self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler) | ||
|
||
log.info("Learning rate scheduler is initialized.") | ||
return lr_scheduler | ||
|
||
def _setup_data( | ||
self, | ||
cfg_dataset: DictConfig, | ||
|
@@ -589,6 +624,8 @@ def train(self) -> None: | |
self._optimizer.step() | ||
self._optimizer.zero_grad(set_to_none=True) | ||
|
||
# Need to fix `lr_scheduler.step()` before `optimizer.step()` warning | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this comment referring to? : ) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a warning when I ran the code. I think it says "Detected call of |
||
self._lr_scheduler.step() | ||
self.global_step += 1 | ||
|
||
loss_to_log = running_loss.item() | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -10,12 +10,12 @@ | |||||
from typing import Any, Callable, Dict, Set, Type, Union | ||||||
|
||||||
import torch | ||||||
|
||||||
from torch import nn | ||||||
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( | ||||||
apply_activation_checkpointing, | ||||||
) | ||||||
from torch.distributed.fsdp.wrap import ModuleWrapPolicy | ||||||
from torch.optim.lr_scheduler import LRScheduler | ||||||
from torchtune.utils.logging import get_logger | ||||||
|
||||||
_log: logging.Logger = get_logger() | ||||||
|
@@ -91,6 +91,7 @@ class OptimizerInBackwardWrapper: | |||||
|
||||||
def __init__(self, optim_map: Dict[str, torch.optim.Optimizer]): | ||||||
self.optim_map = optim_map | ||||||
self.lr_scheduler = None | ||||||
|
||||||
def state_dict(self) -> Dict[str, Any]: | ||||||
""" | ||||||
|
@@ -136,6 +137,62 @@ def get_optim_key(self, key: str) -> Any: | |||||
""" | ||||||
return list(self.optim_map.values())[0].param_groups[0][key] | ||||||
|
||||||
def set_lr_scheduler(self, lr_scheduler: LRScheduler) -> None: | ||||||
""" | ||||||
Sets the learning rate scheduler and modifies its step method to update all optimizers. | ||||||
|
||||||
Args: | ||||||
lr_scheduler (LRScheduler): The learning rate scheduler to use. | ||||||
""" | ||||||
self.lr_scheduler = lr_scheduler | ||||||
original_step = self.lr_scheduler.step | ||||||
|
||||||
def custom_step(epoch=None): | ||||||
if epoch is None: | ||||||
original_step() | ||||||
else: | ||||||
original_step(epoch) | ||||||
new_lr = self.lr_scheduler.get_last_lr()[0] | ||||||
for opt in self.optim_map.values(): | ||||||
for param_group in opt.param_groups: | ||||||
param_group["lr"] = new_lr | ||||||
|
||||||
self.lr_scheduler.step = custom_step | ||||||
|
||||||
def step_lr_scheduler(self, epoch: int = None): | ||||||
""" | ||||||
Steps the learning rate scheduler if it exists. | ||||||
|
||||||
Args: | ||||||
epoch (int, optional): The current epoch number. Defaults to None. | ||||||
|
||||||
Raises: | ||||||
RuntimeError: If the LR scheduler has not been set. | ||||||
""" | ||||||
if self.lr_scheduler: | ||||||
self.lr_scheduler.step(epoch) | ||||||
else: | ||||||
raise RuntimeError( | ||||||
"LR scheduler has not been set. Call set_lr_scheduler first." | ||||||
) | ||||||
|
||||||
def get_last_lr(self): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit:
Suggested change
Have you correctly configured your pre-commit hooks? Our linters should pick this up There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think there are merge conflicts, so linters will not run in CI until you resolve them. Also for missing type hints idk if flake8 actually catches those anyways? (Not positive though) |
||||||
""" | ||||||
Gets the last learning rate from the scheduler if it exists. | ||||||
|
||||||
Returns: | ||||||
float: The last learning rate. | ||||||
|
||||||
Raises: | ||||||
RuntimeError: If the LR scheduler has not been set. | ||||||
""" | ||||||
if self.lr_scheduler: | ||||||
return self.lr_scheduler.get_last_lr()[0] | ||||||
else: | ||||||
raise RuntimeError( | ||||||
"LR scheduler has not been set. Call set_lr_scheduler first." | ||||||
) | ||||||
|
||||||
|
||||||
def create_optim_in_bwd_wrapper( | ||||||
model: torch.nn.Module, optim_dict: Dict[torch.nn.Parameter, torch.optim.Optimizer] | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw your commit about needing to do this to pass the test. Which test is that referring to? I don't think we wanna make this change to our configs unless we know the results will be better, plus it's kinda weird to say we are using a scheduler "with warmup" when the warmup is just one step
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be
test_loss
oftests/recipes/test_full_finetune_single_device.py
. So when I ran the test, the increase of the warm up steps would result higher loss tolerance of loss values than expected loss discrepancy.torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 )
I suspected that might due to warm up steps. However, if I change the tolerance the value might be too high. It might even reach 1e-1 level. However when warmup step is large we are effectively have a much smaller lr. Observed from my own training this might cause initial loss to have large differences. As we can see from my own training the initial loss have like 2e-1 difference but the difference would shrink with more training steps after warmup finished. I havent find a good way to resolve this though. Do you have any suggestions? I hope it isnt some bug from my code causing the tolerance difference thats' why I tested it with 1 warmup step...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@user074 ah I see, I think I was a bit confused about the exact issue but it makes sense now. Note that our recipe tests actually use test-specific overrides; that way we can decouple the configs we use for real fine-tunes from what we use for testing. You can see the overrides here -- I would suggest just adding an override like
lr_scheduler=torch.optim.ConstantLr
there. Hopefully that'll fix things!There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know how to deal with 'num_warmup_steps' when I set to ConstantLr. It seems to give error due to this extra argument. But I couldnt find a way to remove this keyword argument from test script. model_config does not contain it explicitly so I cannot remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could try remove the flag directly in in the test like so
https://pytorch.org/torchtune/stable/deep_dives/configs.html#removing-config-fields
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed you changed the learning rate in the config? Could that be it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really since the test recipe overrides the lr to 2e-5. I also just tested the 2e-5 as the original lr. It is the same problem with loss discrepancy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks to be working on the build machines. Could you push your latest changes, even if they break the tests, and I can check it out?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree it'd be helpful to just push the changes and see what's happening. But also if using ConstantLR make sure you're setting
factor=1
(otherwise with the defaults your LR will be 1/3 the value you want on iteration 0)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Actually factor=1 is the case that caused this. After fixing it now it can pass the tests