-
Notifications
You must be signed in to change notification settings - Fork 647
[FSDP1] reduce GPU memory usage from 78G instead of 23G #843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/843
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 9a01b74 with merge base 4044b93 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this PR and for root causing the high memory consumed by non zero rank @weifengpy! This is a tricky issue we've been discussing quite a bit within torchtune so it's great that we have a root cause and progress towards a fix.
One issue is that we probably can't remove our usage of these flags as they are required for correctness (basically broadcasting the state_dict from rank 0 to all ranks). I wonder if we can root cause this to either of the flags and make the appropriate fix in core.
Also qq if you happen to have dug into either sync_module_states
or param_init_fn
- do we know if either one is the root cause or both contribute increased memory usage?
Yep. I will dig into FSDP internal to see if we can keep broadcasting from rank 0 to other ranks |
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks for the root causing! We should definitely enable this flag as the memory savings are huge.
Thinking about our recipe UX, we probably want to avoid adding NCCL specific details / configuration in the recipe itself as we tend to want to keep recipes simple where possible. As a result I'm thinking about other places we can put this environment variable and possible future env configurations we would need.
One idea is that when tune
is invoked in distributed setting, tune run
itself can set this before dispatching into torchtune. @joecummings what do you think / is this feasible?
Also, an alernative is to just change the backend impl of sync_module_states
entirely. I'm attempting to this in pytorch/pytorch#124679 so that future users get the memory savings out of the box instead of through an env variable, what do you think?
glad to see pytorch/pytorch#124679. @rohan-varma what do you think if I benchmark your PR vs existing sync_module_states on llama shapes ? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Had an offline discussion with @awgu and it makes a lot of sense to merge this. We can do a perf comparison as a follow up and the torchtune team can take an AI to clean the UX up a bit. The memory wins are too compelling and unblocks some ongoing parallel work.
In trunk, non-zero ranks have 78G memory during model init with
sync_module_states=True
. It's callingdist._broadcast_coalesced
https://fburl.com/rkq73zyp and use recordstream by default. As a result, GPU memories not immeidately releasedthis PR set
TORCH_NCCL_AVOID_RECORD_STREAMS=1
and it reduced the memory from 78G to 23Gmemory profiles during model init without the fix

memory profiles after the fix

memory profiles for 78G with accumulating spikes from
sync_module_states=True
andparam_init_fn