-
Notifications
You must be signed in to change notification settings - Fork 647
chunked cross entropy #1390
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chunked cross entropy #1390
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1390
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 4b06b99 with merge base 9629a36 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
Just curious, do you have data for "compile step loss + chunk CE"?
How much longer was it for you? I didn't notice any significant difference on my machine when I added the compile step loss.
Yea I have always thought the logits slicing is a bit iffy (HF is doing this too 😆). Technically due to this, the number of tokens seen is |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
exciting!
torchtune/modules/transformer.py
Outdated
output = self.output(h).float() | ||
# shape: [b, seq_len//num_chunks, out_dim] - out_dim is usually the vocab size | ||
if self.num_output_chunks > 0: | ||
# chunk for ChunkedCrossEntropyLoss, saving memory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My hypothesis was that by chunking before the linear, the output chunks are not views into the same underlying tensor allowing earlier ones to be freed and have their memory reused for later ones.
I wonder if we can check the memory snapshot to confirm that if not already, as it might be nice to understand the constraint on the modeling side (namely, having to chunk the input to the final output projection rather than chunking the output itself).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has been bugging me for a while and I still don't understand this. Where would they be freed though? We pass the entire list of chunks to the loss? Or did I misunderstand your point?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its not the the chunks are freed. If we split AFTER, then we have chunks + whole output. Its duplicated. When we split BEFORE, we have only the chunks. No duplication.
The memory savings is because, when we have chunks, the extra memory by upcasting is at most of size of one chunk, instead of doubling the whole output.
Ideally, if we split AFTER, pytorch should learn to delete the whole output that we have no need for anymore, avoiding the duplication. But this doesnt happen.
I will get it for you. If we can chunk after the linear, that would be beautiful |
2f76477
to
043bd18
Compare
One update: compiling loss + model consumes more memory, but it is faster if activation checkpoint is true. Previously I claimed that the speed was the same. I need to test it more. Working with @weifengpy to see if we can make compile work with chunked + compile together. |
I think two important things to understand is
I try to reproduce 1. in a standalone script with Linear (LM head) + CE but couldn't reproduce the speedup. Perhaps something else is happening here 👀 Another question. Is there a need to upcast logits after LM head to FP32? From what I understand, CE will upcast logits to FP32 anyway internally, both in eager and compile mode. |
I dont know :/ . Note: compiling the upcast with the CE has a small difference in the loss.
Yeah, i would love to have the best of all worlds: speed up + fast compile + lowest memory |
Loss is different if i remove .float() |
# equivalent to doing labels[..., 1:] and logits[..., :-1, :] | ||
# But this way we dont need to slice the logits. We just add an ignore index to labels. | ||
labels = torch.hstack( | ||
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry if I'm being dense but isn't labels.shape[0] always going to be batch_size which is also what the first dim of the cache is? If so, why do we need to slice?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that we need to for the last batch of the dataloader, which may be of a different size.
Eg: you have 10 samples, but bsz=3. Then, unless we have drop_last=True in the dataloader, the last batch will have size=1.
labels = torch.hstack( | ||
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) | ||
) | ||
if not isinstance(logits, list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah :(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah no way around this one unfortunately. But it's worth it for the memory wins 😃
""" | ||
|
||
def __init__(self, num_output_chunks: int = 16, ignore_index: int = -100): | ||
super(ChunkedCrossEntropyLoss, self).__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not super().__init__()
? I think thats how we do this across the code base?
self.num_output_chunks = num_output_chunks | ||
self.ignore_index = ignore_index | ||
self.cross_entropy_loss = torch.nn.CrossEntropyLoss( | ||
reduction="sum", ignore_index=self.ignore_index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the change in reduction? I think previously we were using the default?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The default is mean. I can do the mean, sum the means, and divide by the number of chunks. It should be equivalent and probably a bit faster
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think your current way is more correct. When there are different paddings across samples in a batch (i.e. number of ignore_index is different), taking the mean of (mean of each chunk) != (total loss of non_padded tokens / no. of non_padded_tokens)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh, right! I would have to do a few more operations for it to be numerically the same. Thats why i computed the sum instead of the mean.
torchtune/modules/transformer.py
Outdated
self.output(chunk) for chunk in h.chunk(self.num_output_chunks, dim=1) | ||
] | ||
else: | ||
output = self.output(h) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry if we already discussed this, but why move the float call from here to the recipe? Easy for recipe writers to trip up and doesnt seem like its conceptually different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thats true! I can put the float back here. It doesnt need to be in the recipe.
torchtune/modules/transformer.py
Outdated
output = self.output(h).float() | ||
# shape: [b, seq_len//num_chunks, out_dim] - out_dim is usually the vocab size | ||
if self.num_output_chunks > 0: | ||
# chunk for ChunkedCrossEntropyLoss, saving memory |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This has been bugging me for a while and I still don't understand this. Where would they be freed though? We pass the entire list of chunks to the loss? Or did I misunderstand your point?
Some interesting triton kernels from Liger-Kernel, might be interesting to test them too https://github.com/linkedin/Liger-Kernel
|
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #1390 +/- ##
==========================================
- Coverage 70.14% 69.59% -0.55%
==========================================
Files 272 270 -2
Lines 12919 13011 +92
==========================================
- Hits 9062 9055 -7
- Misses 3857 3956 +99 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A handful of small comments (for whatever reason I left them all in the QAT recipe, but ofc applicable to all recipes). But no major concerns from me here. Really excited to get this in and realize these memory savings!
@@ -533,6 +547,12 @@ def __init__( | |||
self.num_heads = num_heads | |||
self.head_dim = head_dim | |||
self.causal_mask = None | |||
self.num_output_chunks = 0 | |||
|
|||
def set_num_output_chunks(self, num_output_chunks: int) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could consider just using @property
on def num_output_chunks(self)
but no strong preference here
self._loss_fn = config.instantiate(cfg.loss) | ||
backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") | ||
if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just use isinstance
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
either I would have to import the class just for this check, or i could use modules.loss.CEWithChunkedOutputLoss, since modules is already imported. However, for some reason i was getting error "loss is not in torchtune.modules". So it seemed to be the easiest way to avoid an extra import.
I dont mind changing it if you prefer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm weird. It's definitely OK to import from modules in the recipe and imo a little bit cleaner (even if it does require the extra import). We already import from mouldes in our LoRA recipes for various PEFT stuff though we don't in full finetune recipes. Anyways no strong preference here, definitely not a blocker
recipes/qat_distributed.py
Outdated
self._loss_fn._compute_cross_entropy = torch.compile( | ||
self._loss_fn._compute_cross_entropy, backend=backend | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could be good to add a quick comment here around why we compile this method specifically (I know you already explain in detail in the class, but this is more visible)
labels = torch.hstack( | ||
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) | ||
) | ||
if not isinstance(logits, list): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah no way around this one unfortunately. But it's worth it for the memory wins 😃
@@ -202,6 +203,7 @@ def setup(self, cfg: DictConfig) -> None: | |||
|
|||
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) | |||
|
|||
self._model_compile = cfg.get("compile", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should definitely land distributed FFT model compile as a fast follow here; until then the behavior will be confusing to folks
@SLR722, this changed all default configs. So before it lands in genie, you would have to update the recipes. |
Context
What is the purpose of this PR? Is it to
Mechanism: Chunked cross entropy reduces memory by processing and upcasting one chunk at a time. This allows the memory to be released after the computation. If we dont chunk, then compiling model+loss together or independently has no difference, as long as we compile the upcasting with the loss, e.g.:
It is important to notice that chunking must happen BEFORE the linear output, or the memory is not released:

chunking after the linear output layer (bad):
chunking before the linear output layer (This pr):

When compiling model and loss independently, if the upcast doesnt happen inside of the compiled Cross Entropy, chunking is less effective

The other impact of this PR is removing logits manipulation:
torch.compile is much better in 2.5 vs 2.4

At least for this config, no huge gains in memory for num_chunks > 4. However, we dont see any performance degradation using num_chunks=16, so this was the chosen value.

edit: when running smaller models, there is a performance degradation (~10% TPS). I changed the number to num_chunk=8, and there is almost no TPS reduction, and memory savings is still very high.
IMPACT
For llama 8B, seq_len=8192, if activation checkpointing is FALSE, then:
toks/sec:
compiling model+loss together == chunked + compile model + loss separetely
however, if checkpointing is TRUE, then:
compiling model+loss together >> chunked + compile model + loss separetely
In every scenario, chunked saves about 2.5GB of memory per batch with max_seq_len=8192.
For bsz=2, this means 5GB saved.
Therefore, it seems an issue with compile + AC. If we solve that, we can compile model + loss separately, use chunked loss and have the best of all worlds: low compile time, low memory, high toks/sec.

edit: we decided to compile the model per layer. This doesnt match the speed of compile(model + loss), but get us closer to that, while having a much faster compile time (~40s vs 10min for Evan).
Overall:
Things to be aware of
Compiling the upcasting + CE changes the loss by a small amount. According to the compile team, this is expected. We would have to be aware about it in future parity checks. This is independent of chunked cross entropy.
Edit: when training for 200 steps, there is no difference in loss up to the 2nd decimal point.
Design
We have two main options: setter vs wrapper
The set_num_output_chunks is called in the recipe:
If no one intentionally calls set_num_output_chunks, the behavior of transformer is unchanged, and will work as expected during inference.
pros:
Cons:
call it in the recipe:
Pros:
cons:
Results:

Apparently compile has some issues with the wrapper. It also causes higher reserved memory (prob due to graph break?). Maybe these can be solved?
"
Other alternatives that were ruled out:
Changelog
TODO:
[ ] add unit tests after reviewers are ok with the changes.
[ ] consider removing the step loss, since we wont compile it with the model anymore.
[ ] run 100+ steps for every recipe
[ ] add srt
[ ] improve docstring
Test plan
llama lora single

llama full single

qwen distributed lora

gemma distributed full

QWEN QAT
UX
If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Example of docstring:
torchtune/torchtune/modules/vision_transformer.py
Line 285 in 6a7951f
Example in our docs: https://pytorch.org/torchtune/main/tutorials/qat_finetune.html#applying-qat-to-llama3-models