Skip to content

Add LR Scheduler to single device full finetune #1350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion recipes/configs/llama2/7B_full_low_memory.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,10 @@ batch_size: 2
epochs: 3
optimizer:
_component_: bitsandbytes.optim.PagedAdamW
lr: 2e-5
lr: 1e-5
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw your commit about needing to do this to pass the test. Which test is that referring to? I don't think we wanna make this change to our configs unless we know the results will be better, plus it's kinda weird to say we are using a scheduler "with warmup" when the warmup is just one step

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw your commit about needing to do this to pass the test. Which test is that referring to? I don't think we wanna make this change to our configs unless we know the results will be better, plus it's kinda weird to say we are using a scheduler "with warmup" when the warmup is just one step

It should be test_loss of tests/recipes/test_full_finetune_single_device.py. So when I ran the test, the increase of the warm up steps would result higher loss tolerance of loss values than expected loss discrepancy.
torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 )
I suspected that might due to warm up steps. However, if I change the tolerance the value might be too high. It might even reach 1e-1 level. However when warmup step is large we are effectively have a much smaller lr. Observed from my own training this might cause initial loss to have large differences. As we can see from my own training the initial loss have like 2e-1 difference but the difference would shrink with more training steps after warmup finished. I havent find a good way to resolve this though. Do you have any suggestions? I hope it isnt some bug from my code causing the tolerance difference thats' why I tested it with 1 warmup step...

Screenshot 2024-09-05 at 7 00 44 AM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@user074 ah I see, I think I was a bit confused about the exact issue but it makes sense now. Note that our recipe tests actually use test-specific overrides; that way we can decouple the configs we use for real fine-tunes from what we use for testing. You can see the overrides here -- I would suggest just adding an override like lr_scheduler=torch.optim.ConstantLr there. Hopefully that'll fix things!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know how to deal with 'num_warmup_steps' when I set to ConstantLr. It seems to give error due to this extra argument. But I couldnt find a way to remove this keyword argument from test script. model_config does not contain it explicitly so I cannot remove.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could try remove the flag directly in in the test like so
https://pytorch.org/torchtune/stable/deep_dives/configs.html#removing-config-fields

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed you changed the learning rate in the config? Could that be it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed you changed the learning rate in the config? Could that be it?

Not really since the test recipe overrides the lr to 2e-5. I also just tested the 2e-5 as the original lr. It is the same problem with loss discrepancy.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks to be working on the build machines. Could you push your latest changes, even if they break the tests, and I can check it out?

Copy link
Contributor

@ebsmothers ebsmothers Sep 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree it'd be helpful to just push the changes and see what's happening. But also if using ConstantLR make sure you're setting factor=1 (otherwise with the defaults your LR will be 1/3 the value you want on iteration 0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree it'd be helpful to just push the changes and see what's happening. But also if using ConstantLR make sure you're setting factor=1 (otherwise with the defaults your LR will be 1/3 the value you want on iteration 0)

Thanks. Actually factor=1 is the case that caused this. After fixing it now it can pass the tests

optimizer_in_bwd: True
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
Expand Down
5 changes: 4 additions & 1 deletion recipes/configs/llama3/8B_full_single_device.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ batch_size: 2
epochs: 3
optimizer:
_component_: bitsandbytes.optim.PagedAdamW8bit
lr: 2e-5
lr: 1e-5
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 1
loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
Expand Down
37 changes: 37 additions & 0 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,13 @@ def setup(self, cfg: DictConfig) -> None:
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch

# Setup lr scheduler
self._lr_scheduler = self._setup_lr_scheduler(
cfg_lr_scheduler=cfg.lr_scheduler,
num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)

# Set up profiler, returns DummyProfiler (nullcontext object with no-op `step` method)
# if cfg is missing profiler key or if `cfg.profiler.enabled = False`
self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None))
Expand Down Expand Up @@ -429,6 +436,34 @@ def _setup_optimizer(
log.info("Optimizer is initialized.")
return optimizer

def _setup_lr_scheduler(
self,
cfg_lr_scheduler: DictConfig,
num_training_steps: int,
last_epoch: int,
) -> Optional[Optimizer]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add docstrings here?

if self._optimizer_in_bwd:
# Use the first optimizer from the wrapper to represent the learning rate
optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values()))
else:
# Standard case: use the single optimizer
optimizer = self._optimizer

# Instantiate the learning rate scheduler
lr_scheduler = config.instantiate(
cfg_lr_scheduler,
optimizer,
num_training_steps=num_training_steps,
last_epoch=last_epoch,
)

if self._optimizer_in_bwd:
# Modify the scheduler for optimizer_in_bwd case
self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler)

log.info("Learning rate scheduler is initialized.")
return lr_scheduler

def _setup_data(
self,
cfg_dataset: DictConfig,
Expand Down Expand Up @@ -589,6 +624,8 @@ def train(self) -> None:
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Need to fix `lr_scheduler.step()` before `optimizer.step()` warning
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this comment referring to? : )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a warning when I ran the code. I think it says "Detected call of lr_scheduler.step() before optimizer.step() ". I do not get why it has such warning since the lr_scheduler.step()is afteroptimizer.step()`

self._lr_scheduler.step()
self.global_step += 1

loss_to_log = running_loss.item()
Expand Down
59 changes: 58 additions & 1 deletion torchtune/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from typing import Any, Callable, Dict, Set, Type, Union

import torch

from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.optim.lr_scheduler import LRScheduler
from torchtune.utils.logging import get_logger

_log: logging.Logger = get_logger()
Expand Down Expand Up @@ -91,6 +91,7 @@ class OptimizerInBackwardWrapper:

def __init__(self, optim_map: Dict[str, torch.optim.Optimizer]):
self.optim_map = optim_map
self.lr_scheduler = None

def state_dict(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -136,6 +137,62 @@ def get_optim_key(self, key: str) -> Any:
"""
return list(self.optim_map.values())[0].param_groups[0][key]

def set_lr_scheduler(self, lr_scheduler: LRScheduler) -> None:
"""
Sets the learning rate scheduler and modifies its step method to update all optimizers.

Args:
lr_scheduler (LRScheduler): The learning rate scheduler to use.
"""
self.lr_scheduler = lr_scheduler
original_step = self.lr_scheduler.step

def custom_step(epoch=None):
if epoch is None:
original_step()
else:
original_step(epoch)
new_lr = self.lr_scheduler.get_last_lr()[0]
for opt in self.optim_map.values():
for param_group in opt.param_groups:
param_group["lr"] = new_lr

self.lr_scheduler.step = custom_step

def step_lr_scheduler(self, epoch: int = None):
"""
Steps the learning rate scheduler if it exists.

Args:
epoch (int, optional): The current epoch number. Defaults to None.

Raises:
RuntimeError: If the LR scheduler has not been set.
"""
if self.lr_scheduler:
self.lr_scheduler.step(epoch)
else:
raise RuntimeError(
"LR scheduler has not been set. Call set_lr_scheduler first."
)

def get_last_lr(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
def get_last_lr(self):
def get_last_lr(self) -> float:

Have you correctly configured your pre-commit hooks? Our linters should pick this up

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot 2024-09-09 at 11 29 12 PM

I got pre-commit all passed with this. Kinda strange. I installed it according to CONTRIBUTING.md

Copy link
Contributor

@ebsmothers ebsmothers Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you correctly configured your pre-commit hooks? Our linters should pick this up

I think there are merge conflicts, so linters will not run in CI until you resolve them. Also for missing type hints idk if flake8 actually catches those anyways? (Not positive though)

"""
Gets the last learning rate from the scheduler if it exists.

Returns:
float: The last learning rate.

Raises:
RuntimeError: If the LR scheduler has not been set.
"""
if self.lr_scheduler:
return self.lr_scheduler.get_last_lr()[0]
else:
raise RuntimeError(
"LR scheduler has not been set. Call set_lr_scheduler first."
)


def create_optim_in_bwd_wrapper(
model: torch.nn.Module, optim_dict: Dict[torch.nn.Parameter, torch.optim.Optimizer]
Expand Down
Loading