Skip to content

[FSDP1] reduce GPU memory usage from 78G instead of 23G #843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 23, 2024

Conversation

weifengpy
Copy link
Contributor

@weifengpy weifengpy commented Apr 22, 2024

In trunk, non-zero ranks have 78G memory during model init with sync_module_states=True. It's calling dist._broadcast_coalesced https://fburl.com/rkq73zyp and use recordstream by default. As a result, GPU memories not immeidately released

this PR set TORCH_NCCL_AVOID_RECORD_STREAMS=1 and it reduced the memory from 78G to 23G

memory profiles during model init without the fix
Screenshot 2024-04-22 at 3 19 45 PM

memory profiles after the fix
memory profiles for 78G with accumulating spikes from sync_module_states=True and param_init_fn
Screenshot 2024-04-22 at 3 19 30 PM

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Copy link

pytorch-bot bot commented Apr 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/843

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9a01b74 with merge base 4044b93 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 22, 2024
@rohan-varma rohan-varma self-requested a review April 22, 2024 22:25
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR and for root causing the high memory consumed by non zero rank @weifengpy! This is a tricky issue we've been discussing quite a bit within torchtune so it's great that we have a root cause and progress towards a fix.

One issue is that we probably can't remove our usage of these flags as they are required for correctness (basically broadcasting the state_dict from rank 0 to all ranks). I wonder if we can root cause this to either of the flags and make the appropriate fix in core.

Also qq if you happen to have dug into either sync_module_states or param_init_fn - do we know if either one is the root cause or both contribute increased memory usage?

@weifengpy weifengpy changed the title [DEBUG][FSDP1] reduce memeory from 78G to 22G on non-zero ranks [DO NOT LAND][FSDP1] pinpoint why FSDP1 memory is 78G instead of 22G on non-zero ranks Apr 22, 2024
@weifengpy
Copy link
Contributor Author

Thanks for this PR and for root causing the high memory consumed by non zero rank @weifengpy! This is a tricky issue we've been discussing quite a bit within torchtune so it's great that we have a root cause and progress towards a fix.

One issue is that we probably can't remove our usage of these flags as they are required for correctness (basically broadcasting the state_dict from rank 0 to all ranks). I wonder if we can root cause this to either of the flags and make the appropriate fix in core.

Yep. I will dig into FSDP internal to see if we can keep broadcasting from rank 0 to other ranks

Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
@weifengpy weifengpy changed the title [DO NOT LAND][FSDP1] pinpoint why FSDP1 memory is 78G instead of 22G on non-zero ranks [FSDP1] reduce GPU memory usage from 78G instead of 23G Apr 22, 2024
Copy link
Member

@rohan-varma rohan-varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks for the root causing! We should definitely enable this flag as the memory savings are huge.

Thinking about our recipe UX, we probably want to avoid adding NCCL specific details / configuration in the recipe itself as we tend to want to keep recipes simple where possible. As a result I'm thinking about other places we can put this environment variable and possible future env configurations we would need.

One idea is that when tune is invoked in distributed setting, tune run itself can set this before dispatching into torchtune. @joecummings what do you think / is this feasible?

Also, an alernative is to just change the backend impl of sync_module_states entirely. I'm attempting to this in pytorch/pytorch#124679 so that future users get the memory savings out of the box instead of through an env variable, what do you think?

@weifengpy
Copy link
Contributor Author

weifengpy commented Apr 23, 2024

Awesome, thanks for the root causing! We should definitely enable this flag as the memory savings are huge.

Thinking about our recipe UX, we probably want to avoid adding NCCL specific details / configuration in the recipe itself as we tend to want to keep recipes simple where possible. As a result I'm thinking about other places we can put this environment variable and possible future env configurations we would need.

One idea is that when tune is invoked in distributed setting, tune run itself can set this before dispatching into torchtune. @joecummings what do you think / is this feasible?

Also, an alernative is to just change the backend impl of sync_module_states entirely. I'm attempting to this in pytorch/pytorch#124679 so that future users get the memory savings out of the box instead of through an env variable, what do you think?

glad to see pytorch/pytorch#124679. @rohan-varma what do you think if I benchmark your PR vs existing sync_module_states on llama shapes ?

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had an offline discussion with @awgu and it makes a lot of sense to merge this. We can do a perf comparison as a follow up and the torchtune team can take an AI to clean the UX up a bit. The memory wins are too compelling and unblocks some ongoing parallel work.

@kartikayk kartikayk merged commit bec7bab into pytorch:main Apr 23, 2024
yinfan98 pushed a commit to yinfan98/sgl-tune-eagle that referenced this pull request May 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants