Skip to content

Llama3-70b: Full Finetune w/CPU offload + fused optimizer #993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 1, 2024
Merged

Conversation

rohan-varma
Copy link
Member

@rohan-varma rohan-varma commented May 17, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Please link to any issues this PR addresses.

Changelog

  • This PR enables 70B full finetune for Llama3 workload on 8xA100 set up with CPU offload. The configuration offered only works with PyTorch nightly which has recent changes to support fused=True for AdamW to utilize vectorized CPU instructions.
  • Similar to Llama3-70B LoRA multi GPU #802, we use the HF checkpointer and load the model with HF format, with the meta format currently unsupported.
  • QPS looks as follows: 1|95|Loss: 0.7184870839118958: 3%|▎ | 95/3251 [22:02<11:53:47, 13.57s/it] - while slow, this actually seems competitive with 70B training workloads seen in other libraries.
  • Appropriate warnings and error checking is added
  • This PR also calls torch.set_num_threads after set up to ensure that the appropriate # of CPU cores are set for intra-op parallelism. On my box, this results in ~2x speed up compared to not setting it. I got the idea to do this after talking with @msaroufim @janeyx99

Test plan

tune run --nproc_per_node 8 full_finetune_distributed --config recipes/configs/llama3/70B_full.yaml metric_logger=torchtune.utils.metric_logging.WandBLogger log_peak_memory_stats=True

Memory usage, sits at ~35GB active per GPU, 42 GB reserved -

image

Loss looks comparable to 70B model trained w/LoRA (see loss curves #802) -
image

Copy link

pytorch-bot bot commented May 17, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/993

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit cb4e311 with merge base 3b1bcf9 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 17, 2024

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-70b
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to covert the paths to instruct

@rohan-varma rohan-varma requested a review from ebsmothers May 17, 2024 07:28
if cfg.get("fsdp_cpu_offload", False):
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
# speed up when benchmarking fused AdamW on CPU
threads = os.cpu_count() // torch.distributed.get_world_size()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you calculate this? I opened the draftiest PR pytorch/pytorch#126199 to account for other environments where CPU affinity is set and cgroups could be used, and plan on talking with Intel to get this into better shape. But why os.cpu_count() // torch.distributed.get_world_size()?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@janeyx99 Unfortuantely I don't have the bets answer to this, @msaroufim suggested following this over chat and referred to answerAI doing similar (following his suggestions) - AnswerDotAI/fsdp_qlora@1a9fddf

Intuitively, this makes sense since it's essentially assigning each process an equal number of vCPUs out of the vCPUs available. But as you mentioned, this doesn't take into account physical CPUs being < vCPUs, cgroups, and the like. Mostly meant to be a stopgap for now, and I'm sure perf will be very different on different HW set ups. In the medium term, we might want to use what you end up landing in torch core, or even expose this to the user so they can configure this.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeh I suspect comprehensive benchmarking with how diverse cpus are might take a while. So this is a fine stopgap

Copy link
Contributor

@janeyx99 janeyx99 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woo! Nice to integrate fused CPU adam into torchtune :)

Copy link
Contributor

@kartikayk kartikayk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this change! I think there are some open recipe UX questions here which I'd like for us to think through before we land this.

Comment on lines 108 to 122
if cfg.get("fsdp_cpu_offload", False):
if not cfg.optimizer.get("fused", False):
log.warning(
"""
It is highly recommended to use fused optimizer implementation when CPU
offloading, otherwise optimizer will use non vectorized kernels and cause massive slowdown.
"""
)
else:
# We are fusing + CPU offloading - check to make sure we're in a nightly.
if not utils.torch_version_ge("2.4.0"):
raise RuntimeError(
"Using fused optimizer on CPU is only supported in PyTorch nightly."
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of this if-else block being in the main recipe at all. This is starting to get back into the territory of the wild-wild-west recipe scripts that we spoke about a while back. Let's please brainstorm a bit about how to handle this and come up with the right design for this.

Copy link
Contributor

@ebsmothers ebsmothers May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave a suggestion in _recipe_registry.py, at least related to the versioning side of things (but it's a bigger change than what we're looking at here). Separately we are kinda drawing a line (I don't think?) we've drawn before.. we are putting an explicit nightly check in one of our core 4 recipes.

As for the actual validation of fsdp_cpu_offload + optimizer.fused, I guess it's a bit tricky to put in a single util since it's a more cross-cutting config pair. The only alternative I can think of offhand is to move into a config validation, which I think we also had legitimate concerns about doing previously. (I guess we can't validate directly on the optimizer either since it gets instantiated after the FSDP wrapping)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Versioning recipes make sense, though this portion has a bit more than versioning since I also want to warn if fused optimizer is not used with CPU offload.

we are putting an explicit nightly check in one of our core 4 recipes.

We are - but only under a very specific case - CPU offload with FSDP distributed training.

Can definitely add a helper function, though let me know if we should consider something else / alternative design. @ebsmothers @kartikayk

@@ -563,6 +584,12 @@ def recipe_main(cfg: DictConfig) -> None:
)

init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
if cfg.get("fsdp_cpu_offload", False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again not a fan of this being in the recipe code, this should likely be in a utility or something which the user doesn't probably need to reason about.

cc: @ebsmothers to help brainstorm this a bit

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah +1 can we put this in a utility?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, definitely!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bumping this comment. Also I don't fully understand the code comment just below: do we only need to set_num_threads in the case of fused optimizer? Looking at the answer.ai code seems like the answer is no, but based on the code comment it's not clear to me

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We see the biggest gain setting this when fused optimizer is used as it controls threads used for intra-op parallelism, and CPU optimizer is the heaviest CPU op. We may have very slight speed ups when not using fused optimizer, but those can be investigated separately of this PR.

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct
checkpoint_files: [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😢

_component_: torch.optim.AdamW
lr: 2e-5
foreach: False
fused: True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm: this is a necessary change to get 70B full finetune runnable on 8x A100? Or is it more for speedup of CPU offload?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a speedup only change (will clarify the amount of speed up we get once I re-run some benchmark)

@@ -57,6 +57,7 @@ class Recipe:
Config(name="llama2/7B_full", file_path="llama2/7B_full.yaml"),
Config(name="llama2/13B_full", file_path="llama2/13B_full.yaml"),
Config(name="llama3/8B_full", file_path="llama3/8B_full.yaml"),
Config(name="llama3/70B_full", file_path="llama3/70B_full.yaml"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Obv not necessary for this particular PR, but just a (maybe controversial) thought: we could consider adding additional metadata to config and/or recipe dataclasses to capture dependencies (like specific torch/ao versions or certain optional deps). A bit tricky if users want to copy and modify, but this way we would validate in here and not have to do checks in recipes (or wherever else). Thoughts? @kartikayk @joecummings @pbontrager @RdoubleA

Comment on lines 108 to 122
if cfg.get("fsdp_cpu_offload", False):
if not cfg.optimizer.get("fused", False):
log.warning(
"""
It is highly recommended to use fused optimizer implementation when CPU
offloading, otherwise optimizer will use non vectorized kernels and cause massive slowdown.
"""
)
else:
# We are fusing + CPU offloading - check to make sure we're in a nightly.
if not utils.torch_version_ge("2.4.0"):
raise RuntimeError(
"Using fused optimizer on CPU is only supported in PyTorch nightly."
)

Copy link
Contributor

@ebsmothers ebsmothers May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave a suggestion in _recipe_registry.py, at least related to the versioning side of things (but it's a bigger change than what we're looking at here). Separately we are kinda drawing a line (I don't think?) we've drawn before.. we are putting an explicit nightly check in one of our core 4 recipes.

As for the actual validation of fsdp_cpu_offload + optimizer.fused, I guess it's a bit tricky to put in a single util since it's a more cross-cutting config pair. The only alternative I can think of offhand is to move into a config validation, which I think we also had legitimate concerns about doing previously. (I guess we can't validate directly on the optimizer either since it gets instantiated after the FSDP wrapping)

@@ -563,6 +584,12 @@ def recipe_main(cfg: DictConfig) -> None:
)

init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl")
if cfg.get("fsdp_cpu_offload", False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah +1 can we put this in a utility?

@ebsmothers
Copy link
Contributor

Do we know the story behind the large delta between peak active and reserved? (Basically I am wondering if this is similar to other cases where we saw that we could get peak reserved down just by constraining the available memory a bit. And if so, could we actually run this on 4x A100? That'd be nice)

@janeyx99
Copy link
Contributor

@ebsmothers why do we care about reserved memory? The CUDACachingAllocator will keep reserving and reusing memory as long as it's available, so the reserved memory will usually be higher. Peak active memory is more crucial to care about.

@ebsmothers
Copy link
Contributor

@ebsmothers why do we care about reserved memory? The CUDACachingAllocator will keep reserving and reusing memory as long as it's available, so the reserved memory will usually be higher. Peak active memory is more crucial to care about.

@janeyx99 yeah I only care insofar as it impacts how many devices we can run on. If we can actually run on 4x A100 (which it seems like we should be able to based on peak active) that'd be great.

@rohan-varma rohan-varma requested a review from ebsmothers May 30, 2024 19:01
torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1
)
torch.set_num_threads(num_threads)
log.info(f"Set intra op parallelism no. of threads to {num_threads}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason the logger in this file is _log and not log, so this errors out

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ouch, great catch

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working through all the comments here! This is awesome functionality to have and I like the way the UX turned out here. One more comment to address before merging but otherwise looks great!

@ebsmothers ebsmothers dismissed kartikayk’s stale review June 1, 2024 20:43

Changes have been addressed

@ebsmothers ebsmothers merged commit eac2dc5 into main Jun 1, 2024
29 checks passed
weifengpy pushed a commit to weifengpy/torchtune that referenced this pull request Jun 4, 2024
@joecummings joecummings deleted the 70b_full branch June 4, 2024 19:21
@andyl98
Copy link
Contributor

andyl98 commented Jun 6, 2024

Hi @rohan-varma , I was trying some fine-tuning runs with 8x A100-80G setup & 1T memory. The training loop works fine but when entering the checkpoint saving stage, the memory utilization continues to go up until it exceeds my node's limit. Do you know if this is a bug or an intended behavior? Thanks a lot for helping out!
Screenshot 2024-06-05 at 9 13 07 PM

@andyl98
Copy link
Contributor

andyl98 commented Jun 6, 2024

Update: I think I can successfully save the checkpoint after training if I don't specifically specify offload_to_cpu for FullOptimStateDictConfig & don't store the optimizer states.

       # To prevent GPU memory from spiking during checkpoint save,
        # we consolidate the full model and optim state dicts on CPU for rank 0
        with FSDP.state_dict_type(
            self._model,
            StateDictType.FULL_STATE_DICT,
            FullStateDictConfig(offload_to_cpu=True, rank0_only=True),
            # FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True),
        ):  
            cpu_state_dict = self._model.state_dict()
            # opt_state_dict = FSDP.optim_state_dict(self._model, self._optimizer)
        
        # Now that we have the model and opt state dict, create the actual checkpoint dict
        # to be sent to the checkpointer and ultimately written to file

        if self._is_rank_zero:
            checkpoint_dict.update({utils.MODEL_KEY: cpu_state_dict})

            # if training is in-progress, checkpoint the optimizer state as well
            if epoch + 1 < self.total_epochs:
                checkpoint_dict.update(
                    {
                        # utils.OPT_KEY: opt_state_dict,
                        utils.SEED_KEY: self.seed,
                        utils.EPOCHS_KEY: self.epochs_run,
                        utils.TOTAL_EPOCHS_KEY: self.total_epochs,
                        utils.MAX_STEPS_KEY: self.max_steps_per_epoch,
                    }
                )
      
            self._checkpointer.save_checkpoint(
                checkpoint_dict,
                epoch=epoch,
                intermediate_checkpoint=(epoch + 1 < self.total_epochs),
            )
Screenshot 2024-06-06 at 2 13 20 PM

@musabgultekin
Copy link
Contributor

@andyl98 Extra note: I was able to train&save with 4xA100 80GB with 1TB DRAM. It also worked with 8xA100 with 2TB DRAM. Haven't done 8xA100 with 1TB though like in your case, it might be that 1TB DRAM is too small for 8x.

@andyl98
Copy link
Contributor

andyl98 commented Jun 7, 2024

@andyl98 Extra note: I was able to train&save with 4xA100 80GB with 1TB DRAM. It also worked with 8xA100 with 2TB DRAM. Haven't done 8xA100 with 1TB though like in your case, it might be that 1TB DRAM is too small for 8x.

Right, seems like the model full shard checkpointing stage inevitably consumes quite some memory. Appreciate the note!

@ebsmothers
Copy link
Contributor

@andyl98 @musabgultekin sorry for the delay here, I agree it's suspicious that we'd require 1TB to save the 70B model and optimizer states. I created #1092 so we can track it there

maximegmd pushed a commit to maximegmd/torchtune that referenced this pull request Jul 13, 2024
yinfan98 pushed a commit to yinfan98/sgl-tune-eagle that referenced this pull request May 26, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants