add drop_last to dataloader #1654
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Context
What is the purpose of this PR? Is it to
flex_attention + compile break if the last batch size has a different shape. This PR adds drop_last, which only drops the last batch IF the batch size is not divisible by number of steps.
Test plan
reproducing the bug
error log: https://www.internalfb.com/phabricator/paste/view/P1605485255
tune run --nproc_per_node 2 full_finetune_distributed --config llama3_1/8B_full enable_activation_checkpointing=False fsdp_cpu_offload=False compile=True dataset.packed=True dataset.split=train[:20%] tokenizer.max_seq_len=4096 metric_logger=torchtune.training.metric_logging.WandBLogger metric_logger.project=profiling metric_logger.tags=[my_experiment_name] log_every_n_steps=1 log_peak_memory_stats=True gradient_accumulation_steps=1 max_steps_per_epoch=40 epochs=1 batch_size=4 enable_activation_checkpointing=True
torch 2.6.0.dev20240922+cu121
torchao 0.6.0.dev20240922+cu121
torchtune 0.0.0 /data/users/felipemello/torchtune
torchvision 0.20.0.dev20240922+cu121
this happens in the very last batch. All other batches have shape:
tokens shape torch.Size([4, 4096])
labels shape torch.Size([4, 4096])
mask (4, 1, 4096, 4096)
input_pos torch.Size([4, 4096])
But the last batch has shape:
tokens shape torch.Size([2, 4096])
labels shape torch.Size([2, 4096])
mask (2, 1, 4096, 4096)
input_pos torch.Size([2, 4096])
adding drop_last=True to the dataloader solves it