-
Notifications
You must be signed in to change notification settings - Fork 647
LoRA FSDP2 recipe migration #1517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1517
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit b0db0d6 with merge base 66590b4 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1517 +/- ##
===========================================
+ Coverage 27.22% 71.13% +43.91%
===========================================
Files 286 285 -1
Lines 13828 13743 -85
===========================================
+ Hits 3764 9776 +6012
+ Misses 10064 3967 -6097 ☔ View full report in Codecov by Sentry. |
Would you mind setting expandable_segments=True and rerunning one of those tests with huge memory alloc increase? I wonder if the increase is just memory fragmentation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments, mostly questions, looks great overall.
set_trainable_params, | ||
validate_state_dict_for_lora, | ||
validate_missing_and_unexpected_for_lora, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did we rename this? To me, reading this without seeing the previous name I don't realize it's referring to the state dict.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We didn't rename this, they are two different functions. The old one was previously necessary on FSDP1 because there were extra FSDP prefixes floating around in the keys. As a result we couldn't just do string comparison on missing and unexpected as returned by load_state_dict with strict=False. We instead had to actually call model.state_dict() which used a bunch of extra memory.
TLDR: this version is better and we should delete the old one once we're able to.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still prefer the other name but I understand. Could it be something like validate_lora_missing_and_unexpected_keys?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah we can do this separately when deprecating the other one
recipes/lora_finetune_distributed.py
Outdated
Return True for layers.i and False for all other module names | ||
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot | ||
""" | ||
s_list = s.split(".") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will miss the convolution layer in the vision transformer. Though I think besides that all of our transformers so far use "layers" by convention.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I mention this in the comment just above _is_layer_fqn
.. we can directly use the module type instead of the name, it just takes extra effort to do properly for AC. We can make that change once we're ready to enalbe a LoRA recipe using ViT
recipes/lora_finetune_distributed.py
Outdated
# Shard transformer decoder layers (or AC-wrapped versions) | ||
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper) | ||
# But directly using the name is more concise | ||
def _is_layer_fqn(s: str) -> bool: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: why not make this fsdp_shard_conditions and throw away the n variable instead of passing this into a lambda? (also not sure what fqn is)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you just mean something like this?
def _is_layer_fqn(s: str, m: nn.Module) -> bool:
# same as before, m is unused
fsdp_shard_conditions = _is_layer_fqn
And fqn = fully-qualified name
else None | ||
) | ||
), | ||
with training.set_default_dtype(self._dtype), self._device: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is going on here, can you add a comment before the context? This is just initializing the model? I feel we should standardize how we do this type of init code. For KVCache we're going very OO, I think we have functional utils for other things, and here we're directly doing it all in the recipe. I think I favor the second option, but I would prefer consistency.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah this is just initializing various parameters. Functional util is fine but I may punt on it for now and just add a comment if you're good with that
from torch.optim import Optimizer | ||
from torch.utils.data import DataLoader, DistributedSampler | ||
from torchtune import config, modules, training, utils | ||
from torchtune.data import padded_collate_sft | ||
from torchtune.datasets import ConcatDataset | ||
from torchtune.modules.peft import ( | ||
DoRALinear, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Promise that when we go from 2 -> 3 XLinear types that we'll update the PeFT builder functions instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
K I promise
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps | ||
self._fsdp_sharding_strategy = torch.distributed.fsdp.ShardingStrategy[ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does FSDP2 not have this concept anymore? Or is it just unpacked from the config somewhere?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See this comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. This is annoying since now if you don't want sharding you have to use a completely different set of modules. My only remaining comment then is that people know that FSDP has 3 levels (also coming from DeepSpeed) and I don't think it's clear that Full shard and opt shard are controlled by reshard_after_forward. Could we make the config name different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
want to call out gradient and opt state are shaded regardless of reshard_after_forward=true/false. reshard_after_forward mainly wants to avoid all-gather in the backward
self.adapter_params = get_adapter_params(model) | ||
set_trainable_params(model, self.adapter_params) | ||
if lora_weights_state_dict: | ||
lora_missing, lora_unexpected = training.load_from_full_model_state_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know this isn't new, but "load_from_full_model_state_dict" is a very ambiguous name. What does that imply it's doing over "load_state_dict"?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's loading the full model weights on meta device and then sharding them into DTensors. But I agree, we can update the name to be a bit more explicit
device=self._device, | ||
) | ||
|
||
if intermediate_checkpoint: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you port over the ability to save adapter weights only without consolidating the full model checkpoint first from the state dict?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we do that today? Iiuc it seems like we don't? ref
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I forgot that's only in the 405B PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall it looks great! I think that there is an opportunity to make full and lora distributed recipes a bit more similar to each other, but the recipe itself is good.
I recommend using this: https://text-compare.com/
Putting the code from lora and full side by side, and see the opportunities where they look different. We should probably do it for single vs distributed too
torch.distributed.fsdp.ShardingStrategy | ||
(https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy). | ||
For example, in your config, simply pass ``fsdp_sharding=NO_SHARD`` for DDP. | ||
- FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Traning on CPU is not |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add cpu_offloading? or thats not mature enough yet?
Are there any FSDP strategies that the user can control? e.g. hybrid, full_sharded?
This is how its stated in full_distributed
- FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
is supported via the ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
DDP is currently not supported. Training on CPU is not supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we add cpu_offloading? or thats not mature enough yet?
My goal here is to hit parity with the existing LoRA recipe, not to add any new features. We can consider CPU offload as a follow-up if we like, but we haven't really received any requests for it yet (and ofc it's not as valuable for LoRA as it is for full finetune).
Are there any FSDP strategies that the user can control? e.g. hybrid, full_sharded?
The sharding_strategy
from FSDP1 is no longer supported in FSDP2. You can see this comment explaining things
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isnt the feature already present?
We do this in model_setup:
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
self._model = self._setup_model( | ||
cfg_model=cfg.model, | ||
enable_activation_checkpointing=cfg.enable_activation_checkpointing, | ||
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False), | ||
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a note: i think that we should try to keep the recipes as aligned as possible, to decrease cognitive load of having to deal with multiple recipes. In this case, i think that distributed one should remove ac_mode + ac_option
this is full distributed
def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
custom_sharded_layers: Optional[List[str]],
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
model_state_dict: Dict[str, Any],
ac_mode: Optional[str] = None,
ac_option: Optional[int] = None,
) -> nn.Module:
this is lora distributed
def _setup_model(
self,
cfg_model: DictConfig,
enable_activation_checkpointing: bool,
fsdp_cpu_offload: bool,
reshard_after_forward: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
cfg_fsdp: Optional[Union[DictConfig, None]] = None,
) -> nn.Module:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree. We added selective activation checkpointing a while back but never fully integrated it. It does provide increased flexibility over our current activation checkpointing. But imo we need to either migrate to it or remove it.
@@ -139,12 +134,10 @@ def __init__(self, cfg: DictConfig) -> None: | |||
self.total_epochs = cfg.epochs | |||
self.max_steps_per_epoch = cfg.max_steps_per_epoch | |||
self.global_step = 0 | |||
self._resume_from_checkpoint = cfg.resume_from_checkpoint | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not related to this line: in the init, we are missing this check that exists in full distributed
if (
cfg.get("fsdp_cpu_offload", False)
and cfg.optimizer.get("fused", False)
and not utils.torch_version_ge("2.4.0")
):
raise RuntimeError(
"Using fused optimizer on CPU is only supported in PyTorch nightly."
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should remove that check from the full distributed recipe then, since everything is >= 2.4 now (will not tackle that in this PR, that can be done separately)
log.info("Compiling loss with torch.compile...") | ||
self._loss_fn = torch.compile(self._loss_fn, backend=backend) | ||
log.info("Loss is initialized.") | ||
if self._is_rank_zero: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
full distributed doesnt have _is_rank_zero. Maybe we shouldnt update it in this PR, but we may forget it later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You mean this line specifically, right? I think we should add it cause rn there are duplicate logs here. But yeah I don't wanna do it in this PR
self.adapter_params = get_adapter_params(model) | ||
set_trainable_params(model, self.adapter_params) | ||
if lora_weights_state_dict: | ||
lora_missing, lora_unexpected = training.load_from_full_model_state_dict( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
offtopic: we need to update how we check for modules to finetune with lora. It just checks things like "if 'v_proj' in string", which wont raise any error if the user defines "v_projection", for example, because v_proj is in v_projection. I found it out because i was doing "'v_proj'", so v_proj was in the string, but there is no module 'v_proj', with the quotes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I follow this point, do you have a specific example in mind? We have the LORA_ATTN_MODULES
Literal and could modify that to add some validation
@@ -577,27 +578,31 @@ def save_checkpoint( | |||
- Merged weights with key MODEL_KEY | |||
- Adapter weights with key ADAPTER_KEY | |||
- Relevant recipe state if training is not complete | |||
- If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that sounds like a nice feature to have. Did we get feedback to remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry idk why this got removed. Saving adapter weights only should still be supported
@@ -722,6 +726,7 @@ def train(self) -> None: | |||
|
|||
# Compute loss | |||
loss = self._loss_fn(logits, labels) | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not related to this part of the code:
we are missing this, from full_distributed
if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
training.set_torch_num_threads()
and we have hardcoded, which doesnt sound like we would want to
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So set_torch_num_threads
was some heuristic added way back when, can see the convo here. Unless @janeyx99 has a better idea I am happy to just copy this over from the full finetune recipe. But given the small number of trainable params for LoRA it may not have much impact anyways.
Re TORCH_NCCL_AVOID_RECORD_STREAMS
, I'll defer to @weifengpy who will know best.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To follow up on the second point, I chatted offline with @weifengpy and we should be OK to remove the TORCH_NCCL_AVOID_RECORD_STREAMS
setting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no looker
Migrating our main LoRA recipe onto FSDP2 and deleting the dev version.
Note: many recipes are showing a slight increase in peak memory (especially reserved memory) as a result of this move, but by removing the separate wrapping of LoRA(/DoRA) weights we get pretty big toks/sec improvements. This will also make our lives easier wrt compile (see #1445).
Test plan
Aside from green CI, run the following recipes:
Note that for Gemma and Qwen2 there is no baseline with compile, so we compare FSDP1 vs FSDP2 (no compile) vs FSDP2 with compile.
Separately, ran
and
to confirm that CPU offload and saving adapter weights only both work.
Llama2 13B LoRA
On main (baseline)
On this PR
Results
+40% on peak reserved memory, +8ish% toks/sec
Note: this is the one model that seems to show a net regression, but given the major perf improvements for other models I think this is secondary. Also, activate and allocated memory are at parity, so it's possible there is actually a relatively quick fix for this one.
Llama2 7B QLoRA (with compile)
On main (baseline)
On this PR
Results
~3% increase in peak reserved memory, ~400% increase in toks/sec
Llama3 8B DoRA
On main (baseline)
On this PR
Results
No change to peak reserved memory, ~125% increase in toks/sec
Llama3.1 70B LoRA
On main (baseline)
On this PR
Results
~10% increase in memory ~30-50% increase in toks/sec
Mistral 7B LoRA
On main (baseline)
On this PR
Results
~10% increase in peak reserved memory, ~150% increase in toks/sec
Gemma 2B LoRA (+compile)
On main (baseline)
On this PR
For the compiled version, just add
compile=True
Results
~50% increase in peak reserved memory, ~30% increase in toks/sec
Note: similar to Llama2 13B, the delta in peak allocated and active memory is much smaller. For this model, the issue may simply be that the output projection contains a much larger percentage of model weights (relative to transformer layers) than all the other models in our library. One solution is to increase the number of chunks in our output + CE chunking.. e.g. by increasing the number of chunks from 8 to 128 we can achieve parity with FSDP1. However, I consider such a change out of scope for this PR.
Compile also doesn't seem to help much here, we can look into that as a follow-up.
Phi3 Mini LoRA
On main (baseline)
On this PR
Results
~5% increase in peak reserved memory, ~90% increase in toks/sec
Qwen2 0.5B LoRA
On main (baseline)
On this PR
Results
Without compile: ~15% increase in peak reserved memory, 100%+ increase in toks/sec
With compile: ~70% reduction in peak reserved memory, ~250% increase in toks/sec (over baseline)