-
Notifications
You must be signed in to change notification settings - Fork 647
DPO Loss refactor #1197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DPO Loss refactor #1197
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1197
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 55ccd92 with merge base 0ea274f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent refactor. I honestly don't have any concerns here.
As a DPO noob, do you mind adding some details in the docstrings about the training setup for each of these losses? Like what is the policy/reference model, and the core value prop of DPO compared to PPO is that it does not require a reward model.
Also, are we completely removing the kto pair loss?
Yeah, sorry, I forgot to point this out. The reference implementation in TRL seems to differ enough (it uses a separate trainer with a different dataset format), and I couldn't see tests for the KTO loss in the original DPO PR. I personally would re-consider this when we have tests against a reference implementation. |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #1197 +/- ##
==========================================
+ Coverage 68.63% 68.66% +0.02%
==========================================
Files 214 215 +1
Lines 9687 9734 +47
==========================================
+ Hits 6649 6684 +35
- Misses 3038 3050 +12 ☔ View full report in Codecov by Sentry. |
Add docs to help users determine which models will benefit from float8 training.
Context
What is the purpose of this PR? Is it to
I was initially looking at this refactor to add SimPO into the DPO recipe, but got slightly carried away. I think we might need a separate SimPO recipe down the line anyway, since DPO-style losses use a reference model, whereas SimPO doesn't.
This PR refactors the DPO loss module into separate classes for each of the loss types it supported. Each separate loss is now documented with a reference to its corresponding paper, a little intuition about how it works, and comes with a corresponding unit test.
Test plan
pre-commit install
)pytest tests
pytest tests -m integration_test