Skip to content

FSDP Llama3 wrapping improvements for full finetune #865

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
May 7, 2024
Merged

Conversation

rohan-varma
Copy link
Member

@rohan-varma rohan-varma commented Apr 25, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Changelog

This PR primarily seeks to improve memory efficiency specifically for llama3 full distributed training and enable a distributed finetune in 4x 24GB of memory. We do this with a new FSDP wrapping policy that wraps the token embedding and output projections. These are much larger for llama3 due to the increased vocab size, so sharding them across GPUs has more of an effect.

  • Added new API to retrieve memory efficient FSDP wrapping policy. To control whether the memory efficient wrapping policy is retrieved, we introduce a new flag memory_efficient_wrapping in our configs. Currently, this is only set to True for llama3 distributed full finetuning. As follow up work, we'll investigate other workloads with this wrapping and enable where beneficial.
  • Added appropriate unittests
  • Integrated in full_finetune_distributed. Did a bit of study on potential integration into LoRA, but memory savings were less pronounced there - this needs further investigations.

Test plan

  • Added unittests.

This PR only seeks to ship improvements to llama3 training.

Docs

image
image
image

Full finetune

Run command for full finetune: CUDA_VISIBLE_DEVICES=0,3,6,7 tune run --nproc_per_node 4 full_finetune_distributed --config llama3/8B_full batch_size=1

  • With this PR: peak_memory_active:20.830272 peak_memory_alloc:19.085376 peak_memory_reserved:23.699914752, 1.06it/s
  • Without this PR: peak_memory_active:24.170446336 peak_memory_alloc:21.988057088 peak_memory_reserved:27.908898816, 1.08it/s
  • About 13% savings in allocated memory, 15% in memory resereved. This allows us to get a 4x 24GB finetune.
  • NOTE: A previous version of this PR also wrapped the token embedding and output projection in their own activation checkpointing units, but this is not needed. Vocabulary size is increased, but activations generated are proportional to sequences, not vocab size, so checkpointing these won't help more for llama3 compared to llama2. A quick study checkpointing these versus not shows roughly the same memory efficiency. In particular, with checkpointing the token embedding and output proj, we achieve peak_memory_active:20.880037376 peak_memory_alloc:19.135141376 peak_memory_reserved:24.13821952, while without it, we achieve the numbers reported above: they are very comparable.

Copy link

pytorch-bot bot commented Apr 25, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/865

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 6124dd5 with merge base 7d05579 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 25, 2024
@rohan-varma rohan-varma marked this pull request as draft April 25, 2024 00:49
def _llama3_ac_policy(module: nn.Module, recurse: bool, modules_to_wrap, **kwargs):
# Label that output_proj should be wrapped individually.
if isinstance(module, modules.TransformerDecoder):
targ = module.output.module
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should probably use a helper function called _get_fsdp_wrapped_module or something that is aware of whether module.output is wrapped in FSDP or not and intelligently unwraps it.

@rohan-varma rohan-varma changed the title [WIP] Llama wrapping improvements AC and FSDP Llama3 wrapping improvements Apr 26, 2024
@rohan-varma rohan-varma marked this pull request as ready for review April 26, 2024 18:36
@rohan-varma rohan-varma requested a review from ebsmothers April 26, 2024 18:36
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this makes sense to me and the memory savings for full finetune are great. My main question is around whether model type is the most natural way to expose this feature.

There's nothing about this functionality that is unique to Llama3, it's just that it proves most beneficial there. By doing things this way we are kinda making the decision for people that only Llama3 should have this feature, and supporting other models with large vocab sizes will then require updating the wrapping internals instead of just flipping a config. I know we had discussed Gemma as one specific model where this is a challenge, but I wonder if we can do an assert on the backend to raise an error if the model class is not TransformerDecoder as we would expect.

def _llama3_ac_policy(module: nn.Module, recurse: bool, modules_to_wrap, **kwargs):
# Label that output_proj should be wrapped individually.
if isinstance(module, modules.TransformerDecoder):
targ = _maybe_fsdp_unwrap(module.output)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, dumb question here: is this just because we want to support both non-FSDP and FSDP models? Cause rn we are only integrating into distributed recipe, in which case we could (not saying should) assume FSDP

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this isn't really because we want to support FSDP and non FSDP models, but we do happen to have this support with this function.

The reason we need this unwrap call is if the model, when we wrap in AC, has already been wrapped in FSDP. Then we want to AC wrap module.output, but just accessing module.output when the model is FSDP wrapped may give us back the FSDP class. So we further unwrap to retrieve the local nn.module.

@ebsmothers
Copy link
Contributor

CUDA_VISIBLE_DEVICES=0,3,6,7

👀

@rohan-varma
Copy link
Member Author

Overall this makes sense to me and the memory savings for full finetune are great. My main question is around whether model type is the most natural way to expose this feature.

There's nothing about this functionality that is unique to Llama3, it's just that it proves most beneficial there. By doing things this way we are kinda making the decision for people that only Llama3 should have this feature, and supporting other models with large vocab sizes will then require updating the wrapping internals instead of just flipping a config. I know we had discussed Gemma as one specific model where this is a challenge, but I wonder if we can do an assert on the backend to raise an error if the model class is not TransformerDecoder as we would expect.

@ebsmothers I definitely agree here. My proposal on a way forward would be to decouple the model type from the checkpointer and offer it as a general accessor to determine which model is being trained - there's currently no easy way to go about this. And I'd like for this change to be especially focused on llama3 (so the initial rollout of these policies will only be done for llama3). As follow up work we should enable for llama2 and investigate other models, although verifying these improvements are currently a long-running process and should ideally be done iteratively and/or by multiple folks, IMO

@joecummings could you chime in on ModelType for this sort of use case and if you happen to have any, alternative ways to achieve this sort of gating based on specific models here?

@joecummings
Copy link
Contributor

@rohan-varma I could totally be missing something here, but why can't we include embedding in the modules to wrap within the config for Llama3, rather than tie this directly to ModelType? That way, you can expand this to any new models that have this large embedding space, which is starting to become more popular.

cc: @ebsmothers

@musabgultekin
Copy link
Contributor

musabgultekin commented Apr 29, 2024

I haven't tested this but will this allow 70B on 8x80GB? I was only able to full-fine tune 70B with cpu offloading

@rohan-varma
Copy link
Member Author

@rohan-varma I could totally be missing something here, but why can't we include embedding in the modules to wrap within the config for Llama3, rather than tie this directly to ModelType? That way, you can expand this to any new models that have this large embedding space, which is starting to become more popular.

cc: @ebsmothers

@joecummings This is because modules_to_wrap is not configurable right now, and configuring it would be a little tricky (i.e. we'd have to parse the string like torch.nn.Embedding and make it a class)

@rohan-varma
Copy link
Member Author

I haven't tested this but will this allow 70B on 8x80GB? I was only able to full-fine tune 70B with cpu offloading

This unfortunately won't allow 70B on 8x80GB from my experiments without CPU offloading, but can do a bit more testing. our current thinking is to enable full finetune for 70B models with CPU offload.

@rohan-varma rohan-varma requested a review from ebsmothers April 30, 2024 22:49
@rohan-varma rohan-varma changed the title FSDP Llama3 wrapping improvements FSDP Llama3 wrapping improvements for full finetune May 3, 2024
have not been verified and may not see the same improvements.
Returns:
FSDPPolicyType: Wrapping policy that can be passed into ``FullyShardedDataParallel`` as the ``auto_wrap_policy``
argument. Please see documentation for `torchtune.utils.FSDPPolicyType` for additional details.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to link directly to this docstr? @ebsmothers or @NicolasHug maybe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can try this? ref

:const:`~torchtune.utils.FSDPPolicyType`

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping this

@rohan-varma
Copy link
Member Author

Thanks for adding this!

I don't think I fully understand this:

New AC wrapping policy that checkpoints the token embedding and output projections as well. Similar reason to above - they generate larger activations so it would be useful to not store those in memory.

Irrespective of the size of the vocab, the output of the embedding table would just depend on the sequence length? So why does this have anything to do with the vocab size? Or am I misunderstanding?

@kartikayk Thanks for the feedback and the review! You're totally right that this doesn' t have anything to do with the vocab size and this was an oversight on my end. I verified that if we remove the modified AC wrapping, we don't change anything about the memory improvements we're shippping here. So this PR is now only limited to FSDP wrapping changes.

Also added a bunch more documentation to FSDPPolicyType to clearly explain it to the user and link back to FSDP wrapping docs where useful. thanks!

have not been verified and may not see the same improvements.
Returns:
FSDPPolicyType: Wrapping policy that can be passed into ``FullyShardedDataParallel`` as the ``auto_wrap_policy``
argument. Please see documentation for `torchtune.utils.FSDPPolicyType` for additional details.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can try this? ref

:const:`~torchtune.utils.FSDPPolicyType`

"""
A default policy for wrapping Llama-3 style models for full finetuning using FSDP. Specifically,
this will wrap the model's token embedding and output projection into their own FSDP units to
maximize memory savings. After this is done, model will also be hierarchically wrapped
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we be a little bit more explicit about why this maximizes memory savings here? (At least say that this helps because the embedding and output layers are quite large)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

def llama3_wrap(module: nn.Module, recurse: bool, **kwargs):
# Label that output_proj should be wrapped individually.
if isinstance(module, modules.TransformerDecoder):
module.output._wrap = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re the transformer decoder changes, I think the main thing is that the if isinstance check may need to be generalized (since now e.g. our TransformerDecoderLM or TransformerDecoderClassifier classes will both have output layers). But realistically I think the main use case will still be for when we're projecting to vocab_size, so maybe just directly replacing with TransformerDecoderLM (or whatever we're calling it) will be sufficient here.

return ModuleWrapPolicy(modules_to_wrap)


def _llama3_full_fsdp_wrap_policy(modules_to_wrap: Set[Type]) -> FSDPPolicyType:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also bump: are we still naming this based on llama3?

@rohan-varma rohan-varma requested a review from ebsmothers May 6, 2024 19:28
Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK a few more comments but overall no major concerns from my side

def llama3_wrap(module: nn.Module, recurse: bool, **kwargs):
# Label that output_proj should be wrapped individually.
if isinstance(module, modules.TransformerDecoder):
module.output._wrap = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yeah to clarify there are likely to be some inbound changes to the TransformerDecoder class itself. Basically we will probably have

(a) a base class without any output layer (so just token embeddings, decoder layers, and norm),
(b) a class equivalent to our existing TransformerDecoder, but with (a) as a component + the usual output projection to vocab size, and
(c) a separate classifier composing (a) with a more general head module.

Personally I don't think you should optimize for changes that haven't landed yet, but we should at least have an idea of how we'd need to change the util to support this.


Args:
memory_efficient_fsdp_wrap (bool): If ``True``, will also wrap embedding and output projection layers
with FSDP.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you need to indent or something. This is rendering weirdly in the live docs (can be seen in the screenshot you attached for the summary)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh oops, I changed this for LoRA on L226 but didn't change it here, thanks!

have not been verified and may not see the same improvements.
Returns:
FSDPPolicyType: Wrapping policy that can be passed into ``FullyShardedDataParallel`` as the ``auto_wrap_policy``
argument. Please see documentation for `torchtune.utils.FSDPPolicyType` for additional details.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping this

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 24.32432% with 28 lines in your changes are missing coverage. Please review.

Project coverage is 26.67%. Comparing base (a978956) to head (6124dd5).
Report is 29 commits behind head on main.

Files Patch % Lines
tests/torchtune/utils/test_distributed.py 21.05% 15 Missing ⚠️
torchtune/utils/_distributed.py 27.77% 13 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main     #865       +/-   ##
===========================================
- Coverage   66.39%   26.67%   -39.72%     
===========================================
  Files         155      172       +17     
  Lines        6484     7182      +698     
===========================================
- Hits         4305     1916     -2389     
- Misses       2179     5266     +3087     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@rohan-varma rohan-varma merged commit fa1392b into main May 7, 2024
andrewor14 added a commit to andrewor14/torchtune that referenced this pull request May 14, 2024
@joecummings joecummings deleted the wrapping_imp branch May 14, 2024 19:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants