Skip to content

Add per-layer compile support to recipes #1419

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 28, 2024
Merged

Conversation

yf225
Copy link
Contributor

@yf225 yf225 commented Aug 27, 2024

Enabling per-layer compile for our single-device LoRA, single-device full finetune, and FSDP2 LoRA recipes. FSDP2 full finetune will be done in a follow-up.

Results

All recipes were run with three different configurations: (1) per-layer compile (this PR), (2) full-model compile (i.e. compile=True on main), (3) no compile.

QLoRA single-device

WandB results

Screenshot 2024-08-27 at 10 40 51 PM

Repro

TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 CUDA_VISIBLE_DEVICES=0 tune run lora_finetune_single_device \
--config llama3/8B_qlora_single_device model.lora_rank=16 optimizer=bitsandbytes.optim.AdamW8bit \
 gradient_accumulation_steps=4 tokenizer.max_seq_len=2048 max_steps_per_epoch=100 \
 metric_logger=torchtune.utils.metric_logging.WandBLogger metric_logger.project=pr-1419 compile=True \
 log_peak_memory_stats=True metric_logger.name=qlora-per-layer-compile

LoRA

WandB results

Screenshot 2024-08-27 at 10 41 40 PM

Repro

TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 CUDA_VISIBLE_DEVICES=1 tune run \
lora_finetune_single_device --config llama3/8B_lora_single_device model.lora_rank=16 \
 optimizer=bitsandbytes.optim.AdamW8bit gradient_accumulation_steps=4 tokenizer.max_seq_len=2048 \
 max_steps_per_epoch=100 model.lora_attn_modules=['q_proj','k_proj','v_proj','output_proj'] \
 model.apply_lora_to_mlp=True metric_logger=torchtune.utils.metric_logging.WandBLogger \
metric_logger.project=pr-1419 compile=True log_peak_memory_stats=True metric_logger.name=lora-per-layer-compile

FFT

WandB results

Screenshot 2024-08-27 at 10 42 12 PM

Repro

TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 CUDA_VISIBLE_DEVICES=2 \
tune run full_finetune_single_device --config llama3/8B_full_single_device \
dataset=torchtune.datasets.alpaca_cleaned_dataset optimizer=bitsandbytes.optim.AdamW8bit \
gradient_accumulation_steps=4 tokenizer.max_seq_len=2048 max_steps_per_epoch=100 epochs=1 \
optimizer_in_bwd=False metric_logger=torchtune.utils.metric_logging.WandBLogger metric_logger.project=pr-1419 \
metric_logger.name=fft-per-layer-compile compile=True log_peak_memory_stats=True

LoRA FSDP2

Repro only to ensure things aren't broken, will work on perf in a follow-up

Repro

First change dev/llama2/70B_qlora_fsdp2.yaml recipe to this, then run

tune run --nproc_per_node 2 lora_finetune_fsdp2 --config llama2/70B_qlora max_steps_per_epoch=10
1|10|Loss: 1.7545080184936523: 100%|█████████████| 10/10 [06:03<00:00, 36.29s/it]

Also consolidated E2E time over 100 steps and compile time for QLoRA, LoRA, and full finetunes of Llama3 8B on single device. Note that compile on main currently OOMs due to some kind of memory leak, so we don't have E2E times for the models compiled on main.

Screenshot 2024-08-27 at 10 47 31 PM

Copy link

pytorch-bot bot commented Aug 27, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1419

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4678ce9 with merge base 77fbb4f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 27, 2024
Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Will!

a couple of thoughts: I personally dont like that we have two flags for compiling, and would prefer to avoid cases like:

compile: False
per_layer_compile: True

IMO, we should either:

  1. If compile is true, have per_layer_compile by default
  2. make compile a nested config, like:
compile:
	per_layer: True
	loss: True

I prefer 1, as I dont think that the user needs this level of control through configs. Any thoughts?

Second:
Do you think you could test it for our distributed recipe that already uses FSDP2? Or should it be another PR?

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Aug 27, 2024

Would this also enable optimizer in bwd 🤝 compile?

if self._model_compile:
log.info("Compiling loss with torch.compile...")
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
self._loss_fn = torch.compile(self._loss_fn, backend=backend)
Copy link
Contributor

@felipemello1 felipemello1 Aug 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chunked CE can only have compile on the CE + upcasting part. If the chunking is compiled with it, it loses the benefit :/

I think we can leave the loss compile outside of the PR, if chunked CE will be the default

edit: scratch that. Will add to Chuncked CE PR something like

loss_fn = instantiate(cfg.loss)
if isinstance(loss_fn, ChunkedCrossEntropy):
        loss_fn._cross_entropy.compile()
else:
        loss_fn.compile()

@gau-nernst
Copy link
Contributor

Came across this PR. @yf225 Am I seeing memory leak for "qlora-compile-main"? If you are using torch nightly after 20240824, might be the same as what I'm seeing here pytorch/pytorch#134642

@ebsmothers
Copy link
Contributor

Came across this PR. @yf225 Am I seeing memory leak for "qlora-compile-main"? If you are using torch nightly after 20240824, might be the same as what I'm seeing here pytorch/pytorch#134642

@gau-nernst yeah same here. I'm not sure what's going on because I've been compiling models with no problems for a while. But as of a couple days ago I also started seeing what looks like a memory leak. Attaching a memory viz I ran on a full finetune; seems like a bunch of memory is not getting freed after each step.

455259770_1065769865167420_9145965881265563120_n

@@ -230,6 +231,10 @@ def setup(self, cfg: DictConfig) -> None:
)

self._loss_fn = config.instantiate(cfg.loss)
if self._model_compile:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: should we rename "model_compile" to just "compile", since this command will control all loss compile and flexattention compile?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not a priority, i am fine to keep as it for now, and maybe add a todo to rename

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think our naming here is not great. Let's do in a follow-up; rather than add a todo in the code I may just create an issue.

Copy link
Collaborator

@SalmanMohammadi SalmanMohammadi Aug 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ebsmothers ebsmothers merged commit 9629a36 into pytorch:main Aug 28, 2024
20 checks passed
@yf225
Copy link
Contributor Author

yf225 commented Aug 28, 2024

@gau-nernst @ebsmothers For the full-model compile memory leak issue, this should now be resolved by reverting pytorch/pytorch#134272.

In this PR we've also switched to using per-layer compile (#1419), which also won't have the memory leak issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants