-
Notifications
You must be signed in to change notification settings - Fork 645
Add LR Scheduler to single device full finetune #1350
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add LR Scheduler to single device full finetune #1350
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1350
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 69f4ca6 with merge base 7af77c7 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Hi @user074! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
Hi @user074 thanks for this PR! Sorry, somehow the review here seems to have fallen through the cracks a bit. Overall the implementation looks reasonable though. One request: for the case where optimizer_in_bwd is True, can we offload the logic into a standalone utility (e.g. in
(This is just one suggestion, open to any alternative ideas if you think this can be done more clearly another way) Also I think some tests are failing, you may need to update the test configs here. Once you do that you can also run the tests locally via |
Thanks @ebsmothers. I will update it by the weekend. |
@user074 good question. Actually the implementation of the distributed version may be more straightforward given that we don't have the option to fuse optimizer and backward there. We may be able to help out with testing, if you share a repro command for the loss curves in the PR summary we can use that for the distributed recipe to ensure we get similar behavior there. |
lr: 1e-5 | ||
lr_scheduler: | ||
_component_: torchtune.modules.get_cosine_schedule_with_warmup | ||
num_warmup_steps: 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw your commit about needing to do this to pass the test. Which test is that referring to? I don't think we wanna make this change to our configs unless we know the results will be better, plus it's kinda weird to say we are using a scheduler "with warmup" when the warmup is just one step
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I saw your commit about needing to do this to pass the test. Which test is that referring to? I don't think we wanna make this change to our configs unless we know the results will be better, plus it's kinda weird to say we are using a scheduler "with warmup" when the warmup is just one step
It should be test_loss
of tests/recipes/test_full_finetune_single_device.py
. So when I ran the test, the increase of the warm up steps would result higher loss tolerance of loss values than expected loss discrepancy.
torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 )
I suspected that might due to warm up steps. However, if I change the tolerance the value might be too high. It might even reach 1e-1 level. However when warmup step is large we are effectively have a much smaller lr. Observed from my own training this might cause initial loss to have large differences. As we can see from my own training the initial loss have like 2e-1 difference but the difference would shrink with more training steps after warmup finished. I havent find a good way to resolve this though. Do you have any suggestions? I hope it isnt some bug from my code causing the tolerance difference thats' why I tested it with 1 warmup step...

There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@user074 ah I see, I think I was a bit confused about the exact issue but it makes sense now. Note that our recipe tests actually use test-specific overrides; that way we can decouple the configs we use for real fine-tunes from what we use for testing. You can see the overrides here -- I would suggest just adding an override like lr_scheduler=torch.optim.ConstantLr
there. Hopefully that'll fix things!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you know how to deal with 'num_warmup_steps' when I set to ConstantLr. It seems to give error due to this extra argument. But I couldnt find a way to remove this keyword argument from test script. model_config does not contain it explicitly so I cannot remove.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could try remove the flag directly in in the test like so
https://pytorch.org/torchtune/stable/deep_dives/configs.html#removing-config-fields
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed you changed the learning rate in the config? Could that be it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed you changed the learning rate in the config? Could that be it?
Not really since the test recipe overrides the lr to 2e-5. I also just tested the 2e-5 as the original lr. It is the same problem with loss discrepancy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks to be working on the build machines. Could you push your latest changes, even if they break the tests, and I can check it out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree it'd be helpful to just push the changes and see what's happening. But also if using ConstantLR make sure you're setting factor=1
(otherwise with the defaults your LR will be 1/3 the value you want on iteration 0)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree it'd be helpful to just push the changes and see what's happening. But also if using ConstantLR make sure you're setting
factor=1
(otherwise with the defaults your LR will be 1/3 the value you want on iteration 0)
Thanks. Actually factor=1 is the case that caused this. After fixing it now it can pass the tests
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1350 +/- ##
===========================================
- Coverage 69.33% 25.71% -43.62%
===========================================
Files 305 305
Lines 15892 15961 +69
===========================================
- Hits 11018 4104 -6914
- Misses 4874 11857 +6983 ☔ View full report in Codecov by Sentry. |
torchtune/utils/memory.py
Outdated
"LR scheduler has not been set. Call set_lr_scheduler first." | ||
) | ||
|
||
def get_last_lr(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
def get_last_lr(self): | |
def get_last_lr(self) -> float: |
Have you correctly configured your pre-commit hooks? Our linters should pick this up
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have you correctly configured your pre-commit hooks? Our linters should pick this up
I think there are merge conflicts, so linters will not run in CI until you resolve them. Also for missing type hints idk if flake8 actually catches those anyways? (Not positive though)
cfg_lr_scheduler: DictConfig, | ||
num_training_steps: int, | ||
last_epoch: int, | ||
) -> Optional[Optimizer]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add docstrings here?
Hi @user074 just wanted to check in if you're still working on this. If not let me know and one of us can help to take it over the finish line (I think from here there should be very little left) |
Just updated suggested changes. Sorry for the delay |
@user074 no problem. Can you also merge latest changes from main into your branch? There are merge conflicts and we can't run CI until those get resolved |
if "torchtune.modules" in component: | ||
# Instantiate the learning rate scheduler | ||
lr_scheduler = config.instantiate( | ||
cfg_lr_scheduler, | ||
optimizer, | ||
num_training_steps=num_training_steps, | ||
last_epoch=last_epoch, | ||
) | ||
else: | ||
lr_scheduler = config.instantiate( | ||
cfg_lr_scheduler, | ||
optimizer, | ||
last_epoch=last_epoch, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah is this a consequence of using ConstantLR
in the test? I think it's a bit confusing to condition on whether the component is coming from torchtune.modules or not. In general there will be things from torch.optim that don't match the signature in the else statement, and there will probably also (in the future) be things from torchtune.modules that don't match the signature in the if statement.
In that case (and sorry for changing my mind on this), is there any reason we can't just make lr_scheduler
optional? Then if it's not passed we just continue to use the usual constant learning rate defined in the optimizer and set self._lr_scheduler=None
in the recipe (and in the training loop check that it's not None prior to stepping). This will make the recipe test a non-issue and keep the setup clean.
Separately though it's clear that our LR scheduler instantiation is not flexible enough. We should figure out how to better support schedulers from torch.optim moving forward (though that's a problem for another PR).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I was thinking to make it more flexible also and support other types of torch.optim.lr_scheduler. I find that torchtune.modules.lr_scheduler only has get_cosine_schedule_with_warmup and num_training_steps input would not be compatible with other torch.optim.lr_scheduler.
I can check and work with lr_scheduleer as None by the weekend. It is mainly due to the testing config that it always would input num_training_steps even when there is no args for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated self._lr_scheduler=None
Also I modified the testing. I still use the cosine instead of constant but effectively change it equivalent to constant by setting num_cycles=0
. It should be a bit more elegant imo
@@ -586,6 +640,9 @@ def train(self) -> None: | |||
self._optimizer.step() | |||
self._optimizer.zero_grad(set_to_none=True) | |||
|
|||
# Need to fix `lr_scheduler.step()` before `optimizer.step()` warning |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is this comment referring to? : )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a warning when I ran the code. I think it says "Detected call of lr_scheduler.step()
before optimizer.step()
". I do not get why it has such warning since the lr_scheduler.step()is after
optimizer.step()`
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I think this is much cleaner! I left one more comment, can you run a test with no LR scheduler (running manually is fine) just to make sure nothing is broken after applying my suggestion? After that I think this should be good to go
Co-authored-by: ebsmothers <[email protected]>
I just ran a llama 2 7B without lr_scheduler and it works. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks for adding this!
Co-authored-by: ebsmothers <[email protected]>
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses: #1308
Changelog
I added LR scheduler to single device full fine tune (full_finetune_single_device.py). For example we can add following to the config of 8B_full_single_device.yaml:
If _optimizer_in_bwd, then I create a dummy optimizer in order to update
If not then I just follow lora_funetune_single_device.py
Be aware that 2e-5 of get_cosine_schedule_with_warmup with PagedAdamW8bit would cause instability, and I suspect it might due to momentum. 1e-5 would be stable in my case.
Test plan
Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)
pre-commit install
)pytest tests
pytest tests -m integration_test
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:
torchtune/torchtune/modules/vision_transformer.py
Line 285 in 6a7951f
Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models