-
Notifications
You must be signed in to change notification settings - Fork 648
Add LR Scheduler to single device full finetune #1350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
b29b573
e13e593
3a002ef
ecd4141
e786b4c
7d30028
1e40e2b
631391d
22a2174
90cbfd6
69f4ca6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -267,6 +267,13 @@ def setup(self, cfg: DictConfig) -> None: | |
self._steps_per_epoch = self.max_steps_per_epoch | ||
self.global_step = self.epochs_run * self._steps_per_epoch | ||
|
||
# Setup lr scheduler | ||
self._lr_scheduler = self._setup_lr_scheduler( | ||
cfg_lr_scheduler=cfg.lr_scheduler, | ||
num_training_steps=self.total_epochs * self._steps_per_epoch, | ||
last_epoch=self.global_step - 1, | ||
) | ||
|
||
# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method) | ||
# if cfg is missing profiler key or if `cfg.profiler.enabled = False` | ||
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) | ||
|
@@ -422,6 +429,53 @@ def _setup_optimizer( | |
log.info("Optimizer is initialized.") | ||
return optimizer | ||
|
||
def _setup_lr_scheduler( | ||
self, | ||
cfg_lr_scheduler: Optional[DictConfig], | ||
num_training_steps: int, | ||
last_epoch: int, | ||
) -> Optional[Optimizer]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add docstrings here? |
||
""" | ||
Set up the learning rate scheduler based on the provided configuration. | ||
It handles both standard optimization and optimizer-in-backward cases, and supports | ||
schedulers from both torchtune.modules and torch.optim. | ||
|
||
Args: | ||
cfg_lr_scheduler (Optional[DictConfig]): The learning rate scheduler configuration. | ||
num_training_steps (int): The total number of training steps. | ||
last_epoch (int): The index of the last epoch. | ||
|
||
Returns: | ||
lr_scheduler (Optional[Optimizer]): The learning rate scheduler. | ||
""" | ||
if cfg_lr_scheduler is None: | ||
log.info( | ||
"No learning rate scheduler configured. Using constant learning rate." | ||
) | ||
return None | ||
|
||
if self._optimizer_in_bwd: | ||
# Use the first optimizer from the wrapper to represent the learning rate | ||
optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values())) | ||
else: | ||
# Standard case: use the single optimizer | ||
optimizer = self._optimizer | ||
|
||
# Instantiate the learning rate scheduler | ||
lr_scheduler = config.instantiate( | ||
cfg_lr_scheduler, | ||
optimizer, | ||
num_training_steps=num_training_steps, | ||
last_epoch=last_epoch, | ||
) | ||
|
||
if self._optimizer_in_bwd: | ||
# Modify the scheduler for optimizer_in_bwd case | ||
self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler) | ||
|
||
log.info("Learning rate scheduler is initialized.") | ||
return lr_scheduler | ||
|
||
def _setup_data( | ||
self, | ||
cfg_dataset: DictConfig, | ||
|
@@ -586,6 +640,9 @@ def train(self) -> None: | |
self._optimizer.step() | ||
self._optimizer.zero_grad(set_to_none=True) | ||
|
||
# Need to fix `lr_scheduler.step()` before `optimizer.step()` warning | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this comment referring to? : ) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a warning when I ran the code. I think it says "Detected call of |
||
if self._lr_scheduler is not None: | ||
self._lr_scheduler.step() | ||
self.global_step += 1 | ||
|
||
loss_to_log = running_loss.item() | ||
|
Uh oh!
There was an error while loading. Please reload this page.