-
Notifications
You must be signed in to change notification settings - Fork 646
Normalize CE loss by total number of (non-padding) tokens #1875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1875
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit d5ff9ec with merge base e030626 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@@ -80,4 +80,4 @@ def forward(self, logits: List[torch.Tensor], labels: torch.Tensor) -> torch.Ten | |||
for logits_chunk, labels_chunk in zip(logits, labels): | |||
total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk) | |||
|
|||
return total_loss / total_elements | |||
return total_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this unnormalized?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we need to decide where to divide by the number of tokens. This version of the PR does it all in the recipe. Even if we continue normalizing it here, we would then need to do something like running_loss += self._loss_step(batch) * current_num_tokens
in the recipe, which is also a bit awkward
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm! thanks for fixing that
In honor of the day the ML community first discovered the fact that (x1 / n1) + (x2 / n2) != (x1 + x2) / (n1 + n2)
This PR changes how we calculate the loss when gradient accumulation is enabled. This way we'll get an exact match in loss curves with and without gradient accumulation.
The approach
Keep a running tally of the number of unmasked tokens in the recipe. Don't actually change our loss implementations (so they are still normalized by number of non-padding tokens in a batch), but when we get a batch's loss in the recipe we now just multiply by the number of non-padding tokens in that batch to get the unnormalized loss.
Previously we called
.backward()
after every batch (after dividing loss by # of grad accumulation steps). Now we can't do that because we need to accumulate all losses to do proper normalization. So we instead wait until it's time to step, divide our accumulated loss by the total number of tokens seen across all batches in the step, then callloss.backward()
.Note: as a side effect our tokens/sec now logs only non-padding tokens. So yes the tokens/sec we see in our logs will decrease but it will also now be more representative of meaningful throughput (and you won't have to listen to me complaining about misleading tokens/sec anymore).
Test plan
Updated a bunch of recipe tests to explicitly test with gradient accumulation enabled. Note that previously these tests would fail as the loss values would not match (see e.g. this comment in a test that we added specifically to check parity of gradient accumulation when all samples have the same sequence length), but now we get the same loss values regardless of whether or not gradient accumulation is enabled.
E2E tests
For all E2E tests, we compare the following four cases:
We also change the logging of num_tokens_per_second on main to match what's in this PR for a fair comparison.
Llama 3 8B full finetune, single device
TLDR: we get the same loss curves for cases (1), (3), and (4). The updated gradient accumulation logic also increases tokens/second and reduces peak allocated memory. Full wandb workspace here
Loss curves:
Peak allocated memory:
Tokens/sec:
Qwen 2 1.5B with LoRA on two devices
TLDR: same loss curves for (1), (3) and (4). There is a slight increase in peak allocated memory, but also a pretty big jump in tokens/sec. Full wandb workspace
Loss curves:
Peak allocated memory:
Tokens/sec: