Skip to content

add drop_last to dataloader #1654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 24, 2024
Merged

Conversation

felipemello1
Copy link
Contributor

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

flex_attention + compile break if the last batch size has a different shape. This PR adds drop_last, which only drops the last batch IF the batch size is not divisible by number of steps.

Test plan

tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed --config qwen2/0.5B_full max_steps_per_epoch=5

reproducing the bug

error log: https://www.internalfb.com/phabricator/paste/view/P1605485255

tune run --nproc_per_node 2 full_finetune_distributed --config llama3_1/8B_full enable_activation_checkpointing=False fsdp_cpu_offload=False compile=True dataset.packed=True dataset.split=train[:20%] tokenizer.max_seq_len=4096 metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=profiling metric_logger.tags=[my_experiment_name] log_every_n_steps=1 log_peak_memory_stats=True gradient_accumulation_steps=1 max_steps_per_epoch=40 epochs=1 batch_size=4 enable_activation_checkpointing=True

torch 2.6.0.dev20240922+cu121
torchao 0.6.0.dev20240922+cu121
torchtune 0.0.0 /data/users/felipemello/torchtune
torchvision 0.20.0.dev20240922+cu121

this happens in the very last batch. All other batches have shape:
tokens shape torch.Size([4, 4096])
labels shape torch.Size([4, 4096])
mask (4, 1, 4096, 4096)
input_pos torch.Size([4, 4096])

But the last batch has shape:
tokens shape torch.Size([2, 4096])
labels shape torch.Size([2, 4096])
mask (2, 1, 4096, 4096)
input_pos torch.Size([2, 4096])

adding drop_last=True to the dataloader solves it

Copy link

pytorch-bot bot commented Sep 23, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/1654

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b0689a2 with merge base f51f894 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 23, 2024
Copy link
Contributor

@joecummings joecummings left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm coo with this

@felipemello1 felipemello1 merged commit ce58cb1 into pytorch:main Sep 24, 2024
17 checks passed
@felipemello1 felipemello1 deleted the drop_last branch September 24, 2024 00:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants