-
Notifications
You must be signed in to change notification settings - Fork 647
Llama3-70b: Full Finetune w/CPU offload + fused optimizer #993
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/993
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit cb4e311 with merge base 3b1bcf9 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
recipes/configs/llama3/70B_full.yaml
Outdated
|
||
checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Meta-Llama-3-70b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need to covert the paths to instruct
recipes/full_finetune_distributed.py
Outdated
if cfg.get("fsdp_cpu_offload", False): | ||
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x | ||
# speed up when benchmarking fused AdamW on CPU | ||
threads = os.cpu_count() // torch.distributed.get_world_size() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How did you calculate this? I opened the draftiest PR pytorch/pytorch#126199 to account for other environments where CPU affinity is set and cgroups could be used, and plan on talking with Intel to get this into better shape. But why os.cpu_count() // torch.distributed.get_world_size()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@janeyx99 Unfortuantely I don't have the bets answer to this, @msaroufim suggested following this over chat and referred to answerAI doing similar (following his suggestions) - AnswerDotAI/fsdp_qlora@1a9fddf
Intuitively, this makes sense since it's essentially assigning each process an equal number of vCPUs out of the vCPUs available. But as you mentioned, this doesn't take into account physical CPUs being < vCPUs, cgroups, and the like. Mostly meant to be a stopgap for now, and I'm sure perf will be very different on different HW set ups. In the medium term, we might want to use what you end up landing in torch core, or even expose this to the user so they can configure this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeh I suspect comprehensive benchmarking with how diverse cpus are might take a while. So this is a fine stopgap
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Woo! Nice to integrate fused CPU adam into torchtune :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for adding this change! I think there are some open recipe UX questions here which I'd like for us to think through before we land this.
recipes/full_finetune_distributed.py
Outdated
if cfg.get("fsdp_cpu_offload", False): | ||
if not cfg.optimizer.get("fused", False): | ||
log.warning( | ||
""" | ||
It is highly recommended to use fused optimizer implementation when CPU | ||
offloading, otherwise optimizer will use non vectorized kernels and cause massive slowdown. | ||
""" | ||
) | ||
else: | ||
# We are fusing + CPU offloading - check to make sure we're in a nightly. | ||
if not utils.torch_version_ge("2.4.0"): | ||
raise RuntimeError( | ||
"Using fused optimizer on CPU is only supported in PyTorch nightly." | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not a fan of this if-else block being in the main recipe at all. This is starting to get back into the territory of the wild-wild-west recipe scripts that we spoke about a while back. Let's please brainstorm a bit about how to handle this and come up with the right design for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I gave a suggestion in _recipe_registry.py, at least related to the versioning side of things (but it's a bigger change than what we're looking at here). Separately we are kinda drawing a line (I don't think?) we've drawn before.. we are putting an explicit nightly check in one of our core 4 recipes.
As for the actual validation of fsdp_cpu_offload + optimizer.fused, I guess it's a bit tricky to put in a single util since it's a more cross-cutting config pair. The only alternative I can think of offhand is to move into a config validation, which I think we also had legitimate concerns about doing previously. (I guess we can't validate directly on the optimizer either since it gets instantiated after the FSDP wrapping)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Versioning recipes make sense, though this portion has a bit more than versioning since I also want to warn if fused optimizer is not used with CPU offload.
we are putting an explicit nightly check in one of our core 4 recipes.
We are - but only under a very specific case - CPU offload with FSDP distributed training.
Can definitely add a helper function, though let me know if we should consider something else / alternative design. @ebsmothers @kartikayk
@@ -563,6 +584,12 @@ def recipe_main(cfg: DictConfig) -> None: | |||
) | |||
|
|||
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") | |||
if cfg.get("fsdp_cpu_offload", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Again not a fan of this being in the recipe code, this should likely be in a utility or something which the user doesn't probably need to reason about.
cc: @ebsmothers to help brainstorm this a bit
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah +1 can we put this in a utility?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, definitely!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Bumping this comment. Also I don't fully understand the code comment just below: do we only need to set_num_threads
in the case of fused optimizer? Looking at the answer.ai code seems like the answer is no, but based on the code comment it's not clear to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We see the biggest gain setting this when fused optimizer is used as it controls threads used for intra-op parallelism, and CPU optimizer is the heaviest CPU op. We may have very slight speed ups when not using fused optimizer, but those can be investigated separately of this PR.
checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Meta-Llama-3-70B-Instruct | ||
checkpoint_files: [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😢
_component_: torch.optim.AdamW | ||
lr: 2e-5 | ||
foreach: False | ||
fused: True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just to confirm: this is a necessary change to get 70B full finetune runnable on 8x A100? Or is it more for speedup of CPU offload?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a speedup only change (will clarify the amount of speed up we get once I re-run some benchmark)
@@ -57,6 +57,7 @@ class Recipe: | |||
Config(name="llama2/7B_full", file_path="llama2/7B_full.yaml"), | |||
Config(name="llama2/13B_full", file_path="llama2/13B_full.yaml"), | |||
Config(name="llama3/8B_full", file_path="llama3/8B_full.yaml"), | |||
Config(name="llama3/70B_full", file_path="llama3/70B_full.yaml"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Obv not necessary for this particular PR, but just a (maybe controversial) thought: we could consider adding additional metadata to config and/or recipe dataclasses to capture dependencies (like specific torch/ao versions or certain optional deps). A bit tricky if users want to copy and modify, but this way we would validate in here and not have to do checks in recipes (or wherever else). Thoughts? @kartikayk @joecummings @pbontrager @RdoubleA
recipes/full_finetune_distributed.py
Outdated
if cfg.get("fsdp_cpu_offload", False): | ||
if not cfg.optimizer.get("fused", False): | ||
log.warning( | ||
""" | ||
It is highly recommended to use fused optimizer implementation when CPU | ||
offloading, otherwise optimizer will use non vectorized kernels and cause massive slowdown. | ||
""" | ||
) | ||
else: | ||
# We are fusing + CPU offloading - check to make sure we're in a nightly. | ||
if not utils.torch_version_ge("2.4.0"): | ||
raise RuntimeError( | ||
"Using fused optimizer on CPU is only supported in PyTorch nightly." | ||
) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I gave a suggestion in _recipe_registry.py, at least related to the versioning side of things (but it's a bigger change than what we're looking at here). Separately we are kinda drawing a line (I don't think?) we've drawn before.. we are putting an explicit nightly check in one of our core 4 recipes.
As for the actual validation of fsdp_cpu_offload + optimizer.fused, I guess it's a bit tricky to put in a single util since it's a more cross-cutting config pair. The only alternative I can think of offhand is to move into a config validation, which I think we also had legitimate concerns about doing previously. (I guess we can't validate directly on the optimizer either since it gets instantiated after the FSDP wrapping)
@@ -563,6 +584,12 @@ def recipe_main(cfg: DictConfig) -> None: | |||
) | |||
|
|||
init_process_group(backend="gloo" if cfg.device == "cpu" else "nccl") | |||
if cfg.get("fsdp_cpu_offload", False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah +1 can we put this in a utility?
Do we know the story behind the large delta between peak active and reserved? (Basically I am wondering if this is similar to other cases where we saw that we could get peak reserved down just by constraining the available memory a bit. And if so, could we actually run this on 4x A100? That'd be nice) |
@ebsmothers why do we care about reserved memory? The CUDACachingAllocator will keep reserving and reusing memory as long as it's available, so the reserved memory will usually be higher. Peak active memory is more crucial to care about. |
@janeyx99 yeah I only care insofar as it impacts how many devices we can run on. If we can actually run on 4x A100 (which it seems like we should be able to based on peak active) that'd be great. |
torchtune/utils/_distributed.py
Outdated
torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 | ||
) | ||
torch.set_num_threads(num_threads) | ||
log.info(f"Set intra op parallelism no. of threads to {num_threads}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For some reason the logger in this file is _log
and not log
, so this errors out
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ouch, great catch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working through all the comments here! This is awesome functionality to have and I like the way the UX turned out here. One more comment to address before merging but otherwise looks great!
Hi @rohan-varma , I was trying some fine-tuning runs with 8x A100-80G setup & 1T memory. The training loop works fine but when entering the checkpoint saving stage, the memory utilization continues to go up until it exceeds my node's limit. Do you know if this is a bug or an intended behavior? Thanks a lot for helping out! |
@andyl98 Extra note: I was able to train&save with 4xA100 80GB with 1TB DRAM. It also worked with 8xA100 with 2TB DRAM. Haven't done 8xA100 with 1TB though like in your case, it might be that 1TB DRAM is too small for 8x. |
Right, seems like the model full shard checkpointing stage inevitably consumes quite some memory. Appreciate the note! |
@andyl98 @musabgultekin sorry for the delay here, I agree it's suspicious that we'd require 1TB to save the 70B model and optimizer states. I created #1092 so we can track it there |
Add torchft.md
Context
What is the purpose of this PR? Is it to
Please link to any issues this PR addresses.
Changelog
fused=True
for AdamW to utilize vectorized CPU instructions.1|95|Loss: 0.7184870839118958: 3%|▎ | 95/3251 [22:02<11:53:47, 13.57s/it]
- while slow, this actually seems competitive with 70B training workloads seen in other libraries.torch.set_num_threads
after set up to ensure that the appropriate # of CPU cores are set for intra-op parallelism. On my box, this results in ~2x speed up compared to not setting it. I got the idea to do this after talking with @msaroufim @janeyx99Test plan
Memory usage, sits at ~35GB active per GPU, 42 GB reserved -
Loss looks comparable to 70B model trained w/LoRA (see loss curves #802) -
