Skip to content

chunked cross entropy #1390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Aug 29, 2024
Merged

chunked cross entropy #1390

merged 31 commits into from
Aug 29, 2024

Conversation

felipemello1
Copy link
Contributor

@felipemello1 felipemello1 commented Aug 22, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

Mechanism: Chunked cross entropy reduces memory by processing and upcasting one chunk at a time. This allows the memory to be released after the computation. If we dont chunk, then compiling model+loss together or independently has no difference, as long as we compile the upcasting with the loss, e.g.:

@torch.compile()
def _cross_entropy_loss(logits, float):
	return cross_entropy_loss(logits.float(), labels)

It is important to notice that chunking must happen BEFORE the linear output, or the memory is not released:
chunking after the linear output layer (bad):
image

chunking before the linear output layer (This pr):
image

When compiling model and loss independently, if the upcast doesnt happen inside of the compiled Cross Entropy, chunking is less effective
image

The other impact of this PR is removing logits manipulation:

logits = logits[..., :-1, :].contiguous()
logits = logits.transpose(1, 2)
image

torch.compile is much better in 2.5 vs 2.4
image

At least for this config, no huge gains in memory for num_chunks > 4. However, we dont see any performance degradation using num_chunks=16, so this was the chosen value.
image

edit: when running smaller models, there is a performance degradation (~10% TPS). I changed the number to num_chunk=8, and there is almost no TPS reduction, and memory savings is still very high.

IMPACT

For llama 8B, seq_len=8192, if activation checkpointing is FALSE, then:

toks/sec:
compiling model+loss together == chunked + compile model + loss separetely

however, if checkpointing is TRUE, then:
compiling model+loss together >> chunked + compile model + loss separetely

In every scenario, chunked saves about 2.5GB of memory per batch with max_seq_len=8192.
For bsz=2, this means 5GB saved.

Therefore, it seems an issue with compile + AC. If we solve that, we can compile model + loss separately, use chunked loss and have the best of all worlds: low compile time, low memory, high toks/sec.
image

edit: we decided to compile the model per layer. This doesnt match the speed of compile(model + loss), but get us closer to that, while having a much faster compile time (~40s vs 10min for Evan).

Overall:

  • We reduced peak memory and should see higher impact as bsz and seq length grows;
  • We removed an extra copy of logits, reducing memory further;
  • We greatly improved compile time by only having to compile the model, instead of the loss_step (in my machine: 200s -> 20s);

Things to be aware of

Compiling the upcasting + CE changes the loss by a small amount. According to the compile team, this is expected. We would have to be aware about it in future parity checks. This is independent of chunked cross entropy.

Edit: when training for 200 steps, there is no difference in loss up to the 2nd decimal point.

Design

We have two main options: setter vs wrapper

  1. setter (this PR): Directly changes the TransformerDecoder, and adds to it:
def set_num_output_chunks(self, num_output_chunks: int) -> None:
        """Use to save memory in combination with ChunkedCrossEntropy (TODO: add link to ChunkedCrossEntropy)"""
        self.num_output_chunks = num_output_chunks

def forward(...):

	if self.num_output_chunks > 0:
		# chunk for ChunkedCrossEntropyLoss, saving memory
		output = [
			self.output(chunk) for chunk in h.chunk(self.num_output_chunks, dim=1)
		]
	else:
		output = self.output(h)

The set_num_output_chunks is called in the recipe:

self.num_output_chunks = getattr(self._loss_fn, "num_output_chunks", 0)
model.set_num_output_chunks(self.num_output_chunks)

If no one intentionally calls set_num_output_chunks, the behavior of transformer is unchanged, and will work as expected during inference.

pros:

  • The contract of the transformer is clear: the output may be a list. Nothing is hidden from the user.
  • It is simple: no need to worry about checkpointing/FSDP wrapping/other functions
    Cons:
  • Transformer decoder is a bit more polluted
  1. Wrapper: wrap the output layer with the chunk logic
class ChunkedOutputWrapper(nn.Module):
    def __init__(self, output_layer, num_output_chunks: int = 16):
        super(ChunkedOutputWrapper, self).__init__()
        self.output_layer = output_layer
        self.num_output_chunks = num_output_chunks

    def forward(self, x):
        return [self.output_layer(chunk) for chunk in x.chunk(self.num_output_chunks, dim=-2)]

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        return self.output_layer.state_dict(destination=destination, prefix=prefix + 'output_layer.', keep_vars=keep_vars)

    def load_state_dict(self, state_dict, strict=True):
        # Adjust keys in state_dict
        adjusted_state_dict = {key[len('output_layer.'):]: value for key, value in state_dict.items() if key.startswith('output_layer.')}
        return self.output_layer.load_state_dict(adjusted_state_dict, strict=strict)

call it in the recipe:

self.num_output_chunks = getattr(self._loss_fn, "num_output_chunks", 0)
if hasattr(model, 'output') and self.num_output_chunks > 0:
            model.output = ChunkedOutputWrapper(model.output, num_output_chunks=self.num_output_chunks)

Pros:

  • No extra lines in the TransformerDecoder

cons:

  • Having to worry about any other logic that checks if the layer is a LoRA layer
  • Getting checkpoint right
  • User cant see what is happening in the transformer. They would have to realize that a wrapper was applied after the instantiation.

Results:
Apparently compile has some issues with the wrapper. It also causes higher reserved memory (prob due to graph break?). Maybe these can be solved?
image
"

Other alternatives that were ruled out:

  • Adding it to the builders: This would require changing many files and add extra logic not only in the builder, but also in the lora module and configs.

Changelog

  • Added Chunked Cross Entropy class
  • Added if/else to TransformerDecoder to do chunk if self.num_output_chunks > 0, else keep normal behavior
  • Changed compile to compile only the model, and not the loss_step
  • Changed the label + logit off by one operation to be more efficient
  • Moved .float() from transformer and added it to the chunked cross entropy

TODO:
[ ] add unit tests after reviewers are ok with the changes.
[ ] consider removing the step loss, since we wont compile it with the model anymore.
[ ] run 100+ steps for every recipe
[ ] add srt
[ ] improve docstring

Test plan

llama lora single
image

tune run lora_finetune_single_device --config llama3_1/8B_qlora_single_device optimizer_in_bwd=False enable_activation_checkpointing=True loss=torchtune.modules.loss.CEWithChunkedOutputLoss  optimizer._component_=torch.optim.AdamW compile=True dataset.packed=False dataset.train_on_input=True tokenizer.max_seq_len=2048 metric_logger=torchtune.utils.metric_logging.WandBLogger metric_logger.project=chunked_CE log_every_n_steps=1 log_peak_memory_stats=True gradient_accumulation_steps=1 epochs=1 batch_size=16 max_steps_per_epoch=20 metric_logger.name=8b__vanilaCE__lora__AC checkpointer.output_dir=/tmp/Meta-Llama-3.1-8B-Instruct/8b__vanilaCE__lora__AC

llama full single
image

tune run full_finetune_single_device --config llama3_1/8B_full_single_device optimizer_in_bwd=False enable_activation_checkpointing=True loss=torchtune.modules.loss.CEWithChunkedOutputLoss optimizer._component_=bitsandbytes.optim.PagedAdamW8bit compile=True dataset.packed=False dataset.train_on_input=True tokenizer.max_seq_len=2048 metric_logger=torchtune.utils.metric_logging.WandBLogger metric_logger.project=chunked_CE log_every_n_steps=1 log_peak_memory_stats=True gradient_accumulation_steps=1 epochs=1 batch_size=16 max_steps_per_epoch=20 metric_logger.name=8b__chunkedCE__full__AC checkpointer.output_dir=/tmp/Meta-Llama-3.1-8B-Instruct/8b__chunkedCE__full__AC

qwen distributed lora
image

tune run --nnodes 1 --nproc_per_node 8 full_finetune_distributed --config qwen2/0.5B_full optimizer_in_bwd=False enable_activation_checkpointing=True loss=torchtune.modules.loss.CEWithChunkedOutputLoss optimizer._component_=torch.optim.AdamW compile=True dataset.packed=False dataset.train_on_input=True tokenizer.max_seq_len=2048 metric_logger=torchtune.utils.metric_logging.WandBLogger metric_logger.project=chunked_CE log_every_n_steps=1 log_peak_memory_stats=True gradient_accumulation_steps=1 epochs=1 batch_size=32 max_steps_per_epoch=20 metric_logger.name=qwen__chunkedCE__nrpoc8_full__AC checkpointer.output_dir=/tmp/Meta-Llama-3.1-8B-Instruct/qwen__chunkedCE__nrpoc8_full__AC

gemma distributed full
image

tune run --nnodes 1 --nproc_per_node 8 lora_finetune_distributed --config gemma/2B_lora optimizer_in_bwd=False enable_activation_checkpointing=True loss=torchtune.modules.loss.CEWithChunkedOutputLoss optimizer._component_=torch.optim.AdamW compile=True dataset.packed=False dataset.train_on_input=True tokenizer.max_seq_len=2048 metric_logger=torchtune.utils.metric_logging.WandBLogger metric_logger.project=chunked_CE log_every_n_steps=1 log_peak_memory_stats=True gradient_accumulation_steps=1 epochs=1 batch_size=16 max_steps_per_epoch=20 metric_logger.name=gemma2b__chunkedCE__nrpoc8_lora__AC checkpointer.output_dir=/tmp/Meta-Llama-3.1-8B-Instruct/gemma2b__chunkedCE__nrpoc8_lora__AC

QWEN QAT

tune run --nnodes 1 --nproc_per_node 8 qat_distributed --config qwen2/0.5B_full optimizer_in_bwd=False enable_activation_checkpointing=True loss=torchtune.modules.loss.CEWithChunkedOutputLoss optimizer._component_=torch.optim.AdamW compile=True dataset.packed=False dataset.train_on_input=True tokenizer.max_seq_len=2048 metric_logger=torchtune.utils.metric_logging.WandBLogger metric_logger.project=chunked_CE log_every_n_steps=1 log_peak_memory_stats=True gradient_accumulation_steps=1 epochs=1 batch_size=32 max_steps_per_epoch=20 metric_logger.name=qwen__chunkedCE__nrpoc8_full__AC checkpointer.output_dir=/tmp/Meta-Llama-3.1-8B-Instruct/qwen__chunkedCE__nrpoc8_full__AC quantizer._component_=torchtune.utils.quantization.Int8DynActInt4WeightQATQuantizer quantizer.groupsize=256
image

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:


Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models

  • I did not change any public API;
  • I have added an example to docs or docstrings;

Copy link

pytorch-bot bot commented Aug 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1390

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 4b06b99 with merge base 9629a36 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 22, 2024
@gau-nernst
Copy link
Contributor

Just curious, do you have data for "compile step loss + chunk CE"?

We greatly improved compile time by only having to compile the model, instead of the loss_step

How much longer was it for you? I didn't notice any significant difference on my machine when I added the compile step loss.

We removed an extra copy of logits, reducing memory further

Yea I have always thought the logits slicing is a bit iffy (HF is doing this too 😆). Technically due to this, the number of tokens seen is seq_len-1 since we remove 1 token from loss calculation. Still think that right shift the labels should be done in the dataset instead, but it would be a separate discussion.

Copy link

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exciting!

output = self.output(h).float()
# shape: [b, seq_len//num_chunks, out_dim] - out_dim is usually the vocab size
if self.num_output_chunks > 0:
# chunk for ChunkedCrossEntropyLoss, saving memory
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My hypothesis was that by chunking before the linear, the output chunks are not views into the same underlying tensor allowing earlier ones to be freed and have their memory reused for later ones.

I wonder if we can check the memory snapshot to confirm that if not already, as it might be nice to understand the constraint on the modeling side (namely, having to chunk the input to the final output projection rather than chunking the output itself).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been bugging me for a while and I still don't understand this. Where would they be freed though? We pass the entire list of chunks to the loss? Or did I misunderstand your point?

Copy link
Contributor Author

@felipemello1 felipemello1 Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its not the the chunks are freed. If we split AFTER, then we have chunks + whole output. Its duplicated. When we split BEFORE, we have only the chunks. No duplication.

The memory savings is because, when we have chunks, the extra memory by upcasting is at most of size of one chunk, instead of doubling the whole output.

Ideally, if we split AFTER, pytorch should learn to delete the whole output that we have no need for anymore, avoiding the duplication. But this doesnt happen.

image
image

@felipemello1
Copy link
Contributor Author

felipemello1 commented Aug 22, 2024

@gau-nernst

Just curious, do you have data for "compile step loss + chunk CE"?

Its a mess (green). It consumes more memory and takes 10+ min to do 5 steps, because it keeps recompiling on every step.
Red is compile step loss without chunked.
Grey is this PR.
image

How much longer was it for you?

180s
image

Still think that right shift the labels should be done in the dataset instead

I see your point, and it makes sense. But I guess that having it exposed in the recipe gives more flexibility to the user.

@felipemello1
Copy link
Contributor Author

@awgu

I wonder if we can check the memory snapshot to confirm that if not already, as it might be nice to understand the constraint on the modeling side (namely, having to chunk the input to the final output projection rather than chunking the output itself).

I will get it for you. If we can chunk after the linear, that would be beautiful

@felipemello1
Copy link
Contributor Author

felipemello1 commented Aug 23, 2024

One update: compiling loss + model consumes more memory, but it is faster if activation checkpoint is true. Previously I claimed that the speed was the same. I need to test it more. Working with @weifengpy to see if we can make compile work with chunked + compile together.

image

@gau-nernst
Copy link
Contributor

I think two important things to understand is

  1. How compiling model+loss together gives such a significant speedup.
  2. Where are the memory savings come from. I believe what Andrew said is probably right. Also try to see if memory savings are still there if chunking is done after.

I try to reproduce 1. in a standalone script with Linear (LM head) + CE but couldn't reproduce the speedup. Perhaps something else is happening here 👀

Another question. Is there a need to upcast logits after LM head to FP32? From what I understand, CE will upcast logits to FP32 anyway internally, both in eager and compile mode.

@felipemello1
Copy link
Contributor Author

Another question. Is there a need to upcast logits after LM head to FP32? From what I understand, CE will upcast logits to FP32 anyway internally, both in eager and compile mode.

I dont know :/ . Note: compiling the upcast with the CE has a small difference in the loss.

I try to reproduce 1. in a standalone script with Linear (LM head) + CE but couldn't reproduce the speedup. Perhaps something else is happening here 👀

Yeah, i would love to have the best of all worlds: speed up + fast compile + lowest memory

@felipemello1
Copy link
Contributor Author

felipemello1 commented Aug 23, 2024

Ok, I confirmed it: If activation checkpointing is FALSE, then:

toks/sec
compiling model+loss together == chunked + compile model + loss separetely

however, if checkpointing is TRUE, then
compiling model+loss together >> chunked + compile model + loss separetely

In every scenario, chunked saves about 2.5GB of memory per batch with max_seq_len=8192.
For bsz=2, this means 5GB saved.

Therefore, it seems an issue with compile + AC. If we solve that, we can compile model + loss separately, use chunked loss and have the best of all worlds: low compile time, low memory, high toks/sec.
image

@felipemello1
Copy link
Contributor Author

@gau-nernst

From what I understand, CE will upcast logits to FP32 anyway internally, both in eager and compile mode.

Loss is different if i remove .float()

# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
# But this way we dont need to slice the logits. We just add an ignore index to labels.
labels = torch.hstack(
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if I'm being dense but isn't labels.shape[0] always going to be batch_size which is also what the first dim of the cache is? If so, why do we need to slice?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that we need to for the last batch of the dataloader, which may be of a different size.

Eg: you have 10 samples, but bsz=3. Then, unless we have drop_last=True in the dataloader, the last batch will have size=1.

labels = torch.hstack(
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
)
if not isinstance(logits, list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah :(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah no way around this one unfortunately. But it's worth it for the memory wins 😃

"""

def __init__(self, num_output_chunks: int = 16, ignore_index: int = -100):
super(ChunkedCrossEntropyLoss, self).__init__()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not super().__init__()? I think thats how we do this across the code base?

self.num_output_chunks = num_output_chunks
self.ignore_index = ignore_index
self.cross_entropy_loss = torch.nn.CrossEntropyLoss(
reduction="sum", ignore_index=self.ignore_index
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the change in reduction? I think previously we were using the default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The default is mean. I can do the mean, sum the means, and divide by the number of chunks. It should be equivalent and probably a bit faster

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think your current way is more correct. When there are different paddings across samples in a batch (i.e. number of ignore_index is different), taking the mean of (mean of each chunk) != (total loss of non_padded tokens / no. of non_padded_tokens)

Copy link
Contributor Author

@felipemello1 felipemello1 Aug 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, right! I would have to do a few more operations for it to be numerically the same. Thats why i computed the sum instead of the mean.

self.output(chunk) for chunk in h.chunk(self.num_output_chunks, dim=1)
]
else:
output = self.output(h)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry if we already discussed this, but why move the float call from here to the recipe? Easy for recipe writers to trip up and doesnt seem like its conceptually different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thats true! I can put the float back here. It doesnt need to be in the recipe.

output = self.output(h).float()
# shape: [b, seq_len//num_chunks, out_dim] - out_dim is usually the vocab size
if self.num_output_chunks > 0:
# chunk for ChunkedCrossEntropyLoss, saving memory
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been bugging me for a while and I still don't understand this. Where would they be freed though? We pass the entire list of chunks to the loss? Or did I misunderstand your point?

@gau-nernst
Copy link
Contributor

Some interesting triton kernels from Liger-Kernel, might be interesting to test them too https://github.com/linkedin/Liger-Kernel

  • Cross entropy: fused forward and backward. Drop-in replacement by using liger_kernel.transformers.cross_entropy.LigerCrossEntropyLoss
  • Fused linear cross entropy: also implement pre-LM-head chunking logic. Integration will be harder.

@codecov-commenter
Copy link

Codecov Report

Attention: Patch coverage is 5.04202% with 113 lines in your changes missing coverage. Please review.

Project coverage is 69.59%. Comparing base (7e084d9) to head (131f058).
Report is 5 commits behind head on main.

Files with missing lines Patch % Lines
recipes/full_finetune_distributed.py 0.00% 18 Missing ⚠️
torchtune/modules/loss/ce_chunked_output_loss.py 0.00% 18 Missing ⚠️
recipes/lora_finetune_single_device.py 0.00% 17 Missing ⚠️
recipes/qat_distributed.py 0.00% 17 Missing ⚠️
recipes/lora_finetune_distributed.py 0.00% 16 Missing ⚠️
recipes/full_finetune_single_device.py 0.00% 15 Missing ⚠️
torchtune/modules/transformer.py 45.45% 6 Missing ⚠️
torchtune/models/gemma/transformer.py 20.00% 4 Missing ⚠️
torchtune/modules/loss/__init__.py 0.00% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1390      +/-   ##
==========================================
- Coverage   70.14%   69.59%   -0.55%     
==========================================
  Files         272      270       -2     
  Lines       12919    13011      +92     
==========================================
- Hits         9062     9055       -7     
- Misses       3857     3956      +99     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A handful of small comments (for whatever reason I left them all in the QAT recipe, but ofc applicable to all recipes). But no major concerns from me here. Really excited to get this in and realize these memory savings!

@@ -533,6 +547,12 @@ def __init__(
self.num_heads = num_heads
self.head_dim = head_dim
self.causal_mask = None
self.num_output_chunks = 0

def set_num_output_chunks(self, num_output_chunks: int) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could consider just using @property on def num_output_chunks(self) but no strong preference here

self._loss_fn = config.instantiate(cfg.loss)
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor")
if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just use isinstance here?

Copy link
Contributor Author

@felipemello1 felipemello1 Aug 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

either I would have to import the class just for this check, or i could use modules.loss.CEWithChunkedOutputLoss, since modules is already imported. However, for some reason i was getting error "loss is not in torchtune.modules". So it seemed to be the easiest way to avoid an extra import.

I dont mind changing it if you prefer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm weird. It's definitely OK to import from modules in the recipe and imo a little bit cleaner (even if it does require the extra import). We already import from mouldes in our LoRA recipes for various PEFT stuff though we don't in full finetune recipes. Anyways no strong preference here, definitely not a blocker

Comment on lines 248 to 250
self._loss_fn._compute_cross_entropy = torch.compile(
self._loss_fn._compute_cross_entropy, backend=backend
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be good to add a quick comment here around why we compile this method specifically (I know you already explain in detail in the class, but this is more visible)

labels = torch.hstack(
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
)
if not isinstance(logits, list):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah no way around this one unfortunately. But it's worth it for the memory wins 😃

@@ -202,6 +203,7 @@ def setup(self, cfg: DictConfig) -> None:

checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)

self._model_compile = cfg.get("compile", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should definitely land distributed FFT model compile as a fast follow here; until then the behavior will be confusing to folks

@felipemello1 felipemello1 merged commit 4fba6cd into pytorch:main Aug 29, 2024
20 checks passed
@felipemello1
Copy link
Contributor Author

@SLR722, this changed all default configs. So before it lands in genie, you would have to update the recipes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants