Skip to content

Normalize CE loss by total number of (non-padding) tokens #1875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 25, 2024

Conversation

ebsmothers
Copy link
Contributor

@ebsmothers ebsmothers commented Oct 20, 2024

In honor of the day the ML community first discovered the fact that (x1 / n1) + (x2 / n2) != (x1 + x2) / (n1 + n2)

This PR changes how we calculate the loss when gradient accumulation is enabled. This way we'll get an exact match in loss curves with and without gradient accumulation.

The approach

Keep a running tally of the number of unmasked tokens in the recipe. Don't actually change our loss implementations (so they are still normalized by number of non-padding tokens in a batch), but when we get a batch's loss in the recipe we now just multiply by the number of non-padding tokens in that batch to get the unnormalized loss.

Previously we called .backward() after every batch (after dividing loss by # of grad accumulation steps). Now we can't do that because we need to accumulate all losses to do proper normalization. So we instead wait until it's time to step, divide our accumulated loss by the total number of tokens seen across all batches in the step, then call loss.backward().

Note: as a side effect our tokens/sec now logs only non-padding tokens. So yes the tokens/sec we see in our logs will decrease but it will also now be more representative of meaningful throughput (and you won't have to listen to me complaining about misleading tokens/sec anymore).

Test plan

Updated a bunch of recipe tests to explicitly test with gradient accumulation enabled. Note that previously these tests would fail as the loss values would not match (see e.g. this comment in a test that we added specifically to check parity of gradient accumulation when all samples have the same sequence length), but now we get the same loss values regardless of whether or not gradient accumulation is enabled.

E2E tests

For all E2E tests, we compare the following four cases:

  1. batch size N, gradient accumulation disabled on main
  2. batch size 1, N gradient accumulation steps on main
  3. batch size N, gradient accumulation disabled on this PR
  4. batch size 1, N gradient accumulation steps on this PR

We also change the logging of num_tokens_per_second on main to match what's in this PR for a fair comparison.

Llama 3 8B full finetune, single device

TLDR: we get the same loss curves for cases (1), (3), and (4). The updated gradient accumulation logic also increases tokens/second and reduces peak allocated memory. Full wandb workspace here

Loss curves:

Screenshot 2024-10-25 at 2 00 24 PM

Peak allocated memory:

Screenshot 2024-10-25 at 2 00 57 PM

Tokens/sec:

Screenshot 2024-10-25 at 2 01 27 PM

Qwen 2 1.5B with LoRA on two devices

TLDR: same loss curves for (1), (3) and (4). There is a slight increase in peak allocated memory, but also a pretty big jump in tokens/sec. Full wandb workspace

Loss curves:

Screenshot 2024-10-25 at 2 38 23 PM

Peak allocated memory:

Screenshot 2024-10-25 at 2 39 33 PM

Tokens/sec:

Screenshot 2024-10-25 at 2 39 59 PM

Copy link

pytorch-bot bot commented Oct 20, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1875

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit d5ff9ec with merge base e030626 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 20, 2024
@@ -80,4 +80,4 @@ def forward(self, logits: List[torch.Tensor], labels: torch.Tensor) -> torch.Ten
for logits_chunk, labels_chunk in zip(logits, labels):
total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk)

return total_loss / total_elements
return total_loss
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this unnormalized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need to decide where to divide by the number of tokens. This version of the PR does it all in the recipe. Even if we continue normalizing it here, we would then need to do something like running_loss += self._loss_step(batch) * current_num_tokens in the recipe, which is also a bit awkward

@joecummings joecummings linked an issue Oct 21, 2024 that may be closed by this pull request
@ebsmothers ebsmothers changed the title [WIP] Explicitly normalize CE loss by # of tokens Normalize CE loss by total number of tokens Oct 25, 2024
@ebsmothers ebsmothers changed the title Normalize CE loss by total number of tokens Normalize CE loss by total number of (non-padding) tokens Oct 25, 2024
Copy link
Contributor

@felipemello1 felipemello1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! thanks for fixing that

@ebsmothers ebsmothers merged commit 23c8829 into pytorch:main Oct 25, 2024
17 checks passed
@ebsmothers ebsmothers deleted the grad-accum branch October 25, 2024 21:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Grad acc "fix"
4 participants