Skip to content

Add LR Scheduler to single device full finetune #1350

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

user074
Copy link
Contributor

@user074 user074 commented Aug 15, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses: #1308

Changelog

I added LR scheduler to single device full fine tune (full_finetune_single_device.py). For example we can add following to the config of 8B_full_single_device.yaml:

optimizer:
  _component_: bitsandbytes.optim.PagedAdamW8bit
  weight_decay: 0.01
  lr: 1e-5

lr_scheduler:
  _component_: torchtune.modules.get_cosine_schedule_with_warmup
  num_warmup_steps: 500
  

If _optimizer_in_bwd, then I create a dummy optimizer in order to update
If not then I just follow lora_funetune_single_device.py

Be aware that 2e-5 of get_cosine_schedule_with_warmup with PagedAdamW8bit would cause instability, and I suspect it might due to momentum. 1e-5 would be stable in my case.

Screenshot 2024-08-14 at 8 31 14 PM

Test plan

Please make sure to do each of the following if applicable to your PR. (If you're not sure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.)

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:


Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models

  • I did not change any public API;
  • I have added an example to docs or docstrings;

Copy link

pytorch-bot bot commented Aug 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1350

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 69f4ca6 with merge base 7af77c7 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot
Copy link

Hi @user074!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@facebook-github-bot
Copy link

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 15, 2024
@SalmanMohammadi SalmanMohammadi self-requested a review August 19, 2024 22:48
@ebsmothers
Copy link
Contributor

Hi @user074 thanks for this PR! Sorry, somehow the review here seems to have fallen through the cracks a bit. Overall the implementation looks reasonable though. One request: for the case where optimizer_in_bwd is True, can we offload the logic into a standalone utility (e.g. in utils/memory.py so that it's colocated with the OptimizerInBackwardWrapper? Maybe a signature like

def get_lr_scheduler_for_optim_in_bwd(
	lr_scheduler: LRScheduler,
	optim_in_bwd_wrapper: OptimizerInBackwardWrapper
) -> LRScheduler:

(This is just one suggestion, open to any alternative ideas if you think this can be done more clearly another way)

Also I think some tests are failing, you may need to update the test configs here. Once you do that you can also run the tests locally via pytest tests -m integration_test (do let me know if you need any other help on this though)

@user074
Copy link
Contributor Author

user074 commented Aug 28, 2024

Thanks @ebsmothers. I will update it by the weekend.
Additional note: I can try to do the same implementation for the distributed version but i do not have enough GPUs to test it out. Do you have any suggestions on how to test to ensure it works for the distributed version?

@ebsmothers
Copy link
Contributor

Thanks @ebsmothers. I will update it by the weekend. Additional note: I can try to do the same implementation for the distributed version but i do not have enough GPUs to test it out. Do you have any suggestions on how to test to ensure it works for the distributed version?

@user074 good question. Actually the implementation of the distributed version may be more straightforward given that we don't have the option to fuse optimizer and backward there. We may be able to help out with testing, if you share a repro command for the loss curves in the PR summary we can use that for the distributed recipe to ensure we get similar behavior there.

lr: 1e-5
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw your commit about needing to do this to pass the test. Which test is that referring to? I don't think we wanna make this change to our configs unless we know the results will be better, plus it's kinda weird to say we are using a scheduler "with warmup" when the warmup is just one step

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw your commit about needing to do this to pass the test. Which test is that referring to? I don't think we wanna make this change to our configs unless we know the results will be better, plus it's kinda weird to say we are using a scheduler "with warmup" when the warmup is just one step

It should be test_loss of tests/recipes/test_full_finetune_single_device.py. So when I ran the test, the increase of the warm up steps would result higher loss tolerance of loss values than expected loss discrepancy.
torch.testing.assert_close( loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 )
I suspected that might due to warm up steps. However, if I change the tolerance the value might be too high. It might even reach 1e-1 level. However when warmup step is large we are effectively have a much smaller lr. Observed from my own training this might cause initial loss to have large differences. As we can see from my own training the initial loss have like 2e-1 difference but the difference would shrink with more training steps after warmup finished. I havent find a good way to resolve this though. Do you have any suggestions? I hope it isnt some bug from my code causing the tolerance difference thats' why I tested it with 1 warmup step...

Screenshot 2024-09-05 at 7 00 44 AM

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@user074 ah I see, I think I was a bit confused about the exact issue but it makes sense now. Note that our recipe tests actually use test-specific overrides; that way we can decouple the configs we use for real fine-tunes from what we use for testing. You can see the overrides here -- I would suggest just adding an override like lr_scheduler=torch.optim.ConstantLr there. Hopefully that'll fix things!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know how to deal with 'num_warmup_steps' when I set to ConstantLr. It seems to give error due to this extra argument. But I couldnt find a way to remove this keyword argument from test script. model_config does not contain it explicitly so I cannot remove.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could try remove the flag directly in in the test like so
https://pytorch.org/torchtune/stable/deep_dives/configs.html#removing-config-fields

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed you changed the learning rate in the config? Could that be it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed you changed the learning rate in the config? Could that be it?

Not really since the test recipe overrides the lr to 2e-5. I also just tested the 2e-5 as the original lr. It is the same problem with loss discrepancy.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks to be working on the build machines. Could you push your latest changes, even if they break the tests, and I can check it out?

Copy link
Contributor

@ebsmothers ebsmothers Sep 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree it'd be helpful to just push the changes and see what's happening. But also if using ConstantLR make sure you're setting factor=1 (otherwise with the defaults your LR will be 1/3 the value you want on iteration 0)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree it'd be helpful to just push the changes and see what's happening. But also if using ConstantLR make sure you're setting factor=1 (otherwise with the defaults your LR will be 1/3 the value you want on iteration 0)

Thanks. Actually factor=1 is the case that caused this. After fixing it now it can pass the tests

@codecov-commenter
Copy link

codecov-commenter commented Sep 8, 2024

Codecov Report

Attention: Patch coverage is 10.81081% with 33 lines in your changes missing coverage. Please review.

Project coverage is 25.71%. Comparing base (7cf656b) to head (69f4ca6).
Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
torchtune/training/memory.py 18.18% 18 Missing ⚠️
recipes/full_finetune_single_device.py 0.00% 15 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1350       +/-   ##
===========================================
- Coverage   69.33%   25.71%   -43.62%     
===========================================
  Files         305      305               
  Lines       15892    15961       +69     
===========================================
- Hits        11018     4104     -6914     
- Misses       4874    11857     +6983     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

"LR scheduler has not been set. Call set_lr_scheduler first."
)

def get_last_lr(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
def get_last_lr(self):
def get_last_lr(self) -> float:

Have you correctly configured your pre-commit hooks? Our linters should pick this up

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Screenshot 2024-09-09 at 11 29 12 PM

I got pre-commit all passed with this. Kinda strange. I installed it according to CONTRIBUTING.md

Copy link
Contributor

@ebsmothers ebsmothers Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you correctly configured your pre-commit hooks? Our linters should pick this up

I think there are merge conflicts, so linters will not run in CI until you resolve them. Also for missing type hints idk if flake8 actually catches those anyways? (Not positive though)

cfg_lr_scheduler: DictConfig,
num_training_steps: int,
last_epoch: int,
) -> Optional[Optimizer]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add docstrings here?

@ebsmothers
Copy link
Contributor

Hi @user074 just wanted to check in if you're still working on this. If not let me know and one of us can help to take it over the finish line (I think from here there should be very little left)

@user074
Copy link
Contributor Author

user074 commented Sep 11, 2024

Hi @user074 just wanted to check in if you're still working on this. If not let me know and one of us can help to take it over the finish line (I think from here there should be very little left)

Hi @user074 just wanted to check in if you're still working on this. If not let me know and one of us can help to take it over the finish line (I think from here there should be very little left)

Just updated suggested changes. Sorry for the delay

@ebsmothers
Copy link
Contributor

@user074 no problem. Can you also merge latest changes from main into your branch? There are merge conflicts and we can't run CI until those get resolved

Comment on lines 458 to 471
if "torchtune.modules" in component:
# Instantiate the learning rate scheduler
lr_scheduler = config.instantiate(
cfg_lr_scheduler,
optimizer,
num_training_steps=num_training_steps,
last_epoch=last_epoch,
)
else:
lr_scheduler = config.instantiate(
cfg_lr_scheduler,
optimizer,
last_epoch=last_epoch,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah is this a consequence of using ConstantLR in the test? I think it's a bit confusing to condition on whether the component is coming from torchtune.modules or not. In general there will be things from torch.optim that don't match the signature in the else statement, and there will probably also (in the future) be things from torchtune.modules that don't match the signature in the if statement.

In that case (and sorry for changing my mind on this), is there any reason we can't just make lr_scheduler optional? Then if it's not passed we just continue to use the usual constant learning rate defined in the optimizer and set self._lr_scheduler=None in the recipe (and in the training loop check that it's not None prior to stepping). This will make the recipe test a non-issue and keep the setup clean.

Separately though it's clear that our LR scheduler instantiation is not flexible enough. We should figure out how to better support schedulers from torch.optim moving forward (though that's a problem for another PR).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I was thinking to make it more flexible also and support other types of torch.optim.lr_scheduler. I find that torchtune.modules.lr_scheduler only has get_cosine_schedule_with_warmup and num_training_steps input would not be compatible with other torch.optim.lr_scheduler.

I can check and work with lr_scheduleer as None by the weekend. It is mainly due to the testing config that it always would input num_training_steps even when there is no args for that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated self._lr_scheduler=None
Also I modified the testing. I still use the cosine instead of constant but effectively change it equivalent to constant by setting num_cycles=0. It should be a bit more elegant imo

@@ -586,6 +640,9 @@ def train(self) -> None:
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Need to fix `lr_scheduler.step()` before `optimizer.step()` warning
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this comment referring to? : )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There was a warning when I ran the code. I think it says "Detected call of lr_scheduler.step() before optimizer.step() ". I do not get why it has such warning since the lr_scheduler.step()is afteroptimizer.step()`

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I think this is much cleaner! I left one more comment, can you run a test with no LR scheduler (running manually is fine) just to make sure nothing is broken after applying my suggestion? After that I think this should be good to go

@user074
Copy link
Contributor Author

user074 commented Oct 4, 2024

Thanks, I think this is much cleaner! I left one more comment, can you run a test with no LR scheduler (running manually is fine) just to make sure nothing is broken after applying my suggestion? After that I think this should be good to go

I just ran a llama 2 7B without lr_scheduler and it works.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for adding this!

@ebsmothers ebsmothers merged commit a8a64ec into pytorch:main Oct 4, 2024
17 checks passed
mori360 pushed a commit to mori360/torchtune that referenced this pull request Oct 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants