Skip to content

LoRA FSDP2 recipe migration #1517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 10, 2024
Merged

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Sep 6, 2024

Migrating our main LoRA recipe onto FSDP2 and deleting the dev version.

Note: many recipes are showing a slight increase in peak memory (especially reserved memory) as a result of this move, but by removing the separate wrapping of LoRA(/DoRA) weights we get pretty big toks/sec improvements. This will also make our lives easier wrt compile (see #1445).

Test plan

Aside from green CI, run the following recipes:

  • Llama2 13B LoRA on 4 devices
  • Llama2 7B QLoRA on 2 devices with compile
  • Llama3 8B DoRA on 2 devices
  • Llama3.1 70B LoRA on 8 devices
  • Mistral 7B LoRA on 2 devices
  • Gemma 2B LoRA on 4 devices with compile
  • Phi3 Mini LoRA on 2 devices
  • Qwen2 0.5B LoRA on 2 devices with compile

Note that for Gemma and Qwen2 there is no baseline with compile, so we compare FSDP1 vs FSDP2 (no compile) vs FSDP2 with compile.

Separately, ran

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config phi3/mini_lora \
max_steps_per_epoch=10 gradient_accumulation_steps=2 epochs=1  fsdp_cpu_offload=True

and

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config phi3/mini_lora \
max_steps_per_epoch=10 gradient_accumulation_steps=2 epochs=1  save_adapter_weights_only=True

to confirm that CPU offload and saving adapter weights only both work.

Llama2 13B LoRA

On main (baseline)

tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config llama2/13B_lora metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=testing-1517 metric_logger.name=llama2_13b_lora_baseline max_steps_per_epoch=50 gradient_accumulation_steps=2 log_peak_memory_stats=True

On this PR

tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config llama2/13B_lora metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=testing-1517 metric_logger.name=llama2_13b_lora_new max_steps_per_epoch=50 gradient_accumulation_steps=2 log_peak_memory_stats=True

Results

+40% on peak reserved memory, +8ish% toks/sec

Note: this is the one model that seems to show a net regression, but given the major perf improvements for other models I think this is secondary. Also, activate and allocated memory are at parity, so it's possible there is actually a relatively quick fix for this one.

Screenshot 2024-09-09 at 9 16 58 PM

Llama2 7B QLoRA (with compile)

On main (baseline)

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_fsdp2 --config llama2/7B_qlora metric_logger=torchtune.training.met
ric_logging.WandBLogger metric_logger.project=testing-1517 metric_logger.name=llama2_7b_qlora_baseline max_steps_per_epoch=100 log_peak_memory_stats=True compile=True

On this PR

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama2/7
B_qlora metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=testing-1517 metric_logger.name=llama2_7b_qlora_new max_steps_per_epoch=100 log_p
eak_memory_stats=True compile=True

Results

~3% increase in peak reserved memory, ~400% increase in toks/sec

Screenshot 2024-09-09 at 8 38 30 PM

Llama3 8B DoRA

On main (baseline)

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_fsdp2 --config llama3/8B_dora metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=testing-1517 metric_logger.name=llama3_8b_dora_baseline max_steps_per_epoch=100 log_peak_memory_stats=True

On this PR

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config llama3/8B_dora metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=testing-1517 metric_logger.name=llama3_8b_dora_new max_steps_per_epoch=100 log_peak_memory_stats=True

Results

No change to peak reserved memory, ~125% increase in toks/sec

Screenshot 2024-09-09 at 8 30 19 PM

Llama3.1 70B LoRA

On main (baseline)

tune run --nproc_per_node 8 lora_finetune_distributed --config llama3_1/70B_lora metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=testing-1517 metric_logger.name=llama31_70b_lora_baseline max_steps_per_epoch=100 log_peak_memory_stats=True

On this PR

tune run --nproc_per_node 8 lora_finetune_distributed --config llama3_1/70B_lora metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=testing-1517 metric_logger.name=llama31_70b_lora_new max_steps_per_epoch=100 log_peak_memory_stats=True

Results

~10% increase in memory ~30-50% increase in toks/sec

Screenshot 2024-09-09 at 8 57 18 PM

Mistral 7B LoRA

On main (baseline)

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config mistral/7B_lora metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=testing-1517 metric_logger.name=mistral_7b_lora_baseline max_steps_per_epoch=100 epochs=1 log_peak_memory_stats=True

On this PR

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config mistral/7B_lora metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=testing-1517 metric_logger.name=mistral_7b_lora max_steps_per_epoch=100 epochs=1 log_peak_memory_stats=True

Results

~10% increase in peak reserved memory, ~150% increase in toks/sec

Screenshot 2024-09-09 at 8 33 15 PM

Gemma 2B LoRA (+compile)

On main (baseline)

tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma/2B_lora metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=testing-1517 metric_logger.name=gemma_2b_lora_baseline max_steps_per_epoch=100 epochs=1 log_peak_memory_stats=True

On this PR

tune run --nnodes 1 --nproc_per_node 4 lora_finetune_distributed --config gemma/2B_lora metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=testing-1517 metric_logger.name=gemma_2b_lora_new max_steps_per_epoch=100 epochs=1 log_peak_memory_stats=True

For the compiled version, just add compile=True

Results

~50% increase in peak reserved memory, ~30% increase in toks/sec

Note: similar to Llama2 13B, the delta in peak allocated and active memory is much smaller. For this model, the issue may simply be that the output projection contains a much larger percentage of model weights (relative to transformer layers) than all the other models in our library. One solution is to increase the number of chunks in our output + CE chunking.. e.g. by increasing the number of chunks from 8 to 128 we can achieve parity with FSDP1. However, I consider such a change out of scope for this PR.

Compile also doesn't seem to help much here, we can look into that as a follow-up.

Screenshot 2024-09-10 at 7 39 45 AM

Phi3 Mini LoRA

On main (baseline)

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config phi3/mini_lora metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=testing-1517 metric_logger.name=phi3_mini_lora_baseline max_steps_per_epoch=100 gradient_accumulation_steps=2 epochs=1 log_peak_memory_stats=True

On this PR

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config phi3/mini_lora metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=testing-1517 metric_logger.name=phi3_mini_lora_new max_steps_per_epoch=100 gradient_accumulation_steps=2 epochs=1 log_peak_memory_stats=True

Results

~5% increase in peak reserved memory, ~90% increase in toks/sec

Screenshot 2024-09-09 at 8 42 48 PM

Qwen2 0.5B LoRA

On main (baseline)

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/0.5B_lora metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=testing-1517 metric_logger.name=qwen2_0.5b_lora_baseline max_steps_per_epoch=100 log_peak_memory_stats=True

On this PR

tune run --nnodes 1 --nproc_per_node 2 lora_finetune_distributed --config qwen2/0.5B_lora metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=testing-1517 metric_logger.name=qwen2_0.5b_lora_new max_steps_per_epoch=100 log_peak_memory_stats=True

Results

Without compile: ~15% increase in peak reserved memory, 100%+ increase in toks/sec
With compile: ~70% reduction in peak reserved memory, ~250% increase in toks/sec (over baseline)

Screenshot 2024-09-10 at 7 42 31 AM

Copy link

pytorch-bot bot commented Sep 6, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1517

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b0db0d6 with merge base 66590b4 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 6, 2024
@codecov-commenter
Copy link

codecov-commenter commented Sep 6, 2024

Codecov Report

Attention: Patch coverage is 1.96078% with 50 lines in your changes missing coverage. Please review.

Project coverage is 71.13%. Comparing base (66590b4) to head (b0db0d6).
Report is 95 commits behind head on main.

Files with missing lines Patch % Lines
recipes/lora_finetune_distributed.py 0.00% 50 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #1517       +/-   ##
===========================================
+ Coverage   27.22%   71.13%   +43.91%     
===========================================
  Files         286      285        -1     
  Lines       13828    13743       -85     
===========================================
+ Hits         3764     9776     +6012     
+ Misses      10064     3967     -6097     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ebsmothers ebsmothers marked this pull request as ready for review September 10, 2024 04:39
@ebsmothers ebsmothers changed the title [wip] LoRA FSDP2 recipe migration LoRA FSDP2 recipe migration Sep 10, 2024
@felipemello1
Copy link
Contributor

felipemello1 commented Sep 10, 2024

Would you mind setting expandable_segments=True and rerunning one of those tests with huge memory alloc increase? I wonder if the increase is just memory fragmentation

@ebsmothers
Copy link
Contributor Author

ebsmothers commented Sep 10, 2024

Would you mind setting expandable_segments=True and rerunning one of those tests with huge memory alloc increase? I wonder if the increase is just memory fragmentation

Actually I tried it with Gemma 2B, it didn’t seem to make a difference (I can also try with Llama 13B tomorrow though)

Update: Same story with Llama2 13B, adding expandable segments doesn't decrease the peak reserved memory. However, artificially constraining the memory on the FSDP2 recipe we can see that the reserved memory tracks much more closely with the FSDP1 recipe.

Screenshot 2024-09-10 at 7 46 05 AM

Copy link
Contributor

@pbontrager pbontrager left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments, mostly questions, looks great overall.

set_trainable_params,
validate_state_dict_for_lora,
validate_missing_and_unexpected_for_lora,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we rename this? To me, reading this without seeing the previous name I don't realize it's referring to the state dict.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We didn't rename this, they are two different functions. The old one was previously necessary on FSDP1 because there were extra FSDP prefixes floating around in the keys. As a result we couldn't just do string comparison on missing and unexpected as returned by load_state_dict with strict=False. We instead had to actually call model.state_dict() which used a bunch of extra memory.

TLDR: this version is better and we should delete the old one once we're able to.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still prefer the other name but I understand. Could it be something like validate_lora_missing_and_unexpected_keys?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we can do this separately when deprecating the other one

Return True for layers.i and False for all other module names
Covers sharding for both AC-wrapped and non-AC-wrapped modules in one shot
"""
s_list = s.split(".")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will miss the convolution layer in the vision transformer. Though I think besides that all of our transformers so far use "layers" by convention.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I mention this in the comment just above _is_layer_fqn.. we can directly use the module type instead of the name, it just takes extra effort to do properly for AC. We can make that change once we're ready to enalbe a LoRA recipe using ViT

# Shard transformer decoder layers (or AC-wrapped versions)
# Alternatively we could condition on the module type (TransformerDecoder or CheckpointWrapper)
# But directly using the name is more concise
def _is_layer_fqn(s: str) -> bool:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: why not make this fsdp_shard_conditions and throw away the n variable instead of passing this into a lambda? (also not sure what fqn is)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you just mean something like this?

def _is_layer_fqn(s: str, m: nn.Module) -> bool:
  # same as before, m is unused

fsdp_shard_conditions = _is_layer_fqn

And fqn = fully-qualified name

else None
)
),
with training.set_default_dtype(self._dtype), self._device:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is going on here, can you add a comment before the context? This is just initializing the model? I feel we should standardize how we do this type of init code. For KVCache we're going very OO, I think we have functional utils for other things, and here we're directly doing it all in the recipe. I think I favor the second option, but I would prefer consistency.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this is just initializing various parameters. Functional util is fine but I may punt on it for now and just add a comment if you're good with that

from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import config, modules, training, utils
from torchtune.data import padded_collate_sft
from torchtune.datasets import ConcatDataset
from torchtune.modules.peft import (
DoRALinear,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Promise that when we go from 2 -> 3 XLinear types that we'll update the PeFT builder functions instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

K I promise

self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
self._fsdp_sharding_strategy = torch.distributed.fsdp.ShardingStrategy[
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does FSDP2 not have this concept anymore? Or is it just unpacked from the config somewhere?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. This is annoying since now if you don't want sharding you have to use a completely different set of modules. My only remaining comment then is that people know that FSDP has 3 levels (also coming from DeepSpeed) and I don't think it's clear that Full shard and opt shard are controlled by reshard_after_forward. Could we make the config name different?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

want to call out gradient and opt state are shaded regardless of reshard_after_forward=true/false. reshard_after_forward mainly wants to avoid all-gather in the backward

self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)
if lora_weights_state_dict:
lora_missing, lora_unexpected = training.load_from_full_model_state_dict(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this isn't new, but "load_from_full_model_state_dict" is a very ambiguous name. What does that imply it's doing over "load_state_dict"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's loading the full model weights on meta device and then sharding them into DTensors. But I agree, we can update the name to be a bit more explicit

device=self._device,
)

if intermediate_checkpoint:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you port over the ability to save adapter weights only without consolidating the full model checkpoint first from the state dict?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we do that today? Iiuc it seems like we don't? ref

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I forgot that's only in the 405B PR

Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall it looks great! I think that there is an opportunity to make full and lora distributed recipes a bit more similar to each other, but the recipe itself is good.

I recommend using this: https://text-compare.com/

Putting the code from lora and full side by side, and see the opportunities where they look different. We should probably do it for single vs distributed too

torch.distributed.fsdp.ShardingStrategy
(https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.ShardingStrategy).
For example, in your config, simply pass ``fsdp_sharding=NO_SHARD`` for DDP.
- FSDP. Supported using PyTorch's FSDP APIs. DDP is currently not supported. Traning on CPU is not
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add cpu_offloading? or thats not mature enough yet?

Are there any FSDP strategies that the user can control? e.g. hybrid, full_sharded?

This is how its stated in full_distributed

 - FSDP. Supported using PyTorch's FSDP APIs. CPU offload of parameters, gradients, and optimizer states
            is supported via the ``fsdp_cpu_offload``. Resharding of parameters after the forward pass is
            done by default (corresponding to FULL_SHARD sharding strategy), but can be disabled by setting the config
            ``fsdp_reshard_after_forward`` to False (this corresponds to SHARD_GRAD_OP sharding strategy).
            DDP is currently not supported. Training on CPU is not supported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add cpu_offloading? or thats not mature enough yet?

My goal here is to hit parity with the existing LoRA recipe, not to add any new features. We can consider CPU offload as a follow-up if we like, but we haven't really received any requests for it yet (and ofc it's not as valuable for LoRA as it is for full finetune).

Are there any FSDP strategies that the user can control? e.g. hybrid, full_sharded?

The sharding_strategy from FSDP1 is no longer supported in FSDP2. You can see this comment explaining things

Copy link
Contributor

@felipemello1 felipemello1 Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isnt the feature already present?

We do this in model_setup:
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),

self._model = self._setup_model(
cfg_model=cfg.model,
enable_activation_checkpointing=cfg.enable_activation_checkpointing,
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a note: i think that we should try to keep the recipes as aligned as possible, to decrease cognitive load of having to deal with multiple recipes. In this case, i think that distributed one should remove ac_mode + ac_option

this is full distributed

def _setup_model(
        self,
        cfg_model: DictConfig,
        enable_activation_checkpointing: bool,
        custom_sharded_layers: Optional[List[str]],
        fsdp_cpu_offload: bool,
        reshard_after_forward: bool,
        model_state_dict: Dict[str, Any],
        ac_mode: Optional[str] = None,
        ac_option: Optional[int] = None,
    ) -> nn.Module:

this is lora distributed

def _setup_model(
        self,
        cfg_model: DictConfig,
        enable_activation_checkpointing: bool,
        fsdp_cpu_offload: bool,
        reshard_after_forward: bool,
        base_model_state_dict: Dict[str, Any],
        lora_weights_state_dict: Optional[Dict[str, Any]] = None,
        cfg_fsdp: Optional[Union[DictConfig, None]] = None,
    ) -> nn.Module:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. We added selective activation checkpointing a while back but never fully integrated it. It does provide increased flexibility over our current activation checkpointing. But imo we need to either migrate to it or remove it.

@@ -139,12 +134,10 @@ def __init__(self, cfg: DictConfig) -> None:
self.total_epochs = cfg.epochs
self.max_steps_per_epoch = cfg.max_steps_per_epoch
self.global_step = 0
self._resume_from_checkpoint = cfg.resume_from_checkpoint

Copy link
Contributor

@felipemello1 felipemello1 Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this line: in the init, we are missing this check that exists in full distributed

if (
            cfg.get("fsdp_cpu_offload", False)
            and cfg.optimizer.get("fused", False)
            and not utils.torch_version_ge("2.4.0")
        ):
            raise RuntimeError(
                "Using fused optimizer on CPU is only supported in PyTorch nightly."
            )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should remove that check from the full distributed recipe then, since everything is >= 2.4 now (will not tackle that in this PR, that can be done separately)

log.info("Compiling loss with torch.compile...")
self._loss_fn = torch.compile(self._loss_fn, backend=backend)
log.info("Loss is initialized.")
if self._is_rank_zero:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

full distributed doesnt have _is_rank_zero. Maybe we shouldnt update it in this PR, but we may forget it later.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean this line specifically, right? I think we should add it cause rn there are duplicate logs here. But yeah I don't wanna do it in this PR

self.adapter_params = get_adapter_params(model)
set_trainable_params(model, self.adapter_params)
if lora_weights_state_dict:
lora_missing, lora_unexpected = training.load_from_full_model_state_dict(
Copy link
Contributor

@felipemello1 felipemello1 Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

offtopic: we need to update how we check for modules to finetune with lora. It just checks things like "if 'v_proj' in string", which wont raise any error if the user defines "v_projection", for example, because v_proj is in v_projection. I found it out because i was doing "'v_proj'", so v_proj was in the string, but there is no module 'v_proj', with the quotes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure I follow this point, do you have a specific example in mind? We have the LORA_ATTN_MODULES Literal and could modify that to add some validation

@@ -577,27 +578,31 @@ def save_checkpoint(
- Merged weights with key MODEL_KEY
- Adapter weights with key ADAPTER_KEY
- Relevant recipe state if training is not complete
- If the `self._save_adapter_weights_only` option is True, the checkpointer will save only the adapter weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that sounds like a nice feature to have. Did we get feedback to remove it?

Copy link
Contributor Author

@ebsmothers ebsmothers Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry idk why this got removed. Saving adapter weights only should still be supported

@@ -722,6 +726,7 @@ def train(self) -> None:

# Compute loss
loss = self._loss_fn(logits, labels)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related to this part of the code:

we are missing this, from full_distributed

if cfg.get("fsdp_cpu_offload", False):
        # Utilize all available CPU cores for intra-op parallelism. This provides ~2x
        # speed up when benchmarking fused AdamW on CPU
        training.set_torch_num_threads()

and we have hardcoded, which doesnt sound like we would want to

os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"

Copy link
Contributor Author

@ebsmothers ebsmothers Sep 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So set_torch_num_threads was some heuristic added way back when, can see the convo here. Unless @janeyx99 has a better idea I am happy to just copy this over from the full finetune recipe. But given the small number of trainable params for LoRA it may not have much impact anyways.

Re TORCH_NCCL_AVOID_RECORD_STREAMS, I'll defer to @weifengpy who will know best.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To follow up on the second point, I chatted offline with @weifengpy and we should be OK to remove the TORCH_NCCL_AVOID_RECORD_STREAMS setting

Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no looker

@ebsmothers ebsmothers merged commit 6deeda9 into pytorch:main Sep 10, 2024
17 checks passed
@ebsmothers ebsmothers deleted the migrate-lora-fsdp2 branch September 10, 2024 19:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants