-
Notifications
You must be signed in to change notification settings - Fork 647
Add per-layer compile support to recipes #1419
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1419
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4678ce9 with merge base 77fbb4f ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, Will!
a couple of thoughts: I personally dont like that we have two flags for compiling, and would prefer to avoid cases like:
compile: False
per_layer_compile: True
IMO, we should either:
- If compile is true, have per_layer_compile by default
- make compile a nested config, like:
compile:
per_layer: True
loss: True
I prefer 1, as I dont think that the user needs this level of control through configs. Any thoughts?
Second:
Do you think you could test it for our distributed recipe that already uses FSDP2? Or should it be another PR?
Would this also enable optimizer in bwd 🤝 compile? |
if self._model_compile: | ||
log.info("Compiling loss with torch.compile...") | ||
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") | ||
self._loss_fn = torch.compile(self._loss_fn, backend=backend) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
chunked CE can only have compile on the CE + upcasting part. If the chunking is compiled with it, it loses the benefit :/
I think we can leave the loss compile outside of the PR, if chunked CE will be the default
edit: scratch that. Will add to Chuncked CE PR something like
loss_fn = instantiate(cfg.loss)
if isinstance(loss_fn, ChunkedCrossEntropy):
loss_fn._cross_entropy.compile()
else:
loss_fn.compile()
Came across this PR. @yf225 Am I seeing memory leak for "qlora-compile-main"? If you are using torch nightly after 20240824, might be the same as what I'm seeing here pytorch/pytorch#134642 |
@gau-nernst yeah same here. I'm not sure what's going on because I've been compiling models with no problems for a while. But as of a couple days ago I also started seeing what looks like a memory leak. Attaching a memory viz I ran on a full finetune; seems like a bunch of memory is not getting freed after each step. |
@@ -230,6 +231,10 @@ def setup(self, cfg: DictConfig) -> None: | |||
) | |||
|
|||
self._loss_fn = config.instantiate(cfg.loss) | |||
if self._model_compile: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: should we rename "model_compile" to just "compile", since this command will control all loss compile and flexattention compile?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a priority, i am fine to keep as it for now, and maybe add a todo to rename
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I think our naming here is not great. Let's do in a follow-up; rather than add a todo in the code I may just create an issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gau-nernst @ebsmothers For the full-model compile memory leak issue, this should now be resolved by reverting pytorch/pytorch#134272. In this PR we've also switched to using per-layer compile (#1419), which also won't have the memory leak issue. |
Enabling per-layer compile for our single-device LoRA, single-device full finetune, and FSDP2 LoRA recipes. FSDP2 full finetune will be done in a follow-up.
Results
All recipes were run with three different configurations: (1) per-layer compile (this PR), (2) full-model compile (i.e.
compile=True
on main), (3) no compile.QLoRA single-device
WandB results
Repro
LoRA
WandB results
Repro
FFT
WandB results
Repro
LoRA FSDP2
Repro only to ensure things aren't broken, will work on perf in a follow-up
Repro
First change dev/llama2/70B_qlora_fsdp2.yaml recipe to this, then run
Also consolidated E2E time over 100 steps and compile time for QLoRA, LoRA, and full finetunes of Llama3 8B on single device. Note that compile on main currently OOMs due to some kind of memory leak, so we don't have E2E times for the models compiled on main.